めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

強化学習(DQN)に Explainable AI のテクニックを応用してみる

Explainable AI とは

学習済みのディープラーニングのモデルをリバースエンジニアリング的に分析して、モデルがどのようなロジックで推論しているのかを明らかにする手法です。特定の決まった技術があるわけではなく、モデルの種類に応じてさまざまなテクニックを組み合わせて実現します。

DQN (Deep Q-Network) とは

ニューラルネットワークを強化学習に適用する手法で、「Q-Learning」と呼ばれる強化学習のアルゴリズムとニューラルネットワークを用いた近似表現を組み合わせます。数年前に、ビデオゲームを自動プレイするエージェントで有名なったやつです。

OpenAI Gym とは

強化学習のシミュレーション用プラットフォームで、さまざまなビデオゲームのシミュレーターがライブラリー形式で提供されています。このシミュレーターを用いて、自動プレイエージェントの学習に挑戦することができます。

で・・・ここでは何をするかというと

まず、OpenAI Gym を使って、DQN のモデルでビデオゲームをプレイするエージェントを作ります。このエージェントは、ビデオゲームの画面出力だけを元にして、とるべきアクションを決定します。そこで、この出来上がったモデルに対して、Explainable AI 的なテクニックを駆使して、「このエージェントは画面のどこを見てプレイしているのか?」を調べてみよう的な試みです。

モデルの学習

まず、実際に作った学習用のコードはこちらになります。今回は、car-racing-v2 というカーレースのゲームを使いました。

github.com

経験ベースのテクニックをいろいろと駆使して試行錯誤しましたが、結果的には、そこそこプレイできるエージェントが得られました。(よかった!)ちなみに、Google Cloud の Vertex AI Workbench(Jupyter Lab のマネージドサービス)を使って、Tesla T4 が1枚ついた環境で1日〜2日程度学習させました。

学習中のスコアの変化はこんな感じ。ゲームとしての最高点(ノーミスでクリアした時の点数)は 1000 点で、横軸は学習データを集めるためのプレイ回数(エピソード数)です。

これを見るとかなり激しくスコアが変動しており、かつ後半はむしろ下手になっていますが、まぁ、強化学習ってこういうもんなんです。ここでは、700エピソードあたりのモデルを採用することにします。

実際にプレイする様子はこんな感じです。(コースがランダムに選ばれるようになっており、コースによって得意・不得意があるのですが、ここでは、きれいに走れた例を意図的に選んで載せています。)


何を調べたいのか?

今回のモデルでは、画面イメージから畳み込みフィルターで情報を抽出する CNN(畳み込みニューラルネットワーク)を利用しています。32 枚のフィルターと 64 枚のフィルターを 2 段階で適用しています。2 段目のフィルターからの情報を全結合層で総合して、この場面でとるべきアクションを決定します。(正確には、行動状態価値関数の値を計算しているのですが・・・・)。

なのですが・・・そもそもフィルターの数がこれで適切なのかはまったくわかりません。経験ベースの直感で決めてみただけです。

そこで、実際に学習後のそれぞれのフィルターがどんな情報を抽出しており、それがとるべき行動の予測に本当に役立っているのかを「Explainable AI」的に調べてみようというわけです。

どうやるのか

まず、実際にプレイ中の画面データをモデルに入力して、それぞれのフィルターからの出力がどのような画像になっているかを可視化します。これで、それぞれのフィルターがどんな情報を取り出しているか、おおよその検討がつくはずです。

さらに、それぞれのフィルターからの出力をあえて大きくした時に、予測結果がどの程度変化するかを調べます。(フィルターの出力に対する予測値の勾配を計算するわけですね。)大きな正の勾配を持つほど、そのフィルターからの情報は予測に大きな影響を与えていると考えられます。

やってみた

実際のコードは後で紹介しますが、結果としては、やや身も蓋もない感じになりました・・・

ここでは、2 段目の 64 枚のフィルターからの出力を示しており、画像の上の数字が勾配の値になります。左上の4枚目の画像をのぞいて、ほどんどが真っ白です・・・。勾配の値を見ると、この4枚目の画像だけが大きな正の勾配を持っており、要は、このモデルは、この1枚の画像だけをもとにして予測しているわけで、他の 63 枚は無駄だったということですね・・・・。

