(機械学習基礎)数式なし! LSTM・GRU超入門

当記事では数式を使用せずに、機械学習プログラミングで差し支えない程度の知識獲得を目指して、LSTMとGRUについて説明していきます。

深い理解をするうえで計算式は必要ですので、追々別の記事で触れたいと思いますが、数式よりも概要としての理論の方が需要が高いと思い、数式を使用せずにLSTMやGRUの説明記事を執筆しました。同様の内容を動画でも説明していますので、ぜひご視聴ください!

これだけは(ポイント)

本記事で話す内容のうち、ページを閉じる前にこれだけは最低でも理解してほしいというポイントを示します。

  • 再帰セルの表記は時間軸上に展開されていないものと展開されたものの2種類があり、どちらの表示で示されているかを意識することで混乱を減らせる。
  • RNN:系列データを学習することができるニューラルネットワーク。勾配消失問題と重み衝突という二つの理由により長期的な特徴の学習は苦手。
  • LSTM:長期的特徴と短期的特徴を学習することができる。欠点は計算量が多いこと。
  • GRU:LSTMの代替となるモデルでLSTMより計算量が少なくてすむ。性能はLSTMと変わらないとされている。

RNNは長期的な学習ができない!!

ここでは、最初にRNNの基礎及び再帰セルの表記について説明した後に、RNNの学習を妨げる問題点、すなわちLSTMが考案されるに至った理由を2種類説明します。

RNNの基礎

RNNとは、内部に再帰的な構造を持つニューラルネットワークで、気温変動、言語処理といった系列的な情報の識別、推論に使用されます。一般的なフィードフォワードニューラルネットワークは基本的に時間依存の情報を学習することできません(例えば、時刻\(t-n\)から時刻\(t\)までの情報をすべて一括で入力するなどの方法を使用すれば不可能というわけではないが…)。各時刻の信号が逐次的に入力される状況で、時間依存の特徴を学習する方法として、中間層のセルに再帰可能な構造を用いたものが再帰型ニューラルネットワーク(Recurrent Neural Network:RNN)になります。以下に再帰型ニューラルネットワークを示します。

再帰型ニューラルネットワーク(RNN)
再帰型ニューラルネットワークには(a)のように時間軸上に展開されていない(ノード間を通る情報がどの時刻のものか理解しにくい)表記と、(b)のように時間軸上に展開されている表記の2種類がある。

学習には、BPTT(backpropagation through time)が用いられます。しかし、BPTTは万能ではなく、長期的特徴を学習することは苦手という問題があります。

再帰セルの表記は2種類

上の図において注目してほしいのは、再帰セルの表記が2種類あることです。違いは時間軸で展開されているかどうかです。(a)は時間軸で展開されていません。(b)は時間軸に沿って展開されています。最も単純な再帰セルなら(a)の表記でも混乱することはありませんが、LSTMなど複雑な再帰セルを理解しようとすると、再帰セルが(a)と(b)のどちらの表記で示されているのか意識しないと混乱を招きますので、時間軸で展開された図かどうか意識することをお勧めします。

RNNの学習を妨げるもの

先ほど、RNNは長期にわたる特徴の学習は上手くできないことを言及しました。その原因は勾配消失問題と重み衝突の2つが関与しています。

勾配消失問題

再帰型ニューラルネットワークに関わらず、深層ニューラルネットワークでは勾配消失問題が度々問題になります。勾配消失問題とは、誤差逆伝播により誤差を低次の層に伝播していく過程で、完全に消失しまい学習ができないという問題です。これを、フィードフォワード型の深層ニューラルネットワークでは、活性化関数のかたちを試行錯誤したり、事前学習をしたりすることで、あまり問題にならないよう対処してきました。一方でRNNの場合は、勾配消失問題だけでなく次に述べる重み衝突問題も同時に生じているため全く異なるアプローチで対処していきます。

重み衝突

重み衝突は、さらに入力重み衝突と出力重み衝突に分けられます。入力重み衝突とは、「現時刻に入力された信号が有用なものなら重みを大きく、有用でなければ小さくするが、今は有用でなくても将来有用な情報だったら重みはどうするの(大きくすべきか小さくすべきか)??」という問題です。出力重み衝突も同様で、「現時刻に出力された信号が有用なものなら重みを大きく、有用でなければ小さくするが、今は有用でなくても将来有用な情報だったら重みはどうするの??」という問題です。ここで、出力の場合は重み云々に関係ないのではと疑問を持つ方がいるかもしれませんので説明すると、再帰セルの出力が再帰の重みを介して再帰セルの入力として再度伝わる際の話をしています。

