はじめに
前回は、時系列で可変長なデータを処理できるリカレントニューラルネットワーク(Recurrent Neural Network)について説明しました。そこで取り上げたRNNは長いデータを処理すると計算が爆発するため、記憶打ち切り型通時的逆伝搬(Truncated Back propagation Through time)により、短期記憶情報だけで処理していました。でも、やっぱりそう割り切るとAIの精度も限定的になります。そこで、この単純RNN(Simple Recurrent NN)の長期依存性問題を解決する構造を持った長・短期記憶ユニット(Long Short-Term Memory)というモデルが現れ、これが現在のRNNの主流となっています。
単純RNNの長期依存性問題
LSTMの説明に入る前に、前回学んだ単純リカレントニューラルネットワークの長期依存性問題について復習しましょう。単純RNNは超能力者ではないので、「僕の」「麻里ちゃんの」「愛してやまない」「ケーキは」という一文の最後の言葉を当てることができません。でも記憶範囲がもっと広がり、「麻里ちゃんが、玲奈ちゃんととても幸せそうにモンブランを食べている顔を見てピンときたのですが、」という文のあとで、「僕の」「麻里ちゃんの」「愛してやまない」「ケーキは」と続いた場合はどうでしょうか。今度は(モンブランです)という言葉がすぐに浮かんできますね。
私たち人間は、このように直前の情報だけではなく、必要に応じてもっと前のセンテンスの情報も利用しているわけです。こうした長期記憶をRNNが構造的にできないわけではありません。ただし、「麻里ちゃんが」「玲奈ちゃんと」「とても」「幸せそうに」「モンブランを」「食べている」「顔も」「見て」「ピンときた」「のですが」などと情報量が増えてくると、それらがどのくらいの重みでどう関連するかが著しく複雑になります。十数ステップであれば対応できますが、100ステップ以上にもなると計算が爆発してしまいます。これが、単純RNNの長期依存性問題です。
この問題を解決するために登場したのがLSTMです。Long Short-Term Memoryという、ロングとショートが混ざった妙な言葉ですね。図1のように単純RNNが短期記憶しか利用しないのに対し、LSTMは長期依存(long-term dependencies)を学習できるように改良したモデルです。単純RNNが計算が爆発するのにLSTMは大丈夫って、そんなことがどうして可能なのでしょうか。
図1:ニューラルネットワークの記憶範囲
RNNの構造
単純RNNは、図2のようなリカレント(再帰)構造を持ちます。前セルの出力(Recurerent)と入力(Input)が合わさって出力(output)が出されている単純なモデルですが、前回、説明を省いたtanhという妙なものがありますね。これは、ハイボリックタンジェントと呼ばれる関数で、統計でよく登場するロジスティックシグモイド関数です。あ~面倒っちい言葉がいっぱい出たので、ここでくじけないために少し補足します。
図2:RNNのリカレント構造
ハイボリックという言葉は双曲線という意味です。シグモイドという言葉は、麻里ちゃんの焼き肉の回(Vol.5)にシグモイドニューロンとして登場しましたね。パーセプトロンの入出力が1か0の二値であるのに対し、シグモイドニューロンは0から1までの実数モデルでした。そして、ロジスティックという言葉も、麻里ちゃんの定時帰りを予測するロジスティック回帰(Vol.14)で登場しました。こちらは”発生確率を予測して、確率に応じてYes/Noに分類するもの”だったことを思い出してください。
Vol.14のロジスティック関数をの図を思い出してください。シグモイド関数もロジスティック関数も同じS字型の関数で、Xの値を0から1の値に変換する関数です。一方、tanhはハイボリック(双曲線)という言葉がついているように、0~1ではなく-1から1の値に変換します(図3)。
シグモイドと違い、tanhの出力は正と負の値が持てるので状態の増減が可能です。そのためtanhはセルの反復接続で使用され、内部で追加される候補値を決定するのに便利です。2次微分がゼロとなる前に長期間値を維持できますので勾配消失問題に対処しやすい関数です。2次微分やら勾配消失やら難しいので、情報を利用しやすいようにいい塩梅に変換してくれるものだと思っておいてください。図2で言えば、前セルからのリカレント情報(記憶情報)をそのまま垂れ流しにするのではなく、tanhが要点をうまく整理してくれるイメージです。
図3:Tanhは双曲線の関数
LSTMの構造
続いて図4のLSTMのリカレント構造を見てみましょう。うわっ、格段に複雑になっていますね。でも、これでもノーマルなLSTMなんです。順番に説明すれば理解できますので、どうぞついて来てください。
図4:LSTMのリカレント構造
(1)前セルの出力に記憶ラインが追加(ht-1とCt-1)
単純RNNで1つだった前セルからの情報伝搬が、出力(Recurrent)のほかに記憶(Memory)が追加されて2ラインになっていますね。これ、大まかに言ってRecurrentの方がRNNと同じ短期記憶で、Memoryが長期記憶だと思ってください。LSTMは、その名の通り、短期と長期を関連させながらも別々のラインで記憶保持しているのです。
(2)前セルの出力(Recurrent)と入力の合流(ht-1とXt)
前セルの出力ht-1(短期記憶)と今セルの入力Xtが合流します。合流された信号は4つのラインに分岐(同一情報コピー)されます。この合流結果は、僕の麻里ちゃんの好きな」という短期記憶に入力「ケーキは」が加わったものになります(これは前回のRNNと同じです)。
(3)忘却ゲート(ftの出力)
一番上のラインは忘却ゲートです。これは、前セルからの長期記憶1つずつに対して、σ(シグモイド関数)からでた0~1までの値ftで情報の取捨選択を行うものです。1は全て残すで、0は全部捨てるです。短期記憶ht-1と入力Xtで「僕の麻里ちゃんの愛してやまないケーキは」まで認識した時点(t)において、長期記憶の中から「玲奈ちゃんと」は重要でないと判断したとき、σの出力ftは0付近の値となり、この記憶を忘却します。一方、「モンブランを」という情報は重要そうなのでftは1でそのまま残しています。
RNNが過去の情報を全て利用しようとすると計算が爆発すると説明しましたが、忘却ゲートにより不要と思われる情報を捨てることで爆発を防ぐのです(いらない情報をどんどん忘れるのは、人間と一緒ですね)。
3つのゲートとシグモイド関数σ LSTMには、忘却ゲート(forget gate)、入力ゲート(input gate)、出力ゲート(output gate)の3つのゲートがあります。ゲートというと自身の信号の出入口になっているようにイメージしますが、ここではちょっと違います。シグモイド関数σによって、流れてくる信号のゲートの開け閉めを行っている制御門なのです。1が全開、0が閉め切りで、例えば0.5なら半開きというゲートで信号の重みコントロールを行っているのです。 |
(4)入力ゲート(Ct'とit)
短期記憶ht-1と入力Xtで合算された入力データを長期保存用に変換した上で、どの信号をどのくらいの重みで長期記憶に保存するか制御します。ここは2つのステップで処理されます。
①tanhによる変換(Ct'を出力)
入ってきた情報をそのまま流すのではなく、要点を絞った端的な形にした方が、情報量を削減できるうえに利用しやすくなります。さきほどtanhは内部で追加される候補値を決定するのに便利な関数と説明しました。例えば、「麻里ちゃんの」は(コンピュータ的には)「麻里の」でいいでしょうし(僕はよくないが…)、「愛してやまない」は要するに「好きな」という候補に置き換えてもいいでしょう。そんなふうにシンプルに変換されてCt'が出力されています。
②入力ゲート(it)による取捨選択
前回触れたように、LSTMは通時的誤差逆伝搬(Back propagation Through time)によって重みを調節します。通常の誤差逆伝搬は入力Xtの重みの調節ですが、通時的誤差逆伝搬は、これに加えて前セルからの短期記憶ht-1からの情報にも影響をうけます。そのため、ht-1から入ってくる無関係な情報によって重みがミス更新されるのを防止するために、入力ゲートが必要な誤差信号だけが適切に伝搬するように制御しています。
ht-1+Xtで作られた「僕の麻里ちゃんの愛してやまないケーキは」という情報の中から、入力ゲートのσ(シグモイド関数)が防止すべきものと流すべきものを選別します。今回は、「僕の」という部分に対して”なに妄想してんだよ”ってことで、この文節に対するシグモイド関数の出力itは0が出されてしまいました(あらら)。一方、「麻里ちゃんの」「愛してやまない」「ケーキは」は出力が1に近い値で一応残されました。
(5)出力ゲート(otを出力)
htは短期記憶の出力です。上記のような処理により長期記憶に短期記憶が加わって取捨選択された値(長期記憶の出力Ct)の中で、短期記憶に関する部分のみを出力します。ここも先ほどと同様2つのステップで処理されます。
①tanhによる変換
tanhの入力は、前セルからの長期記憶Ct-1に入力Xtを変換した短期記憶Ct'を加えたものです。それぞれ忘却ゲートおよび入力ゲートで取捨選択はされています。これをそのまま長期記憶として出力するのがCtですが、そこに含まれる短期記憶部分も長期記憶と合わせることによって、短期記憶のみの時より端的で利用しやすいものに変換することができます。
例えば、短期記憶が「僕の彼女の好きなケーキは」だったとしましょう。この場合、長期記憶に僕の彼女が麻里ちゃんだという重要要素があれば、短期記憶をより明確な「麻里ちゃんの好きなケーキは」に変換するようなイメージです。
②短期記憶の取捨選択
入力ゲートで自セルを保護したように、出力ゲートでも次のセルへの悪い情報伝搬を防止します。次のセルを活性化するための重みhtを更新する際に、無関係の情報を流して悪い影響を与えないようにしなければなりません出力ゲートのσ(シグモイド関数)により0から1の範囲でOtが出力され、短期記憶出力htに必要な信号だけを適切に伝搬するように制御しています。
今回は、入力ゲートで「僕の」という言葉がすでにカットされているので、出力ゲートでは特にストップする言葉がありませんでした。このように入力でも出力でも二重にゲートチェックすることにより、無関係な情報が流れないように徹底しているのです。これまでの処理の結果、このセルからは「麻里の好きなケーキは」という情報がhtに出力されたことになります。
[RELATED_POSTS]
まとめ
今回は、前回説明したリカレントニューラルネットワーク(単純RNN)の欠点を補った、LSTMの構造について説明しました。前セルからの出力が短期記憶と長期記憶に分かれていること、情報が暴発しないように不要と思われる情報は忘却ゲートで消し去ること、不必要な情報で誤った重み更新をしないように入力ゲートと出力ゲートで取捨選択していること、tanhによって情報をそのまま流すのではなく利用しやすい形に変換していること、などが理解できましたでしょうか。
なお、ここでは理解しやすいように例文を使って説明していますが、実際にRNNがどんな処理を行うかはブラックボックスですし、学習度合いによっても異なります。LSTMの各パートでこんな感じで処理しているんだってイメージしていただければOKです。
梅田弘之 株式会社システムインテグレータ :Twitter @umedano
- カテゴリ:
- AI技術をぱっと理解する(基礎編)