Explainable AI とは
学習済みのディープラーニングのモデルをリバースエンジニアリング的に分析して、モデルがどのようなロジックで推論しているのかを明らかにする手法です。特定の決まった技術があるわけではなく、モデルの種類に応じてさまざまなテクニックを組み合わせて実現します。
DQN (Deep Q-Network) とは
ニューラルネットワークを強化学習に適用する手法で、「Q-Learning」と呼ばれる強化学習のアルゴリズムとニューラルネットワークを用いた近似表現を組み合わせます。数年前に、ビデオゲームを自動プレイするエージェントで有名なったやつです。
OpenAI Gym とは
強化学習のシミュレーション用プラットフォームで、さまざまなビデオゲームのシミュレーターがライブラリー形式で提供されています。このシミュレーターを用いて、自動プレイエージェントの学習に挑戦することができます。
で・・・ここでは何をするかというと
まず、OpenAI Gym を使って、DQN のモデルでビデオゲームをプレイするエージェントを作ります。このエージェントは、ビデオゲームの画面出力だけを元にして、とるべきアクションを決定します。そこで、この出来上がったモデルに対して、Explainable AI 的なテクニックを駆使して、「このエージェントは画面のどこを見てプレイしているのか?」を調べてみよう的な試みです。
モデルの学習
まず、実際に作った学習用のコードはこちらになります。今回は、car-racing-v2 というカーレースのゲームを使いました。
経験ベースのテクニックをいろいろと駆使して試行錯誤しましたが、結果的には、そこそこプレイできるエージェントが得られました。(よかった!)ちなみに、Google Cloud の Vertex AI Workbench(Jupyter Lab のマネージドサービス)を使って、Tesla T4 が1枚ついた環境で1日〜2日程度学習させました。
学習中のスコアの変化はこんな感じ。ゲームとしての最高点(ノーミスでクリアした時の点数)は 1000 点で、横軸は学習データを集めるためのプレイ回数(エピソード数)です。
これを見るとかなり激しくスコアが変動しており、かつ後半はむしろ下手になっていますが、まぁ、強化学習ってこういうもんなんです。ここでは、700エピソードあたりのモデルを採用することにします。
実際にプレイする様子はこんな感じです。(コースがランダムに選ばれるようになっており、コースによって得意・不得意があるのですが、ここでは、きれいに走れた例を意図的に選んで載せています。)
何を調べたいのか?
今回のモデルでは、画面イメージから畳み込みフィルターで情報を抽出する CNN(畳み込みニューラルネットワーク)を利用しています。32 枚のフィルターと 64 枚のフィルターを 2 段階で適用しています。2 段目のフィルターからの情報を全結合層で総合して、この場面でとるべきアクションを決定します。(正確には、行動状態価値関数の値を計算しているのですが・・・・)。
なのですが・・・そもそもフィルターの数がこれで適切なのかはまったくわかりません。経験ベースの直感で決めてみただけです。
そこで、実際に学習後のそれぞれのフィルターがどんな情報を抽出しており、それがとるべき行動の予測に本当に役立っているのかを「Explainable AI」的に調べてみようというわけです。
どうやるのか
まず、実際にプレイ中の画面データをモデルに入力して、それぞれのフィルターからの出力がどのような画像になっているかを可視化します。これで、それぞれのフィルターがどんな情報を取り出しているか、おおよその検討がつくはずです。
さらに、それぞれのフィルターからの出力をあえて大きくした時に、予測結果がどの程度変化するかを調べます。(フィルターの出力に対する予測値の勾配を計算するわけですね。)大きな正の勾配を持つほど、そのフィルターからの情報は予測に大きな影響を与えていると考えられます。
やってみた
実際のコードは後で紹介しますが、結果としては、やや身も蓋もない感じになりました・・・
ここでは、2 段目の 64 枚のフィルターからの出力を示しており、画像の上の数字が勾配の値になります。左上の4枚目の画像をのぞいて、ほどんどが真っ白です・・・。勾配の値を見ると、この4枚目の画像だけが大きな正の勾配を持っており、要は、このモデルは、この1枚の画像だけをもとにして予測しているわけで、他の 63 枚は無駄だったということですね・・・・。
実際にフィルターが 1 枚だけの場合に、そのフィルターが適切に学習できるかどうかはやってみないとわからないので、本当にこれが無駄だったかはわかりませんが、例えば、この学習済みのモデルから不要なフィルターを削除してモデルのサイズを圧縮する、と言った実用的な応用は考えられるかもしれません。(おそらく学習の途中では複数のフィルターの情報を利用していたのが、学習が進むにつれて、特定の 1 枚に必要な情報が集約していったものと想像しています。)
もうちょっと可視化してみる
これだけだとあまり見栄えがしないので、最後に、ちょっとした可視化を行なってみます。正の勾配を持つフィルターについて、それぞれのフィルター出力を勾配の重みを掛けて合成して、さらにそれを元の画像に重ね合わせてみます。これにより、「画像のどの部分の情報を予測に使用しているのか」を読み取ることができます。
結果はこんな感じになります。
左端が元の画像で真ん中がフィルターを合成したもの、そして右端がこれらを重ね合わせた結果です。
実際のコードは下記にありますが、これを開くとその他の例もいろいろ見れるので、どういう情報を使っているのかを色々と想像してみてください。
パッと見て気がつくのはこのあたりでしょうか・・・・
・コースの左端を主に見ている(左カーブが多いから?)
・コーナーの縞模様はしっかりと見ている(コーナーを認識するため?)
・画面下部のメーター類も見ている(スピードやブレーキングの状態もちゃんと情報として利用しているのかも?)
モデルの構造を変えてみる
モデルの構造を少し変えて、オリジナルの画面を半分のサイズのグレイスケール画像に変換してから入力するようにしてみました。
今回は次のようなフィルターが得られました。先ほどと少し様子が異なり、複数のフィルターを利用しているようです。
オーバーレイで可視化するとこんな感じ
左から順に、オリジナルの画面、モデルに入力する画面、フィルターを合成したもの、オリジナル画面に重ねたものになります。
さらにアニメーションも作ってみました。こちらから参照できます。
宣伝
DQN を含む強化学習の理論を基礎から理解したい方は、上記の書籍を読んでみてください。
畳み込みフィルターの出力から「どこを見ているのか」を調べる方法については、上記の書籍でも解説しています。