まとめると、再帰セルに入力された入力層および再帰セルから信号が、長期的な特徴をもつのか短期的な特徴を持つのか分からない状態において、重みを大きくすべきか小さくすべきか決定できない問題を重み衝突といいます。RNNの再帰セルは単純なニューロンモデルが使用されるため、現在の入力が有用かどうかのみを判定することになり、必然的に短期的な記憶は学習できますが、長期的な記憶は学習することがきません。

LSTM

RNNの問題の対処

RNNが長期的な特徴を学習できない理由として、勾配消失問題と重み衝突の2つを説明しました。LSTMでは両問題を以下の方法で対処します。

  • 勾配消失問題
    →誤差を保存するセル(Constant Error Carrousel:CEC)を使用する(記憶セルを用意)。
  • 重み衝突(入力重み衝突・出力重み衝突)
    入力ゲート・出力ゲートを使用する

勾配消失問題により、長期にわたる誤差伝播が不可能(すなわち、長期にわたる記憶が不可能)ならば、専用の記憶セルを用意して長期的な記憶を可能にすることで対処するわけです。

そして、重み衝突に関しては、入力層から再帰セルに入力される信号や再帰セルから再帰セルに入力される信号を適切に処理(どの程度と通すか)するゲートを用意することで対処するわけです。

付随して必要になる仕組み

上で述べた対処法には2つの問題点が付随して生じます。

1つ目は記憶セルに記憶された情報をどのようにリセットするかです。リセットする仕組みがないと、永遠に入力ゲートからの入力が加算されることになり、過去の記憶が将来の判断に影響を及ぼす可能性があります。そこで、忘却ゲートをという忘れさせる仕組みを用意します。これにより、CECセルに記憶された情報を適切に忘却することができます。

2つ目はそれぞれのゲートをどのように操作するかです。入力信号の通過具合を調節する入力ゲート、再帰セルの出力信号お通過具合を調節する出力ゲート、どの程度情報を忘却するか調節する忘却ゲートは神様ではないので、いつどのように処理すべきか学習しなければ分かりません。そこで、それぞれのゲートを操作するゲートコントローラ、すなわち新たなニューロンを用意します。ゲートでは乗算のみで通過する情報量を調節できるように、ゲートコントローラのニューロンは0~1を出力するように設定します。そのため、活性化関数としてシグモイド関数が使用されることが一般的となります。

ここまでのまとめ:基本LSTM(Basic LSTM)

ここまで話してきた内容は、LSTMの中でもとても基本的なもので、基本LSTM(Basic LSTM)と呼ばれることがあります。当記事では基本LSTMと呼ぶことにします。では、主な構成要素をまとめます。

  • 記憶セル(CECセル)
  • 入力ゲート・出力ゲート・忘却ゲート
  • ゲートを操作するニューロンが3つ
  • セルへの入力を求めるニューロンが1つ

基本LSTMは、セルへの入力を求めるニューロンが1つ、CECセルを操作するゲートコントローラニューロンが3つの合計4つのニューロンで実現されていると解釈できます。それでは、時間軸で展開した基本LSTMセルの表記と時間軸で展開していない表記を見てみます。

時間軸展開していない表記

時間軸で状態を展開しない場合の表記を2つ示してみました。前時刻の情報は点線で示しています。左側の表記は入門書等で頻繁に見かける図ですね。また、右のような図もしばしば見かけます。主な違いは、どこがニューロンか一目で判断できることと、向きが横ではなく縦である点です。

左図の上に位置する3つのニューロンが入力ゲート・忘却ゲート・出力ゲートを操作するゲートコントローラです。ゲートはゲートコントローラからの出力値(0~1)を乗算します。gは活性化関数を表し、デフォルト値としてtanhが使用されることが多いです。また、gは省略されることがあります。

時間軸展開した表記

時間軸展開しなかった表記には、CECとh(もしくは出力y)の2つが時間を扱う状態ベクトルであることが分かります。すなわち、この二つを時間軸上に展開して図を描けばいいのです。すると、以下の図になります。

覗き穴結合(ピープホール結合)

実は、ここまでで説明してきた基本LSTMはあまり使用されません。我々がLSTMというときに使用されるセルはこれから説明する覗き穴結合(ピープホール結合)という接続が内部的に存在します。

基本LSTMの問題点:CECの気持ちが無視されている

ゲートコントローラは入力x(t)と直前の出力h(t-1)のみを判断材料として、全結合計算をするため、CECの記憶の忘却、更新や出力などに、CECセルが保持している長期記憶情報が考慮されません。

改善策:CECに耳を傾けてあげよう

改善策は簡単です。CECに耳を傾けられるように、ゲートコントローラと接続を作ればよいのです。具体的には、入力ゲートと忘却ゲートを操作するニューロンには時刻\(t-1\)の値が、出力ゲートを操作するニューロンには時刻\(t\)の値が届くように接続します。