実際にフィルターが 1 枚だけの場合に、そのフィルターが適切に学習できるかどうかはやってみないとわからないので、本当にこれが無駄だったかはわかりませんが、例えば、この学習済みのモデルから不要なフィルターを削除してモデルのサイズを圧縮する、と言った実用的な応用は考えられるかもしれません。(おそらく学習の途中では複数のフィルターの情報を利用していたのが、学習が進むにつれて、特定の 1 枚に必要な情報が集約していったものと想像しています。)

もうちょっと可視化してみる

これだけだとあまり見栄えがしないので、最後に、ちょっとした可視化を行なってみます。正の勾配を持つフィルターについて、それぞれのフィルター出力を勾配の重みを掛けて合成して、さらにそれを元の画像に重ね合わせてみます。これにより、「画像のどの部分の情報を予測に使用しているのか」を読み取ることができます。

結果はこんな感じになります。

左端が元の画像で真ん中がフィルターを合成したもの、そして右端がこれらを重ね合わせた結果です。

実際のコードは下記にありますが、これを開くとその他の例もいろいろ見れるので、どういう情報を使っているのかを色々と想像してみてください。

github.com

パッと見て気がつくのはこのあたりでしょうか・・・・

・コースの左端を主に見ている(左カーブが多いから?)

・コーナーの縞模様はしっかりと見ている(コーナーを認識するため?)

・画面下部のメーター類も見ている(スピードやブレーキングの状態もちゃんと情報として利用しているのかも?)

モデルの構造を変えてみる

モデルの構造を少し変えて、オリジナルの画面を半分のサイズのグレイスケール画像に変換してから入力するようにしてみました。

github.com

今回は次のようなフィルターが得られました。先ほどと少し様子が異なり、複数のフィルターを利用しているようです。

オーバーレイで可視化するとこんな感じ

左から順に、オリジナルの画面、モデルに入力する画面、フィルターを合成したもの、オリジナル画面に重ねたものになります。

さらにアニメーションも作ってみました。こちらから参照できます。

宣伝

DQN を含む強化学習の理論を基礎から理解したい方は、上記の書籍を読んでみてください。

畳み込みフィルターの出力から「どこを見ているのか」を調べる方法については、上記の書籍でも解説しています。

React と Firebase と Cloud Run を連携するサンプル実装

前提知識

  • React:インタラクティブな Web フロントエンド(クライアント上で稼働する Javascript)を実装するためのライブラリーで、状態変数の値の変化を自動的に画面に反映する機能があります。
  • Firebase:モバイルアプリのバックエンドを Google Cloud で提供するサービスで、ユーザー認証やユーザー管理などの機能を専用のライブラリで簡単に実装できます。
  • Cloud Run:アプリケーションのコンテナイメージを Google Cloud 上にデプロイして実行するサーバーレスタイプのサービスで、オートスケールなどの機能が簡単に利用できます。

なんの話かというと

上記の3つの技術(サービス)を組み合わせて、

  • エンドユーザーは、Google アカウントで Web アプリケーションにログインする
  • Web アプリケーションから Cloud Run で稼働するバックエンド API を実行する
  • バックエンド API は、ログイン済みのユーザーからのリクエストのみを受け付ける

という要件を満たす Web アプリケーションのサンプル実装を作ったので、実装上のポイントを解説します。

github.com

サンプルを実際に利用する手順は、GitHub 上の README にありますが、ここでは、手順の各ステップについてポイントとなる部分を説明していきます。

Firebase の利用準備

Firebase は Google Cloud をバックエンドに使う前提なので、はじめに Google Cloud の利用登録をして、Google Cloud のコンソールから、Google Cloud のプロジェクトを作成する必要があります。その上で、Firebase のコンソールから、作成したプロジェクトを Firebase に登録して、利用する機能(今回の場合は、Google アカウントによるユーザー認証)の有効化やこれから作成するアプリケーションの登録などを行います。手順の「Do this first」では、この部分の作業を行ないます。

バックエンド API を Cloud Run にデプロイする

バックエンド用のアプリケーションは、Python の Flask というフレームワークで実装してあります。手順の「Build and deply the backend API service.」では、Google Cloud 上でコンテナイメージをビルドして、Cloud Run の環境にデプロイしています。アプリケーションの内容は単純で、

{"name": "Etsuji Nakai"}

というデータを受け取って、

{"message": "Hello, Etsuji Nakai!"}

というメッセージを返すものになります。ただし、重要な要件として、「ログイン済みのユーザーからのリクエストのみを受け付ける」という機能を満たす必要があります。これは、次のように、API を実装する関数 hello() に対して @jwt_authenticated というデコレーターを付与することで実現しています。

