RNNをイメージで理解する

再帰型ニューラルネットワーク(RNN)と一般的なニューラルネットワーク

ザックリと

まず、導入として、一般的なニューラルネットワークとRNNがどのように違い、どこで使用されるのか、大まかに説明していきます。

違いをザックリといえば、時間領域でパターンを学習できるか否かです。学習できるのがRNNで、不可能なのが一般的なニューラルネットワークです。

RNNが使用される用途は、機械翻訳、感情認識、画像や動画のキャプショニング等々、一般的なニューラルネットワークで扱えなかった時間的領域が中心となります。最後のキャプショニングではCNNで認識したものをRNNを使用して文章生成が行われます。

ちなみに、人間の記憶には陳述記憶、非陳述記憶があり、機械学習で主に用いられるのは陳述記憶のほうですが、陳述記憶はさらにエピソード記憶と意味記憶に分けられ、今までのニューラルネットワークの重み学習は意味記憶、RNNは意味記憶にエピソード記憶を加えたものといえます。

・意味記憶:猫は「猫」、犬は「犬」というように、概念化された記憶で、いつ学んだか、といった時間的な情報は記憶されません。
・エピソード記憶:「車を運転していたら猫が飛び出してきたのでブレーキを踏んだ」というような経験を覚えているように、個人の経験に対する記憶で、出来事の時間的・空間的な順序やその時の感情などが記憶されます。

少し詳しく

一般的なニューラルネットワークで時系列データを扱うにはどのようにしたらよいでしょうか?答えは、過去の入力も含め現在までのデータ全てを同時に入力するのです。同時に入力することは、簡単に言えば画像を入力するようなもので、時系列データを空間的な領域に変換して扱っていることになります。簡単な例を示します。時刻1の入力\(\boldsymbol{x}_1\)から時刻\(t\)の入力\(\boldsymbol{x}_t\)までの時間的パターンを学習させるとしましょう。このとき、一般的なニューらえるネットワークへの入力ベクトルは

\(\boldsymbol{x} = \begin{pmatrix} \boldsymbol{x}_1 \\ \boldsymbol{x}_2 \\・\\・\\・\\ \boldsymbol{x}_t \end{pmatrix}\)

となります。再度言いますが、これは過去の記憶がないために、時系列パターンを空間的な領域の表現に置き換えて入力しているのです。

一方でRNNには時間的領域のパターンを学習できる、つまり、記憶があるため入力は時刻と同期して順番に入力していけばいいのです。過去の全ての入力を毎回入力する必要はありません。

どのように時間的な記憶を形成するのでしょうか?それは、過去のニューロンが将来の自分(ニューロン)に過去の状態を教えてあげればいいのです。つまり再帰です。

図1 ニューロンからニューロンモデル、再帰ニューロンモデルへ

図1は、神経細胞(a)、人工ニューロン(b)、自身へ再帰可能な構造を持つ再帰ニューロン(c)を示しています。過去の自分の出力を将来の自分の入力に再帰させることで、過去の状態を教える、つまり記憶を持つ状態を作り出しています。これを多層にして図示したものが図2です。

図2 一般的なNNとRNNの基本単位

(a)は一般的なニューラルネットワークとして三層ニューラルネットワークを示しています。これを2つ用意して、時間を超えて中間層の入力をさせるようにしたのが(b)です。一般的なニューラルネットワークであれば、時刻tと時刻t+1のニューラルネットワークは独立していますが、RNNでは、過去の自分からのカンニングペーパーを参考に現在の入力を考慮した出力をしているイメージです。

より統計的に近い式で具体的に説明するなら、中間層の任意のニューロンの出力\(y_j\)、入力層と中間層の間の重みを\(w\)、1時刻前の自信との間の重みを\(v\)として、以下の線形モデル(回帰など)と考えてみます。

再帰のない状態の線形モデルは

\(y_n^{(t)} = b + \sum_iw_{ni}x_i^{(t)}\)

です。これに、以前の出力を重み荷重する項を加えます。式は

\(y_n^{(t)} = b + \sum_iw_{ni}x_i^{(t)} + \sum_jv_{nj}y_j^{(t-1)}\)

であり、線形モデルに時間を超えた項を追加しただけだと考えることができます。この加えた項が過去の自分から贈られたカンニングペーパーにあたりますね!

話を戻しますが、図2(b)は1時刻前の記憶にすぎませんが、1時刻前の記憶はさらに前の記憶から生成されているため、記憶が焼失しない限り永遠と続きます(実際には勾配消失や爆発の問題があるので簡単には永遠の記憶は生成できない)。ニューラルネットワークを簡略化して一つのボックスとして表すと、時間的なつながりは以下のようになります。

図3 RNNを展開すると

一般的なニューラルネットワークでは過去の記憶がないため、過去の入力全てを同時入力する必要がありましたが、RNNでは過去の入力を中間層を通じて間接的に情報をつないでいるのです。

Follow me!