読者です 読者をやめる 読者になる 読者になる

めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

倒立振子でDQNにおけるモデルの複雑さと学習内容の関係をちらっと確かめてみた系の話

何の話かというと

qiita.com

上記の記事では、「倒立振子」を題材にした、DQN(Deep Q Network)による強化学習の解説があり、非常によくまとまっています。

一方、この記事の中では、全結合層を4層に重ねたネットワークを利用しているのですが、倒立振子の問題に限定すれば、もっとシンプルなネットワークでも対応できる気がしなくもありません。

というわけで、「0層(パーセプトロン)」「1層」「2層」のネットワークでどこまで学習できるのか、モデルの複雑さと学習内容の関係を確認してみたよー、というのがこのブログのネタになります。

DQNとは?

まずは簡単にDQNを解説しておきます。

ビデオゲームの自動プレイで有名になりましたが、「与えられた環境において、最善の行動を選択する」という処理を実現することが目標です。ここで言う「行動」は、ビデオゲームの操作のように、「どのボタンを押すのか」というように、決められた選択肢から1つを選ぶという簡単なものとします。一方、何を持って「最善」とするかは、ちょっとトリッキーです。ビデオゲームの場合、一定時間プレイを続けて、トータルのスコアができるだけ高くなる事を「最善」と考える必要があります。その行動を選択した時点でのスコアの増加ではなく、その後の行動全体を通じてのトータルスコアを増加することを考えないといけません。ブロック崩しのゲームであれば、目の前のブロックを急いで消すよりは、少し時間がかかっても、端のブロックを集中して狙った方がよい、というような判断が必要です。

DQNでは、この問題を次のように解いていきます。

まず、時刻 t における環境を変数 s_t で表します。この時、行動 a_t を選択した場合に時刻 t+1 ではスコアが r_t だけ増加するものとします。最善の行動かどうかはわかりませんが、とにかく、t=0N までなんらかの行動を続けた場合のトータルのスコアを次式で計算します。

  R = r_0 + \gamma r_1 + \gamma^2 r_2 + \cdots = \sum_{t=0}^N\gamma^t r_t ―― (1)

ここで、\gamma は、0.9 程度の 1 よりちょっとだけ小さい値とします。本来は、\gamma=1 とするべきですが N が大きい時に R が大きくなりすぎて計算処理が発散するのを防止するために入れてあります。この R が最大になる行動を学習することが目的になります。また、このルールにしておけば、永遠にゲームを続けても R は有限の値に収束するので、これ以降は、N=\infty として話を進めます。

ここで、ちょっと大胆ですが、次のような魔法の関数 Q(s,a) が存在すると仮定します。それは・・・、

・環境 s において行動 a を選択したと仮定して、その後、ずっと最善の行動をとり続けた場合のトータルスコア ―― (2)

を与える関数です。

最善の行動がどのようなものか分からないのに、こんなものが計算できるわけがないのですが、とにかく、こんな都合のよい関数 Q(s,a) があったと仮定します。

このとき、よーく考えると、Q(s,a) は次の性質を満たすことがわかります。

 Q(s_0,a) = r_0 + \gamma \max_{a'} Q(s_{1}, a') ―― (3)

