RNN의 단점을 보완한 모델인 LSTM에 대한 정리

LSTM (Long Short-Term Memory)

LSTM 개요

  • RNN은 관련 정보와 그 정보를 사용하는 지점 사이 거리가 멀 경우 역전파시 그래디언트가 점차 줄어 학습능력이 크게 저하되는 것으로 알려져 있다. 이를 vanishing gradient problem이라고 합니다.
  • 이 문제를 극복하기 위해서 고안된 LSTM은 RNN의 hidden-state에 cell-state를 추가한 구조

+full


LSTM 구조

+full

  • LSTM 은 RNN과 다르게 4개의 값이 사용됨
    • : input gate @sigmoid = 현재 입력 값을 얼마나 반영할 것인지
    • : forget gate @sigmoid = 이전 입력 값을 얼마나 기억할 것인지
    • : output gate @sigmoid = 현재 셀 안에서의 값을 얼마나 보여줄 것인지
    • : gate gate @tanh = input cell을 얼마나 포함시킬지 결정하는 가중치, 얼마나 학습 시킬지 (-1, 1)
  • 는 잊기 위해서 0에 가까운 값을, 기억하기 위해서 1에 가까운 값을 사용하기 위해 sigmoid 함수를 사용함.
  • 에서 tanh 는 01의 크기는 강도, -11은 방향을 나타냄

Forget Gate

  • 는 과거의 정보를 잊기 위한 게이트
  • 직전의 hidden state()과 입력 를 입력 받고, 시그모이드 함수를 취해준 값이 forget gate의 출력
  • 시그모이드 함수의 출력은 0~1로, 값이 0에 가까울수록 이전 정보를 잊고, 1에 가까울 수록 이전정보를 많이 기억하도록함

Input Gate

  • 현재의 정보를 기억하기 위한 게이트
  • 직전의 hidden state()과 입력 를 입력 받고, 각각 sigmoid와 tanh를 취해줌
  • sigmoid의 출력 01 값은 input gate의 역할을, tanh의 출력 -11 값은 현재 cell state를 나타냅니다

Cell State Update

  • Cell state의 업데이트 과정
  • (1) 이전의 cell state 값과 얼마나 잊을지를 정해줄 forget gate의 출력값을 원소별 곱셈 연산
  • (2) 현재의 cell state 값과 얼마나 기억할지를 정해줄 input gate의 출력값을 원소별 곱셈 연산해 더한 값으로 update 됩니다.

Output Gate

  • output gate는 update된 cell state의 정보를 얼마나 다음 hidden state에 반영할지를 정해준다
  • hidden state의 값은 tanh를 취한 cell state의 값과 얼마나 반영할지를 정해줄 output gate 출력 값의 원소별 곱셈으로 계산된다

LSTM 순전파

+full

LSTM 역전파

+full

  • 함수 전까지 역전파를 한모습

+full

  • 각 sigmoid 함수와 tanh의 activation의 역전파에 대한 설명
  • sigmoid의 gradient 값은 (1-sigmoid(x)) x sigmoid(x)이며, tanh의 gradient는 1-tanh^2(x)이며 이를 적용하여 역전파

+full

  • 이후 입력과 hidden state에 역전파의 결과

GRU (Gated Recurrent Unit)

  • LSTM의 변형으로 GRU(Gated Recurrent Unit) Cell 의 의미
  • LSTM 셀의 간소화 버전
  • LSTM에서 ct와 ht가 하나의 ht로 합쳐졌습니다.
  • rt의 추가로 과거의 정보를 어느정도 reset 할 것인지 정합니다.
  • Update를 위해 사용되던 f와 i의 값이 zt와 (1-zt)인 하나 값으로 input과 hidden state의 update 정도를 정합니다.
  • 현 시점의 정보 후보군(Candidate)을 계산합니다. gt는 과거 hidden state(은닉층) 값을 그대로 사용하지 않고 reset gate(rt)를 곱하여 사용합니다.
  • 현 시점 hidden state(은닉층) 값은 update gate 결과와 Candidate 결과를 결합하여 계산합니다.


참고