from middleware import jwt_authenticated

@app.route('/hello-world-service/api/v1/hello', methods=['POST'])
@jwt_authenticated
def hello():
...

デコレーターの中身は、middleware.py で実装されており、デコレーションした関数を実行する前に、API リクエストのヘッダーに Firebase が発行したユーザー ID トークンが含まれていることを検証します。実際に検証する部分は、Firebase が提供するライブラリ firebase_admin を使用しており、次の関数を実行するだけで検証ができます。

firebase_admin.auth.verify_id_token(token)

フロントエンドをビルドする

次は Web フロントエンドをビルドする部分になります。手順の「Build React Application.」の部分にあたります。React のコードは、JSX と呼ばれる形式で書かれており、これを Javascript 用のコンパイラ Babel を用いてブラウザーで実行可能な Javascript に変換します。Babel は Node.js を用いて実行されるので、手順の中では、Node.js をインストールした後にビルドを実行するという流れになります。ビルド済みのファイルは /build 以下に保存されます。

ビルド作業自体はとても簡単なのですが、重要なのはコードの中身なので、ポイントとなる部分を解説しておきます。

まずは、Firebase.js では、Firebase の基本的な設定情報を用意しています。

Firebase.js

import { initializeApp } from "firebase/app";
import { getAuth, GoogleAuthProvider, signInWithPopup } from "firebase/auth";


const firebaseConfig = {
...(省略)...
};

const app = initializeApp(firebaseConfig);
export const auth = getAuth(app);
const provider = new GoogleAuthProvider();

export const signInWithGoogle = () => {
  signInWithPopup(auth, provider)
    .catch((error) => {console.log(error)})
};

export const projectId = firebaseConfig.projectId;

手順の中でも説明しているように、変数 firebaseConfig には、Firebase にアプリケーションを登録した際に発行された各種情報をコピペで書きこんでおきます。これらの情報はアプリケーションがクラウド上の Firebase の機能を利用する際に必要となります。

ここでは特に、変数 auth に格納されたオブジェクトが重要になります。このオブジェクトを用いて、ログインユーザーの情報を管理することができます。また、関数 singInWithGoogle は、Google アカウントによるログイン画面をポップアップ表示して、エンドユーザーにログイン処理を行なってもらう際に利用します。

次に、App.js がアプリケーションの本体です。ここでは特に、次の部分がポイントになります。

App.js

  userAuthHandler(user) {
    if (user) {
      // Login
      this.setState({loginUser: user});
    } else {
      // Logout
      this.setState({
        loginUser: null,
        message: "no message"
      });
    }
  }

  componentDidMount() {
    onAuthStateChanged(auth, this.userAuthHandler);
  }

componentDidMount() は、ブラウザ上でアプリケーションが実行されたタイミングで、ハンドラー関数 userAuthHandler() を登録しています。このハンドラーは、エンドユーザーがアプリケーションにログイン、もしくは、ログアウトしたタイミングで自動的に実行されます。この例では、ログインしたタイミングで、React の状態変数 loginUser にユーザー情報を記録しています。あるいは、ログアウトしたタイミングでこの情報を削除します。冒頭で説明したように、React は、状態変数の値が変化すると自動で画面の再描画が行われるので、これによって、ログイン中の画面とログイン前(ログアウト後)の画面を切り替えています。(このあたりは、React の本当に便利な所ですね・・・。変数を書き換えれば勝手に画面も書き変わるとは・・・。)

次にポイントになるのは、バックエンド API を呼び出す次の関数 getMessage() です。

  getMessage() {
    const callBackend = async () => {
      const baseURL = "https://" + projectId + ".web.app";
      const apiEndpoint = baseURL + "/hello-world-service/api/v1/hello";
      const user = auth.currentUser;
      const token = await user.getIdToken();
      const request = {  
        method: "POST",
        headers: {
          "Authorization": "Bearer " + token,
          "Content-Type": "application/json",
        },
        body: JSON.stringify({
          name: user.displayName,
        })
      };
      fetch(apiEndpoint, request)
        .then((res) => res.json())
        .then((data) => this.setState({message: data.message}));
    };
    const waitMessage = new Promise(resolve => {
      this.setState({message: "Wait..."});
      resolve();
    });
    waitMessage.then(callBackend);
  }