まず、時刻 t=0 の環境 s_0 において行動 a を取ると、その時点でスコア r_0 が得られます。そして、時刻 t=1 以降は、最善の行動を続けるという前提になりますが、時刻 t=1 の環境 s_1 における最善の行動とは何でしょうか? そう、Q(s_1, a') が最も大きくなるような a' が最善の行動ですよね。(2)の定義をよーく思い出してください。

その結果、(1)の定義を用いて、(3)の関係式が得られます。

これは、ある関数を再帰的に定義する場合と同じテクニックです。n! を計算するサブルーチン f(n) を再帰的に定義する場合、(n-1)! を計算するサブルーチンはすでにあるものと仮定して、f(n) = n \times f(n-1) という計算を行います。ここでは、Q(s_{1}, a) がすでにあるものと仮定して、Q(s_0, a) を計算しているわけです。

そして、(3)を一般化すると、次のように表現することが可能です。

・環境 s において行動 a を選択すると、スコア r が得られて次の環境が s' になるとした場合、次式が成立する。

 Q(s,a) = r + \gamma \max_{a'} Q(s', a') ―― (4)

したがって、なんらかの方法で(4)の性質を満たす魔法の関数 Q(s,a) を見つけることができれば、これが問題の答えになります。環境 s において、Q(s,a) が最も大きくなる行動 a を選択すれば、それがトータルスコアを最大にする、最善の行動ということになります。

ここまでは、完全に理論(理屈?)だけの世界ですが、DQNでは、実際に(4)を満たす関数 Q(s,a) を機械学習で構成していきます。厳密に(4)を満たす関数を構成するのは困難ですが、できるだけ(4)に近い関数を近似的に作り出します。

具体的な手順は、次の通りです。

まず機械学習の前提となるトレーニングセットが必要ですが、(4)の前文にある、

・環境 s において行動 a を選択すると、スコア r が得られて次の環境が s' になる

という (s,a,r,s') の4つ組の情報が必要です。たとえば、人間がプレイした結果や適当なアルゴリズムでランダムにプレイした結果をかき集めれば、これは入手可能です。このデータは、最善のプレイに基づく必要はありません。純粋に、これから学習するゲームの特性を示すデータになります。

次に、関数 Q(s,a) を適当なニューラルネットワークで構成します。このニューラルネットワークのパラメーターをうまく調整すれば、求める Q(s,a) に一致するものと期待します。(理屈の上では、ニューラルネットワークを複雑にすれば、どれほど複雑な関数でも表現できるので、決して無謀な期待ではありません。)

ニューラルネットワークに含まれるパラメーターを w として、これを Q(s,a \mid w) と表現しておきます。

この時、トレーニングセットから学習用データ (s,a,r,s') を1つ取り出して、次の2つの値を計算します。

Q(s,a \mid w)
r + \gamma \max_{a'} Q(s', a' \mid w)

求めるべき Q(s,a) では、この2つは一致するはずですので、この2つが近い値になるようにパラメーター w を修正します。そして、その後、また別の学習用データ (s,a,r,s') を1つ取り出して、再度、パラメーター w を修正する、ということを何度も繰り返します。

実際には、1つずつ計算するのではなく、複数のトレーニングデータをまとめて利用するミニバッチを適用します。具体的には、バッチに含まれるデータに対して、次式で二乗誤差を計算して、

 E(w) = \sum_{(s,a,r,s')} \left\{Q(s,a \mid w) - \left(r + \gamma \max_{a'} Q(s', a' \mid w)\right)\right\}^2

これを最小化するように、勾配降下法でパラメーター w を修正するということを繰り返します。これを何度も繰り返せば、最終的に、すべての状態 s に対して(4)を満たす魔法の関数 Q(s,a) が得られる(かも知れない)というわけです。満たすべき性質を再帰的に定義しておいて、そこから議論を進めるのは、数学の論証ではよくやることで、なんとなく騙された気分になることもあるのですが、ここでは、その性質を満たす関数を力技でとにかく見つけ出してしまうというわけですね。これ、この方法を最初に考えた人がほんとにすごいと思います。

最後にトレーニングセットを収集する手法について補足しておきます。先ほど、「適当なアルゴリズムでランダムにプレイした結果をかき集める」と言いましたが、学習途中の完璧ではない Q(s,a\mid w) を「適切なアルゴリズム」として使うという方法があります。たとえば、100回分のプレイ結果を保存する領域をFIFOのキュー形式で用意しておき、1回バッチ学習すると、その段階の Q(s,a\mid w) で1回プレイして、その結果をキューに入れる、という事を繰り返します。これにより、直近の100回分のプレイ結果をトレーニングセットとして、バッチ学習を継続することが可能になります。ただし、Q(s,a\mid w) をそのまま使ってプレイするのではなく、ある程度、ランダムな選択もまぜて、なるべく多くの (s,a) に対するデータを蓄えるようにしておきます。

倒立振子の問題

という前置きを踏まえて、冒頭の記事をもう一度読んでください。これは、環境、行動、スコアを次のように定義した問題になっています。

・環境:直近の4ステップ分の振り子の角度 s_t = (\theta_{t-3},\theta_{t-2},\theta_{t-1},\theta_t)
・行動:回転軸の回転方向の決定 a = \pm 1
・スコア:次のステップでの振り子の高さ h(\theta_{t+1})

「最善」の行動は、「引いて、押して(振り上げて)、上で静止させる」という動作になります。下記は、この最善に近い結果について、記事中のGIFアニメーションを引用したものです。

https://qiita-image-store.s3.amazonaws.com/0/30340/c292ed92-ef24-807b-a2e3-123fabdbbdf9.gif

もしくは、本当の意味での最善をグラフにすると次になります。青い線が振り子の高さ、緑が回転方向(右か左か)、赤は回転速度の変化です。

これを見ると、大きく2つの動作を学習する必要があることがわかります。前半の振り上げる動作と後半の上で静止させる動作です。

まず、前半の動作については、これはそれほど複雑なものではなさそうです。棒の高さと速度の組み合わせで、引くタイミングと押すタイミングを判断すればOKです。もしかしたら、この程度の判断は、単なる線形関数(パーセプトロン)でも学習できるかも知れません。

一方、後半の動作はちょっと高度です。上空で静止しているという(ほぼ)同じ状態に対して、右と左を交互に選択するという処理が必要です。判断の材料は、あくまで「現在の状態」だけで、過去の行動の選択を参照することは許されていません。これは、どの程度複雑なニューラルネットワークを用意すれば判断できるようになるのでしょうか? また、前半と後半を別々のニューラルネットワークで判断するということも許されません。結局、「前半と後半の2種類の判断を1つのニューラルネットワークで行う」ためには、どの程度複雑なニューラルネットワークが必要なのか、ということを考える必要があります。

・・・・というか、現実には、適当に複雑なネットワークを組んでそれでうまくいけばいいのですが、ここでは、純粋に理論的な興味として、線形関数(パーセプトロン)、1層、2層、それぞれのニューラルネットワークでどこまで学習できるかを実験してみようというわけです。

実験結果

はい。やってみました。先ほどの記事ではChainerを使って実装していましたが、私は、TensorFlowでやりました。Jupyterのノートブックを下記に置いておきました。

github.com

先の記事でも触れられていますが、学習結果は初期条件に依存して割と不安定です。ここでは、プレイ結果の保存数とバッチのサンプル数を変えながら何度か試して、その中でベストの結果を選んでいます。

まずは、単なる線形関数(パーセプトロン)です。

GIFアニメーションは作ってないので、各自、脳内再生してください。引いて押して振り上げた後、押し続けてぐるぐる周りはじめますが、途中でやばいことに気づいてブレーキをかけようとしています。後からブレーキをかける所まで学習できるとは思っていませんでした。意外とがんばってます。

次は、1024個のユニットの隠れ層をもった1層ネットワークです。

なんと。。。。前半の振り上げにちょっと手間取っていますが、後半の空中維持をちゃんと実現しています。恐らく、空中維持だけなら1層ネットワークで十分なのですが、前半と後半の両方を学習するのは、ちょっとつらいのでしょう。後半の処理(回転方向をプルプル変える動作)の学習にひきづられて、前半でも無駄にプルプルしちゃっているのかも知れません。複数の事を独立に学習するには、複数の層が必要なことがわかります。

次は、512個(1層目)+1024個(2層目)のユニットからなる2層ネットワークの結果です。

前半の振り上げがまだ完璧ではありません(1回だけ無駄にスイングしてる)が、後半はほぼ完璧ですね。条件を変えてもっとがんばって学習すれば、もしかしたら、これで完璧な結果を達成できるのかも知れませんね。

まとめ

DeepLearningにおいては、とにかくやたらと複雑な多層ネットワークが登場しますが、本当は学習内容に合わせた適切なネットワークの組み方、という知見も大切なんだと思います。簡単な例ですが、ネットワークの構成と学習内容の関係がよく見えるという意味で、面白い結果が得られました。

ちなみに、AlphaGoでもDQNが用いられていますが、碁の場合は、各ステップのスコアをどう与えるかが課題になります。ビデオゲームのような「各ステップにおけるスコア」に相当するものは与えられず、ある手が最善かどうかは最終的な勝ち負けだけで決まります。AlphaGoの場合は、独自の方法で各ステップ(その時点の盤面の様子)に対応するスコア(優劣を数値化したもの)を与えることで、DQNを適用していたような気がします。(すいません。後でもう一回、論文読みます。。。)

しかし繰り返しになりますが、(4)の再帰的な定義によって、これほど複雑な処理が実現できるとは驚きですね。再帰処理おそろしや。

追記事項

その後、がんばって何度も学習を繰り返していた所・・・・1層のネットワークで下記を達成してしまいました。かなり完璧に近いです。いやぁ面白いですね。

ちなみに、もちろんこれは、真下で静止した位置から始めるという条件に特化して過学習している可能性もあります。(というか、たぶんそう。)このネットワークで、他の初期条件の振り子をどこまで正しく扱えるかは、別途検証が必要です。