めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

Transformer モデルの仕組みを JAX/Flax で実装しながら解説してみる(パート2)

パート1はこちら。

enakai00.hatenablog.com

なんの話かと言うと

最近、大規模言語モデルを用いたチャットシステムがよく話題になりますが、言語モデルの性能が大きく向上するきっかけとなったのが、下記の論文で公表された「Transformer」のアーキテクチャーです。

arxiv.org

ここでは、JAX/Flax を用いて Transformer を実装しながら、その仕組みを解説していきます。このパート2では、Attention Head / Muti-head Attention を解説します。

JAX/Flax の使い方を学びたいという方は、こちらの書籍を参照してください。

Attention Head の役割

画像を扱うモデルでは、畳み込みフィルターを用いて、画像から特定の情報を抽出した新しい画像を生成します。畳み込みフィルターは、画像のそれぞれのピクセルに対して、その周辺のピクセルとの関係性を利用して、現在のピクセル値を新しいピクセル値に変換します。Attention Head は、これに類似した処理を埋め込み空間のベクトルに対して行います。前回説明した Embedding レイヤーでは、トークン(単語)を埋め込み空間のベクトル値に変換しました。データ構造を見ると分かるように、個々のトークンに対して、特定のベクトル値を1対1でマッピングする仕組みになっています。そして、このベクトルが単語の「意味」を表すものと期待するわけですが、実際には、単語の意味というのはテキスト全体の中で決まるものであり、特定のトークンだけを見て決められるものではありません。その意味では、Embedding レイヤーが出力するベクトルは、まだまだ不正確な情報と言えます。よくある例がこちらですね。

・Time flies like an arrow(時間は矢のように飛ぶ)
・Time flies like a banana(「時間バエ」はバナナを好む)

最初の文の「flies」は「飛ぶ」という動詞ですが、2つ目の文の「flies」は「ハエ」という意味の名詞で、トークンとしては同一ですが、その意味はまったく異なります。文の最後にある、「arrow」もしくは「banana」というトークンを見ることで、初めて「flies」の意味が決まるのです。

そこで、Attention Head は、それぞれのトークンに割り当てられた現在のベクトル値から、トークン同士の関連度を表す、Query、および、Key の値と、そのトークンが及ぼす影響を表す Value の値を生成します。Query, Key, Value はすべて同じ次元のベクトル値です。これらを用いて、現在のベクトル値から新しいベクトル値に変換します。たとえば、「flies」の場合、自身の Query と(自分自身を含む)周りの単語の Key を比較して、この2つが近い(ベクトルの内積が大きい)ほど、関連度(「flies」の意味に対する影響度)が大きいものと考えて、この関連度を重みとして、それぞれのトークンの Value の値を合成します。これが、このトークンの新しいベクトル値になります。Attention Head を1つだけ使用する場合、Query, Key, Value は元のベクトルと同じ次元のベクトルにしておきます。

一方、畳み込みフィルターでは、一般に、複数のフィルターを用いてさまざまな情報を取り出します。これと同様に、複数の Attention Head を用いて、1つのトークンから複数の新しい情報を取り出すことができます。これが、Multi-head Attention です。たとえば、「flies」の場合、「ハエ」「飛ぶ」という意味情報の他にも「名詞」「動詞」と言った品詞情報などもあります。このような異なる種類の情報を個別の Attention Head で取り出そうというわけです。Multi-head Attention の場合は、計算量を抑えるために、1つのAttention Headの Query, Key, Value の次元は「トークンの次元 // Attention Head の個数(割り算して余りを切り捨てた値)」にしておき、それぞれの Attention Head から得られたベクトルを横一列に連結することで、元のベクトルと同じ次元の新しいベクトルが得られます。うまく割り切れない場合、元のベクトルと微妙に異なる次元になってしまいますが、そのような場合にそなえて、最後に、一次関数(アフィン変換)で次元数を再調整します。

Attention Head の実装

それでは、Attention Head の実装を見てみましょう。