非同期処理をチェーンしているのでちょっと読みづらいですが、画面上のメッセージ表示部分を「Wait...」に書き換えた後に、非同期でバックエンド API を呼び出しています。バックエンド API からメッセージが返ってくると、その内容を状態変数に message に書き込みます。これによって、画面上のメッセージ部分にバックエンド API から受け取ったメッセージが表示されます。

バックエンド API を呼び出す際は、次の部分で取得したユーザー ID トークンをヘッダーに埋め込んでいます。

      const user = auth.currentUser;
      const token = await user.getIdToken();

前述のように変数 auth を用いてログインユーザーの情報が管理されており、ここでは、auth.currentUser でログイン中のユーザー情報を取得して、さらに、user.getIdToken() で該当ユーザーの ID トークンを取得しています。これは、Firebase が独自に発行するトークンで、有効期限が切れると自動的に新しいトークンが発行されます。つまり、user.getIdToken() でトークンを取得するようにしておけば、自分で有効期限を管理する必要はありません。

フロントエンドを FIrebase Hosting にデプロイする

ビルドされたアプリケーションは、普通の(?)Javascript なので任意の環境にデプロイすることができますが、この手順では、Firebase 標準の Web ホスティング環境である Firebase hosting にデプロイしています。具体的な手順は、「Deploy the application on Firebase hosting.」の部分になります。

Firebase hosting を利用するメリットの1つに Cloud Run との連携機能があります。Cloud Run の環境はセキュリティ強化のためにデフォルトでは CORS が禁止されています。CORS が禁止されているというのは、バックエンド API が稼働するドメインとは異なるドメインからのリクエストを拒絶するということで、Google Cloud 上で稼働するアプリケーション以外からのリクエストは受け付けないということです。アプリケーションの設定で CORS を許可することもできますが、Firebase hosting を利用した場合は、Firebase hosting の環境が Cloud Run に対する Proxy として動作することで、外部のクライアントから Cloud Run で稼働するバックエンド API が利用可能になります。

手順の中で、Firebase hosting の設定ファイル Firebase.json を書き換えているのは、このためで、具体的には次の部分が対応します。

    "rewrites": [
      {
        "source": "/hello-world-service/**",
        "run": {
          "serviceId": "hello-world-service",
          "region": "us-central1"
        }
      }
    ]

この設定により、Firebase hosting 上のアプリケーションのパス 「https://www.[Project ID].web.app/hello-world-service/」以下にリクエストを投げると、指定された Cloud Run のサービスにリクエストが転送されます。この際、Cloud Run と同じドメインからリクエストが転送されるので、CORS の制限を回避してバックエンド API の処理が行われます。

アプリケーションを試してみる

https://raw.githubusercontent.com/enakai00/react-google-login-example/main/doc/img/screenshot.png

Firebase hosting にデプロイしたアプリケーションにアクセスすると「Sign in with Google」のボタンが出るので、これを押して、Google アカウントでログインします。すると、Google アカウントに設定されたプロフィール情報から取得した名前とプロフィール画像が表示されます。さらに、「Get message from the backend API」のボタンを押すと、バックエンド API からメッセージを取得します。「Logout」ボタンでログアウトした後に、他の Google アカウントで再ログインすることもできます。

マイクロサービスに関する参考書籍

マイクロサービスの全体像を把握するのに最適


マイクロサービスの「パターン」を網羅した辞書的な本


マイクロサービスにおける適切な実装パターンを選ぶための考え方(トレードオフ)を解説。モノリスからマイクロサービスへの移行の基本的な考え方も学べる。


モノリスからマイクロサービスへの移行に関連したテクニックを解説。変更が困難な例外ケースへの対応テクニックが多いので、まずは、上の書籍(Software Architecture: Modern Trade-off Analyses for Distributed Architectures)で基本パターンを学んだ上で読むのが良い。


www.humio.com
マイクロサービス環境におけるオブザーバビリティについて、コンパクトにまとめられた小冊子


マイクロサービスにおけるセキュリティと信頼性に関する話題を提供


マイクロサービスの基本となる「単一責任の原則」を根本から理解したい方向け


マイクロサービスの全体像を説明した書籍だけど、筆者の経験や冗長な「たとえ話」も多いエッセイ的な内容。マイクロサービスを一通り理解した後に、頭を整理する為に気軽に読むのがよさそう。


Domain-Driven Design の平易な入門書。マイクロサービスとの関連性も説明されており、サービス設計の基本が学べる。