一般的なLSTM

それでは、基本LSTMに覗き穴結合(ピープホール結合)を加えた場合のLSTMセルの構造を先ほどと同様に時間軸で展開していない表記と展開した表記について見てみます。

時間軸展開していない表記

基本LSTMセルのモデルで示した図において、新たに3つ接続が増えていることが分かります。

少し言及しましたが、出力ゲートのゲートコントローラには過去の状態ではなく、現在の状態が反映されるよう、現在の値が覗ける接続になっています。

時間軸展開した表記

これも同様で、基本LSTMセルモデルに、3つの覗き穴結合が追加されただけです。

※以下の図において、ピープホール結合を表す矢印が一か所間違っていました。tanhを活性化関数に持つニューロンではなく、その右のニューロン(入力ゲートコントローラ)へ矢印を引くのが正解です。

時間軸方向にセルを並べてみると、以下のような感じになります。

GRU

これから触れるGRU(Gated Recurrent Unit)セルは、LSTMと同様の性能を持つとされており、LSTMより計算量が少なく、高速に学習を進めることができます。

LSTMの欠点

LSTMはRNNで不可能だった長期的特徴の学習を可能にしたセルですが、計算コストが大きいという問題点があります。計算コストが大きいことは、機械学習に限らず好ましくありません。

アイディア

解決策として、状態ベクトル数を減らす、ゲートコントローラ数を減らすという主に2つのアプローチで改善策を考えていきます。

  • 状態ベクトル数を減らす(時間依存の状態数を減らす)
    LSTMではCECとhの2つが状態を保持していました。これを1つにまとめます。
    これにより、記憶セルという構造は無くなります。
  • ゲートコントローラ数を減らす
    LSTMでは入力ゲート、忘却ゲート、出力ゲートに1つずつゲートコントローラが必要でした。そこで、忘却ゲートと入力ゲートの操作を1つのコントローラで操作するように変更します。

2つの方針における具体的な変更は上で述べたようなものですが、この変更に伴って、別の修正を加える必要性が生まれます。具体的には、

  • CECとhをまとめたため、出力ゲートによる出力制限は望ましくない。すなわち、出力ゲートの仕組みを取り除く。
  • 出力ゲートを取り除くと、重み衝突が発生する可能性があるため、入力値を求めるニューロンに再帰される経路上にリセットゲートを設ける。

です。

GRUセルの構造

LSTMセルと比較できるように2つを並べて示してあります。今までと同様に、時間軸で展開していない表記と展開した表記を見てみます。

時間軸展開していない表記

以下に時間軸展開していない表記のセルをそれぞれ示します。上がLSTMセルで、下がGRUセルです。とても簡素になっていることが分かると思います。CECセルがなくなったため、覗き穴結合はありません。注目すべき部分は、記憶の忘却と更新という操作が、1つのゲートコントローラで行われている点で、その手法は興味深いものがあります。忘却はゲートコントローラの値を1から差し引いたものを乗算することで、更新はゲートコントローラの値を1時刻前の状態ベクトルと乗算した後に加算することで実現されている点です。すなわち新たな記憶を保持するときは過去の記憶は忘れるように連動させることで構造を簡略化させています。

時間軸展開した表記

以下の図は、時間軸で展開したときの表記で、上がLSTMセル、下がGRUセルです。

GRUはLSTMに劣るのか

この点はかなり気になるところです。内部様態のバリエーションはGRUの方が確実にLSTMより少ないため、表現能力という点では落ちていると考えるべきだと思います。しかし、ハッキリとしていない部分が多く、性能は変わらないと結論付けている書籍が多くあるため、あまり気にする必要はないと思います。

念頭に置くべきことは、一般論として計算量と性能はトレードオフの関係にあるということです。絶対とは言い切れませんが、念頭に置いておくことで、壁にぶつかったときの解決策を見つけるきっかけにできるはずです。

機械学習ライブラリの再帰セル一覧

ここでは代表的な機械学習ライブラリのKerasとPyTorchがもつ再帰セルについて一覧にしたいと思います。

Keras

  • RNN
  • SimpleRNN
  • GRU
  • LSTM
  • ConvLSTM2D
  • SimpleRNNCell
  • GRUCell
  • LSTMCell
  • CuDNNGRU
  • CuDNNLSTM

PyTorch

  • RNNBase
  • RNN
  • LSTM
  • GRU
  • RNNCell
  • LSTMCell
  • GRUCell

参考文献

当記事を作成するにあたり、以下の書籍を参考に勉強をさせていただいたので、参考文献として紹介させていただきます。

Follow me!

コメントを残す