class AttentionHead(nn.Module):
    head_dim: int

    def scaled_dot_product_attention(self, q, k, v, mask):
        scores = jnp.matmul(q, jnp.transpose(k, (0, 2, 1))) # *4
        if mask is not None:
            mask = jnp.tile(mask, mask.shape[-1]).reshape(
                    mask.shape[0], -1, mask.shape[-1]) # *7
            scores = jnp.where(mask==0, -jnp.inf, scores) # *8
        w = nn.softmax(scores / jnp.sqrt(self.head_dim)) # *5
        return jnp.matmul(w, v) # *6

    @nn.compact
    def __call__(self, hidden_state, attention_mask):
        q = nn.Dense(features=self.head_dim)(hidden_state) # *1
        k = nn.Dense(features=self.head_dim)(hidden_state) # *2
        v = nn.Dense(features=self.head_dim)(hidden_state) # *3
        output = self.scaled_dot_product_attention(q, k, v, attention_mask)
        return output

hidden_state には、Embedding レイヤーの出力、すなわち、[テキスト数, トークン数, 埋め込み空間の次元] = [N, 128, 512] サイズのデータが入ります。

*1 〜 *3 は、入力値に含まれる個々の(512次元の)ベクトルに対して、対応する Query, Key, Value の値を生成します。ここで、512個の値の単純な一次関数で計算しています。先ほどの説明では、この Qeury, Key, Value によって意味情報や品詞情報を抽出すると言いましたが、「そのような特定の情報を抽出する Query, Key, Value なんてどうやって用意するんだ?」と思った方もいるかも知れません。ここでは、Query, Key, Value の値を生成する関数自体を学習対象のパラメーターにしており、学習がうまく進めば、結果として、「テキストの内容を理解するために必要ななんらかの情報」を抽出する Query, Key, Value が得られるだろう、と期待しています。「理屈はわかるが、実際にどうやれば?」という部分をすべて学習処理に押し付けるというディープラーニングならではの発想ですね。

次に、これらの値を用いて、新しいベクトル値を生成するのが、関数 scaled_dot_product_attention() です。まず、*4 の部分で、1つのテキストに含まれるトークンのすべての組み合わせについて、Qeury と Key の内積を一気にまとめて計算しています。ちょっとわかりにくい所なので、次の例で説明します。

入力データとして、3つのトークンからなるテキストが2つあって、Query と Key は4次元ベクトルだとすると、変数 q には次のようなデータが格納されています。1行の4つの数字が1つの Query です。3×4行列が2つ積み重なっていますが、1つの行列が1つのテキストに対応します。

q = np.array([
    [[1, 0, 0, 0],
     [0, 1, 0, 0],
     [0, 0, 1, 0]],

    [[0, 0, 1, 0],
     [0, 1, 0, 0],
     [1, 0, 0, 0]],     
])

変数 k は、同様の形で Key が格納されていますが、jnp.transpose で、テキストごとの行列をそれぞれ転置しています。

k = np.array([
    [[0, 0, 1, 0],
     [1, 0, 0, 0],
     [0, 1, 0, 0]],

    [[0, 1, 0, 0],
     [0, 0, 1, 0],
     [1, 0, 0, 0]],     
])
np.transpose(k, (0, 2, 1))
### [output] ###
array([[[0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 0, 0]],

       [[0, 0, 1],
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 0]]])

q と転置した k に含まれるテキストごとの行列について、対応するテキストごとに行列としての積を取ると、すべてのトークンの組み合わせに応じた q と k の内積が一気に計算されます。

np.matmul(q, np.transpose(k, (0, 2, 1)))
#### output ####
array([[[0, 1, 0],
        [0, 0, 1],
        [1, 0, 0]],

       [[0, 1, 0],
        [1, 0, 0],
        [0, 0, 1]]])

それぞれのテキストには3個のトークンがあるので、テキスト内のトークンの組み合わせは 3×3 で、ちょうど 3×3 行列に結果がまとめられています。1つの行は、特定のトークンの Query と(自分自身を含む)他のトークンの Key との内積を並べたものになります。

*4 の直後の if 文を一旦無視すると、この後は、内積で計算した score を Query, Key, Value の次元数(self.head_dim)の平方根で割って大きさを調整した後、ソフトマックス関数でトークンごと(行ごと)の合計が1になるように正規化した重み w に変換します(*5)。w のサイズは、先ほどの score と同じで、[テキスト数, トークン数(Query側), トークン数(Key側)] というサイズです。これと [テキスト数, トークン数, Value の次元数] というサイズを持つ v とテキストごとに内積をとれば、うまいぐあいに、正規化した重みで Value を合成した、新しい、埋め込み表現のベクトルがそれぞれのトークンに対して得られます(*6)。

