めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

Transformer モデルの仕組みを JAX/Flax で実装しながら解説してみる(パート4)

パート3はこちら。

enakai00.hatenablog.com

なんの話かと言うと

最近、大規模言語モデルを用いたチャットシステムがよく話題になりますが、言語モデルの性能が大きく向上するきっかけとなったのが、下記の論文で公表された「Transformer」のアーキテクチャーです。

arxiv.org

ここでは、JAX/Flax を用いて Transformer を実装しながら、その仕組みを解説していきます。このパート4では、完成済みの Transformer Encoder に線形多項分類器をくっつけて、テキスト分類(感情分析)を行います。

JAX/Flax の使い方を学びたいという方は、こちらの書籍を参照してください。

Classification Head の追加

Transformer Encoder は、テキストの「意味」を取り出すという汎用的な役割を持ちますが、一般には、取り出した意味から何らかの予測処理をするレイヤーを付け加えることで、実用的なタスクを実行することができます。ここでは分類処理を行うヘッダー(Classification Head)として、線形多項分類器を追加します。

class TransformerForSequenceClassification(nn.Module):
    num_labels: int
    num_heads: int
    embed_dim: int
    num_hidden_layers: int

    def setup(self):
        self.transformer_encoder = TransformerEncoder(
            self.num_heads, self.embed_dim, self.num_hidden_layers)

    @nn.compact
    def __call__(self, input_ids, attention_mask=None, eval=True):
        x = self.transformer_encoder(input_ids, attention_mask, eval)[:, 0, :] # select [CLS] token *1
        x = nn.Dropout(0.1, deterministic=eval)(x)
        logits = nn.Dense(features=self.num_labels)(x)
        return logits

ここでちょっと面白いのが *1 の部分です。Transformer Encoder は、テキストに含まれるそれぞれのトークンに対して、埋め込み空間のベクトルを割り当てますが、後続の処理では、必ずしもすべてのベクトルを利用する必要はありません。Multi-head Attention を繰り返し適用することで、トークン間の情報が混じり合うので、どれか1つのベクトルだけを見ても、テキスト全体の情報が得られる可能性があります。*1 では、先頭のトークンに対応するベクトルだけを取り出して、self.num_labels 個のロジット値を返す線形多項分類器に入力しています。ロジットの値が大きいほど、そのクラスである確度が高いことになります。

先頭のトークンのベクトルだけを見て、本当に分類に有用な情報が得られるのか疑問を持つかもしれませんが、このように考えてください。この後、「先頭のトークンのベクトルを元に感情を正しく予測できる」という条件の下に学習処理を行うので、結果として、先頭のトークン部分に「感情に関わる情報」が集積するようにモデルが学習されるものと期待するのです。

この後は、誤差関数を定義して、勾配降下法で学習するという定番の流れなので、コードの詳細は、下記のノートブックに譲ります。このあたりについては、冒頭の参考書籍(JAX/Flaxで学ぶディープラーニングの仕組み)に詳しい解説があります。

github.com

モデルのオブジェクト生成は以下の部分で行っており、*1 の部分で Attention Head の個数(num_heads)、埋め込み空間の次元数(embed_dim)、Transformer Encoder ブロックの段数(num_hidden_layers)が指定できます。

class TrainState(train_state.TrainState):
    epoch: int
    dropout_rng: type(random.PRNGKey(0))

model = TransformerForSequenceClassification(
    num_labels=6, num_heads=8, embed_dim=512, num_hidden_layers=2) # *1

key, key1, key2 = random.split(random.PRNGKey(0), 3)
variables = model.init(key1, train_text[:1])
state = TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=optax.adam(learning_rate=0.00005),
    dropout_rng=key2,
    epoch=0)

今回は、8個の Attention Head を持つ Multi-head Attention レイヤーを用いた Transformer Encoder ブロックを2段使用しており、下記のように、16エポックの学習で、テストセットに対して90%の正解率を達成しています。学習時間は、GPUを接続した環境で、8分39秒でした。

%%time
ckpt_dir = './checkpoints/'
prefix = 'TextClassification_checkpoint_'
state, history = fit(state, ckpt_dir, prefix,
        train_text, train_mask, train_label,
        test_text, test_mask, test_label,
        epochs=16, batch_size=32)
#### output ####
Epoch: 1, Loss: 1.6664, Accuracy: 0.3089 / Loss(Test): 1.5604, Accuracy(Test): 0.3814
Epoch: 2, Loss: 1.5732, Accuracy: 0.3621 / Loss(Test): 1.4319, Accuracy(Test): 0.4653
Epoch: 3, Loss: 1.3712, Accuracy: 0.4824 / Loss(Test): 0.8860, Accuracy(Test): 0.6870
Epoch: 4, Loss: 0.9353, Accuracy: 0.6611 / Loss(Test): 0.4823, Accuracy(Test): 0.8433
Epoch: 5, Loss: 0.5925, Accuracy: 0.7938 / Loss(Test): 0.3562, Accuracy(Test): 0.8760
Epoch: 6, Loss: 0.4374, Accuracy: 0.8471 / Loss(Test): 0.3241, Accuracy(Test): 0.8884
Epoch: 7, Loss: 0.3606, Accuracy: 0.8709 / Loss(Test): 0.2955, Accuracy(Test): 0.8938
Epoch: 8, Loss: 0.3204, Accuracy: 0.8810 / Loss(Test): 0.2717, Accuracy(Test): 0.8948
Epoch: 9, Loss: 0.2834, Accuracy: 0.8922 / Loss(Test): 0.2661, Accuracy(Test): 0.8963
Epoch: 10, Loss: 0.2607, Accuracy: 0.8988 / Loss(Test): 0.2571, Accuracy(Test): 0.9053
Epoch: 11, Loss: 0.2400, Accuracy: 0.9058 / Loss(Test): 0.2572, Accuracy(Test): 0.9023
Epoch: 12, Loss: 0.2180, Accuracy: 0.9126 / Loss(Test): 0.2601, Accuracy(Test): 0.9003
Epoch: 13, Loss: 0.2077, Accuracy: 0.9177 / Loss(Test): 0.2516, Accuracy(Test): 0.9023
Epoch: 14, Loss: 0.1926, Accuracy: 0.9214 / Loss(Test): 0.2478, Accuracy(Test): 0.9053
Epoch: 15, Loss: 0.1855, Accuracy: 0.9246 / Loss(Test): 0.2587, Accuracy(Test): 0.8998
Epoch: 16, Loss: 0.1713, Accuracy: 0.9281 / Loss(Test): 0.2594, Accuracy(Test): 0.9048
CPU times: user 8min 13s, sys: 19.3 s, total: 8min 32s
Wall time: 8min 39s

Confusion matrix はこんな感じ。

わかりやすい例文で予測すると、こんな感じです。

上記のノートブックは Colab で実行できるので、実際に学習して、好きな例文を予測して楽しんでください。

※ 無償版の Colab は GPU の使用時間に制限があります。GPU のランタイムを起動したままにすると、制限時間が超過して、一定期間(1日程度)GPU が使用できなくなるので注意してください。ノートブックの使用が終わったら、「ランタイム」メニューから「ランタイムを接続解除して削除」を実行して、ランタイムを停止しておいてください。

おまけ

今回の実装内容は、下記の書籍を参考にしています。Transformer を勉強するのに最適な書籍なので、こちらもぜひ参照してください。