パート3はこちら。
なんの話かと言うと
最近、大規模言語モデルを用いたチャットシステムがよく話題になりますが、言語モデルの性能が大きく向上するきっかけとなったのが、下記の論文で公表された「Transformer」のアーキテクチャーです。
ここでは、JAX/Flax を用いて Transformer を実装しながら、その仕組みを解説していきます。このパート4では、完成済みの Transformer Encoder に線形多項分類器をくっつけて、テキスト分類(感情分析)を行います。
JAX/Flax の使い方を学びたいという方は、こちらの書籍を参照してください。
Classification Head の追加
Transformer Encoder は、テキストの「意味」を取り出すという汎用的な役割を持ちますが、一般には、取り出した意味から何らかの予測処理をするレイヤーを付け加えることで、実用的なタスクを実行することができます。ここでは分類処理を行うヘッダー(Classification Head)として、線形多項分類器を追加します。
class TransformerForSequenceClassification(nn.Module): num_labels: int num_heads: int embed_dim: int num_hidden_layers: int def setup(self): self.transformer_encoder = TransformerEncoder( self.num_heads, self.embed_dim, self.num_hidden_layers) @nn.compact def __call__(self, input_ids, attention_mask=None, eval=True): x = self.transformer_encoder(input_ids, attention_mask, eval)[:, 0, :] # select [CLS] token *1 x = nn.Dropout(0.1, deterministic=eval)(x) logits = nn.Dense(features=self.num_labels)(x) return logits
ここでちょっと面白いのが *1 の部分です。Transformer Encoder は、テキストに含まれるそれぞれのトークンに対して、埋め込み空間のベクトルを割り当てますが、後続の処理では、必ずしもすべてのベクトルを利用する必要はありません。Multi-head Attention を繰り返し適用することで、トークン間の情報が混じり合うので、どれか1つのベクトルだけを見ても、テキスト全体の情報が得られる可能性があります。*1 では、先頭のトークンに対応するベクトルだけを取り出して、self.num_labels 個のロジット値を返す線形多項分類器に入力しています。ロジットの値が大きいほど、そのクラスである確度が高いことになります。
先頭のトークンのベクトルだけを見て、本当に分類に有用な情報が得られるのか疑問を持つかもしれませんが、このように考えてください。この後、「先頭のトークンのベクトルを元に感情を正しく予測できる」という条件の下に学習処理を行うので、結果として、先頭のトークン部分に「感情に関わる情報」が集積するようにモデルが学習されるものと期待するのです。
この後は、誤差関数を定義して、勾配降下法で学習するという定番の流れなので、コードの詳細は、下記のノートブックに譲ります。このあたりについては、冒頭の参考書籍(JAX/Flaxで学ぶディープラーニングの仕組み)に詳しい解説があります。
モデルのオブジェクト生成は以下の部分で行っており、*1 の部分で Attention Head の個数(num_heads)、埋め込み空間の次元数(embed_dim)、Transformer Encoder ブロックの段数(num_hidden_layers)が指定できます。
class TrainState(train_state.TrainState): epoch: int dropout_rng: type(random.PRNGKey(0)) model = TransformerForSequenceClassification( num_labels=6, num_heads=8, embed_dim=512, num_hidden_layers=2) # *1 key, key1, key2 = random.split(random.PRNGKey(0), 3) variables = model.init(key1, train_text[:1]) state = TrainState.create( apply_fn=model.apply, params=variables['params'], tx=optax.adam(learning_rate=0.00005), dropout_rng=key2, epoch=0)
今回は、8個の Attention Head を持つ Multi-head Attention レイヤーを用いた Transformer Encoder ブロックを2段使用しており、下記のように、16エポックの学習で、テストセットに対して90%の正解率を達成しています。学習時間は、GPUを接続した環境で、8分39秒でした。
%%time ckpt_dir = './checkpoints/' prefix = 'TextClassification_checkpoint_' state, history = fit(state, ckpt_dir, prefix, train_text, train_mask, train_label, test_text, test_mask, test_label, epochs=16, batch_size=32) #### output #### Epoch: 1, Loss: 1.6664, Accuracy: 0.3089 / Loss(Test): 1.5604, Accuracy(Test): 0.3814 Epoch: 2, Loss: 1.5732, Accuracy: 0.3621 / Loss(Test): 1.4319, Accuracy(Test): 0.4653 Epoch: 3, Loss: 1.3712, Accuracy: 0.4824 / Loss(Test): 0.8860, Accuracy(Test): 0.6870 Epoch: 4, Loss: 0.9353, Accuracy: 0.6611 / Loss(Test): 0.4823, Accuracy(Test): 0.8433 Epoch: 5, Loss: 0.5925, Accuracy: 0.7938 / Loss(Test): 0.3562, Accuracy(Test): 0.8760 Epoch: 6, Loss: 0.4374, Accuracy: 0.8471 / Loss(Test): 0.3241, Accuracy(Test): 0.8884 Epoch: 7, Loss: 0.3606, Accuracy: 0.8709 / Loss(Test): 0.2955, Accuracy(Test): 0.8938 Epoch: 8, Loss: 0.3204, Accuracy: 0.8810 / Loss(Test): 0.2717, Accuracy(Test): 0.8948 Epoch: 9, Loss: 0.2834, Accuracy: 0.8922 / Loss(Test): 0.2661, Accuracy(Test): 0.8963 Epoch: 10, Loss: 0.2607, Accuracy: 0.8988 / Loss(Test): 0.2571, Accuracy(Test): 0.9053 Epoch: 11, Loss: 0.2400, Accuracy: 0.9058 / Loss(Test): 0.2572, Accuracy(Test): 0.9023 Epoch: 12, Loss: 0.2180, Accuracy: 0.9126 / Loss(Test): 0.2601, Accuracy(Test): 0.9003 Epoch: 13, Loss: 0.2077, Accuracy: 0.9177 / Loss(Test): 0.2516, Accuracy(Test): 0.9023 Epoch: 14, Loss: 0.1926, Accuracy: 0.9214 / Loss(Test): 0.2478, Accuracy(Test): 0.9053 Epoch: 15, Loss: 0.1855, Accuracy: 0.9246 / Loss(Test): 0.2587, Accuracy(Test): 0.8998 Epoch: 16, Loss: 0.1713, Accuracy: 0.9281 / Loss(Test): 0.2594, Accuracy(Test): 0.9048 CPU times: user 8min 13s, sys: 19.3 s, total: 8min 32s Wall time: 8min 39s
Confusion matrix はこんな感じ。
わかりやすい例文で予測すると、こんな感じです。
上記のノートブックは Colab で実行できるので、実際に学習して、好きな例文を予測して楽しんでください。
※ 無償版の Colab は GPU の使用時間に制限があります。GPU のランタイムを起動したままにすると、制限時間が超過して、一定期間(1日程度)GPU が使用できなくなるので注意してください。ノートブックの使用が終わったら、「ランタイム」メニューから「ランタイムを接続解除して削除」を実行して、ランタイムを停止しておいてください。
おまけ
今回の実装内容は、下記の書籍を参考にしています。Transformer を勉強するのに最適な書籍なので、こちらもぜひ参照してください。