ここまでの *4, *5, *6 の行列計算が、ちょうど、論文に書かれている次の計算式に対応することになります。

Attention Mask の役割

先ほど無視した、*4 の直後の if 文について補足しておきます。前回みたように、個々の入力テキストは、一定の個数(今の場合は、128個)のトークンから構成されており、実際の単語が存在しない部分は、0 でパディングされています。

train_text[0]
#### output ####
array([  101,  1045,  2134,  2102,  2514, 26608,   102,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0])

先ほどの行列計算では、128個のトークンすべてについて重みを計算しているので、このままでは、0 でパディングされた部分も計算対象になります。つまり、存在しない単語からの影響度を計算するという無駄な処理が入ってしまいます。そこで、0 でパディングされた部分の重みを強制的に 0 にする処理を if 文の中で加えています。

変数 mask は、次のようにパディング部分を 0 で示したマスク値を受け取ります。正確には、テキストごとのマスク値を縦に並べた行列ですね。

train_mask[0]
#### output ####
array([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

*7 の部分がややアクロバティックですが、次の例を見れば、何をやっているかがわかるでしょう。

mask = np.array([
    [1, 0, 0],
    [1, 1, 0]
])

np.tile(mask, mask.shape[-1]).reshape(mask.shape[0], -1, mask.shape[-1])
#### output ####
array([[[1, 0, 0],
        [1, 0, 0],
        [1, 0, 0]],

       [[1, 1, 0],
        [1, 1, 0],
        [1, 1, 0]]])

ここでは、3個のトークンからなテキストが2つある場合を想定していますが、それぞれのテキストのマスクを複製して、トークンの数だけ縦に積み重ねています。これは、*4 で計算したスコアと同じ構造になっており、*8 では、マスク値が 0 の部分のスコアを -\infty に強制変更しています。これをソフトマックス関数で重みに変換すると、-\infty の部分は重みが 0 になり、これで、存在しない単語の影響を無視することができます。

これで、Attention Head が完成しました。*4 の関数 scaled_dot_product_attention() は、論文にある下記の図と同じものになっています。(厳密には、今回の実装では、scale と mask の順番が入れ替わってますが、本質的な違いではありません。)

Multi-head attention の実装

単体の Attention Head ができたので、これを複数並べた Multi-head attention を実装しましょう。完成品がこちらになります。

class MultiHeadAttention(nn.Module):
    num_heads: int
    embed_dim: int

    def setup(self):
        head_dim=self.embed_dim // self.num_heads # *1
        self.attention_heads = [AttentionHead(head_dim=head_dim)
                                for _ in jnp.arange(self.num_heads)] # *2

    @nn.compact
    def __call__(self, hidden_state, attention_mask):
        attention_outputs = [head(hidden_state, attention_mask)
                             for head in self.attention_heads] # *3
        x = jnp.concatenate(attention_outputs, axis=-1) # *4
        x = nn.Dense(features=self.embed_dim)(x) # *5
        return x

先に説明したように、複数の Attention Head を用いる場合、個々の Attention Head が出力するベクトルの次元は、「元のベクトルの次元 // Attention Head の個数」になります(*1)。*2 では、self.num_heads 個の Attention Head のオブジェクトを生成して、リストに詰め込んでいます。*3 でそれぞれの Attention Head からの出力を得た後に、これを横に連結して(*4)、アフィン変換で元のベクトルと同じ次元(self.embed_dim)に調整します(*5)。

ここまでの説明からわかるように、Multi-head Attention レイヤーは、入力データと出力データは同じ構造(同じサイズのリスト)になっています。テキストに含まれる個々のトークンが、埋め込み空間のベクトルに置き換えられた形です。最初の Embedding レイヤーからの出力では、トークン同士の関係性は考慮されていませんでしたが、Multi-head Attention レイヤーを通ることで、トークン同士の関係性を考慮した、より本質的な「意味」を表すベクトルが得られるものと期待するわけです。この後、さらに何度も Multi-head Attention レイヤーを積み重ねれば、3個以上のトークンの関係性を考慮したより深い意味内容が得られると期待ができますが、はたしてうまくいくのでしょうか・・・?

パート3をお楽しみに!

パート3はこちらからどうぞ。

enakai00.hatenablog.com