めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

Transformer モデルの仕組みを JAX/Flax で実装しながら解説してみる(パート3)

パート2はこちら。

enakai00.hatenablog.com

なんの話かと言うと

最近、大規模言語モデルを用いたチャットシステムがよく話題になりますが、言語モデルの性能が大きく向上するきっかけとなったのが、下記の論文で公表された「Transformer」のアーキテクチャーです。

arxiv.org

ここでは、JAX/Flax を用いて Transformer を実装しながら、その仕組みを解説していきます。このパート2では、Muti-head Attention と Feed Forward のブロックを積み重ねた Transformer Encoder ブロックを完成させます。

JAX/Flax の使い方を学びたいという方は、こちらの書籍を参照してください。

Transformer Encoder ブロックとは?

パート1に示した論文の図を再掲します。

パート2では、Multi-head Attention レイヤーを用意したので、これに Feed Forward レイヤーを組み合わせると、左側の Encoder 部分のパーツが揃います。図の左端にある「N×」という記号は、Muti-head Attention と Feed Forward を組み合わせたブロックを 「Transformer Encoder ブロック」として、これを好きな段数だけ積み重ねることを表します。

それでは、まずは、Feed Forward レイヤーを用意します。これは、「全結合層+活性化関数(ReLU)」というよくある構成で、それぞれのトークンに対応した(埋め込み空間の)ベクトルに個別に適用します。一般には、全結合層のノード数はベクトルの次元と異なるので、最後に一次関数(アフィン変換)で、元のベクトルと同じ次元に戻します。つまり、Multi-head Attention レイヤーと同様に、Feed Foward レイヤーは入出力データの構造が一致しています。このため、Transformer Encoder ブロックは、何段でも自由に積み重ねることができます。

Feed Forward レイヤーの実装は次のとおりです。

class FeedForward(nn.Module):
    embed_dim: int
    intermediate_size: int = 2048

    @nn.compact
    def __call__(self, x, eval):
        x = nn.Dense(features=self.intermediate_size)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.embed_dim)(x)
        x = nn.Dropout(0.1, deterministic=eval)(x)
        return x

ここでは、全結合数のノードは 2,048 個に設定してあります。また、出力の直前にドロップアウト層を入れてあるのは、オーバーフィッティングを回避するいつものテクニックです。

そして、Multi-head Attention レイヤーと Feed Forward レイヤーを結合した Transformer Encoder ブロックは次になります。

class TransformerEncoderBlock(nn.Module):
    num_heads: int
    embed_dim: int

    def setup(self):
        self.attention = MultiHeadAttention(
            num_heads=self.num_heads, embed_dim=self.embed_dim)
        self.feed_forward = FeedForward(embed_dim=self.embed_dim)

    @nn.compact
    def __call__(self, x, attention_mask, eval):
        x = x + self.attention(x, attention_mask) # Skip connection *1
        x = nn.LayerNorm()(x) # *2
        x = x + self.feed_forward(x, eval) # Skip connection *3
        x = nn.LayerNorm()(x) # *4
        return x

先ほどの図と見比べて、レイヤー正規化とスキップ接続が加えられている点に注意してください。

*1 では、Multi-head Attention レイヤへの入力 x に、Multi-head Attention レイヤーからの出力を加えたものを次のレイヤーへの入力としています。理屈の上では、Multi-head Attention レイヤーからの出力だけを次のレイヤーに入力してもよいのですが、このスキップ接続によって学習の速度を高めることができます。なぜかというと・・・

スキップ接続がない場合とある場合を比較すると、スキップ接続がある場合、Multi-head Attention レイヤーは、入力値と(スキップ接続がない場合の)出力値の差分だけを計算すればよいことになります。違う見方をすると、Multi-head Attention レイヤーの出力が多少でたらめでも、次のレイヤーには、少なくとも元の入力値と同じ情報を伝えることができます。学習が進むにつれて、Multi-head Attention レイヤーによる差分補正が意味のあるものになっていき、次のレイヤーは(元の入力値に比べて)より有用な情報を受け取るようになっていきます。

一方、スキップ接続がない場合、学習の初期では、次のレイヤーは Multi-head Attention レイヤーが出力するでたらめな情報しか受け取りません。Multi-head Attention レイヤーの学習がある程度進むまで、次のレイヤーは適切な学習ができないことになります。スキップ接続によって、このような問題を回避します。

この後は、レイヤー正規化を挟んで(*2)、先ほど用意した Feed Forward レイヤーを適用します(*3)。ここでもまたスキップ接続を利用しています。最後にもう一度、レイヤー正規化(*4)で終わりです。

ここで、*3 にある self.feed_forward(x, eval) の処理について補足しておきます。x は [テキスト数, トークン数, 埋め込み空間の次元数] というサイズのデータ構造になっていますが、Feed Forward レイヤーの計算処理は、最後の「埋め込み空間の次元数」の部分のデータ、すなわち、埋め込み空間の個々のベクトルに対して、同一の計算処理(同一のパラメーター値を持つ関数)が個別に適用されます。複数のベクトルを混ぜ合わせる処理を行うわけではありません。トークン間の関係性は、Attention Head によってのみ考慮されるのです。

Transformer Encoder の完成

それでは、パート1で用意した、Embedding レイヤーの上に、上記の Transformer Encoder ブロックを好きな段数だけ積み重ねて、Transformer Encoder を完成させましょう。

class TransformerEncoder(nn.Module):
    num_heads: int
    embed_dim: int
    num_hidden_layers: int

    def setup(self):
        self.embeddings = Embeddings(self.embed_dim)
        self.layers = [TransformerEncoderBlock(num_heads=self.num_heads,
                                               embed_dim=self.embed_dim)
                       for _ in range(self.num_hidden_layers)] # *1

    def __call__(self, input_ids, attention_mask, eval):
        x = self.embeddings(input_ids, eval) # *2
        for layer in self.layers: # *3
            x = layer(x, attention_mask, eval=eval)
        return x

*1 では、self.num_hidden_layers で指定した個数だけ Transformer Encoder ブロックのオブジェクトを生成してリストに詰め込んでいます。個別のオブジェクトを使用するので、それぞれ、異なるパラメータ値を持つ、独立したブロックになります。*2 で Embedding レイヤーを通した後に、*3 のループで Transformer Encoder ブロックを順番に適用しています。

これでめでたく Transformer Encoder が完成しましたが、これだけでは有用な仕事はできません。Transformer Encoder が入力テキストから抽出した「意味」を利用して、何らかの仕事をするレイヤーを最後に追加する必要があります。次のパート4では、線形多項分類器をくっつけて、テキスト分類を実行します。パート1で説明した様に、テキストごとに対応する「感情」が割り当てられたデータがあるので、これを利用して、テキストから対応する「感情」を予測するモデルを構築します。

パート4をお楽しみに!

パート4はこちらからどうぞ!

enakai00.hatenablog.com