めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

Transformer モデルの仕組みを JAX/Flax で実装しながら解説してみる(パート1)

なんの話かと言うと

最近、大規模言語モデルを用いたチャットシステムがよく話題になりますが、言語モデルの性能が大きく向上するきっかけとなったのが、下記の論文で公表された「Transformer」のアーキテクチャーです。

arxiv.org

ここでは、JAX/Flax を用いて Transformer を実装しながら、その仕組みを解説していきます。このパート1では、Embedding レイヤーを解説します。

JAX/Flax の使い方を学びたいという方は、こちらの書籍を参照してください。

Transformer の全体像

冒頭の論文では、Transformer Encoder と Transformer Decoder を組み合わせた下記のモデルが説明されています。

左側の Encoder でテキストを解釈して、右側の Decoder で新しいテキストを生成するという形になっており、例えば、機械翻訳システムとして使用することができます。

ここでは、左側の Encoder にフォーカスして、これを実装していきます。この後で説明するように、Encoder だけでもテキスト分類やテキスト生成を実施することが可能です。

学習用データセットの準備

はじめに、学習に使用するデータセットを用意します。ここでは、英語の短文とその文章に付随する「感情」を正解ラベルとするデータセットを用いて、テキスト文に対する感情分析を行うモデルを構築します。

まずは、必要なパッケージをインストールして、基本的なモジュールをインポートします。(次のコードでは、実装済みの Transformer モデルを提供する Hugging Face Transformer をインストールしていますが、ここでは、テキストをトークンに分割する Tokenizer を使用するためにインストールしています。Transfomer 自体は、きちんと一から実装するのでご安心ください。)

%%bash
pip install -q git+https://github.com/huggingface/transformers.git datasets
pip install -q flax==0.6.1 jax==0.3.25 optax==0.1.3 jedi
import numpy as np
import matplotlib.pyplot as plt
from pandas import DataFrame
from functools import partial

import jax, optax
from jax import random, numpy as jnp
from flax import linen as nn
from flax.training import train_state
from flax.training import train_state, checkpoints

plt.rcParams.update({'font.size': 12})

先ほどインストールした datasets ライブラリーを使って、"Emotion" データセットをダウンロードします。

from datasets import load_dataset
emotions = load_dataset('emotion')

データの一部を表示すると、次のように英語の短文が含まれることがわかります。この後で見るように、それぞれの文章に対する「感情」を示すラベルが付与されています。

emotions['train']['text'][:2]
#### output ####
['i didnt feel humiliated',
 'i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake']

ここで、テキストデータをトークンに分割する Tokenizer を用意します。

from transformers import AutoTokenizer, AutoConfig
model_ckpt = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
vocab_size = AutoConfig.from_pretrained(model_ckpt).vocab_size

vocab_size
#### output ####
30522

Tokenizer は、テキストに含まれるそれぞれの単語を対応する整数値に個別に置き換えます。テキスト文を扱う機械学習モデルは、一般に、生のテキスト文ではなく、単語ごとに整数値に置き換えられた、「整数値のリスト」を入力データとして受け取ります。この入力データを用意するのが Tokenizer の役割です。

上記で "vocab_size" に保存した値 30,522 は、この Tokenizer は全部で 30,522 種類の整数値を使用する(つまり、30,522 種類の単語を識別する)ことを示しています。実際のテキスト文には、より多くの種類の単語が使用されている可能性がありますが、Tokenizer は類似した単語、あるいは、重要度の低い単語を同じ整数値に置き換えるなどの方法で、30,522 種類に抑えています。

それでは、Tokenizer でトークンに分割して、モデルに入力可能なデータセットを用意しましょう。

トークンに分割する際は、トークン数の最大値を指定する必要があるので、一文に含まれる単語数の最大値を確認しておきます。

max([len(text.split(' '))
     for text in emotions['train']['text'] + emotions['validation']['text']])
#### output ####
66

最大でも66ワードなので、ここでは、最大128トークンとして Tokenizer を実行します。

text_length = 128

# Training set
train_set = tokenizer(emotions['train']['text'], max_length=text_length,
                      padding='max_length', truncation=True)
train_text = np.array(train_set['input_ids'])
train_mask = np.array(train_set['attention_mask'])
train_label = np.eye(6)[emotions['train']['label']]

# Test set
test_set = tokenizer(emotions['validation']['text'], max_length=text_length,
                     padding='max_length', truncation=True)
test_text = np.array(test_set['input_ids'])
test_mask = np.array(test_set['attention_mask'])
test_label = np.eye(6)[emotions['validation']['label']]

emotion_labels = emotions['train'].features['label'].names

トークンに分割されたテキストは、train_text(トレーニングセット)と test_text(テストセット)に保存しており、サンプルを表示すると次のようになっています。

train_text[0]
#### output ####
array([  101,  1045,  2134,  2102,  2514, 26608,   102,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0])

101, 102 は、文の始まりと終わりを表すトークンで、その間が文章の本体です。残りは0でパディングして、全体として、text_length で定義した128トークンに揃えられています。それぞれのデータに対して、どこまでが実際のテキストかを示すマスクも生成されています。これは、後ほど、Attention 機構を適用する際に、パディングした部分を無視するために使用します。

train_mask[0]
#### output ####
array([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

正解ラベルは、上記のコードでワンホット表現に変換してあり、次のようになります。

train_label[0]
#### output ####
array([1., 0., 0., 0., 0., 0.])

この出力からわかるように、6種類のラベルがあり、それぞれの意味は、次のようになります。

emotion_labels
#### output ####
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

Embedding レイヤーの実装

入力データが用意できたので、まずは、これを受け取る、下記の Embedding レイヤーを実装します。

先に実装を見せると、次のようになります。

class Embeddings(nn.Module):
    embed_dim: int
    text_length: int = text_length
    vocab_size: int = vocab_size

    @nn.compact
    def __call__(self, input_ids, eval):
        token_embeddings = nn.Embed(
            self.vocab_size, self.embed_dim)(input_ids) # *1
        position_ids = jnp.arange(self.text_length) # *2
        position_embeddings = nn.Embed(
            self.text_length, self.embed_dim)(position_ids) # *3
        embeddings = token_embeddings + position_embeddings # broadcast *4
        embeddings = nn.LayerNorm(epsilon=1e-12)(embeddings)
        embeddings = nn.Dropout(0.5, deterministic=eval)(embeddings)
        return embeddings

入力データ input_ids は、次のように、トークンに分割されたテキストを縦に積み重ねた形になります。

array([[  101,  1045,  2134,  2102,  2514, 26608,   102,    0, ... ,0],
       [  101,  1045,  2064,  2175,  2013,  3110,  2061, 20625,... ,0],
       ...
       [  101, 10047,  9775,  1037,  3371,  2000,  2695,  1045,... ,0],
      ])

ここでまず重要なのが *1 の部分で、先ほどの図の「Input Embedding」と書かれた箱にあたります。これは、入力データに含まれるそれぞれのトークンを高次元空間(埋め込み空間)のベクトルにマッピングした「埋め込み表現」に変換します。次の図は、word2vecの説明でよく登場するもので、2次元空間に複数の単語がマッピングされており、単語の意味が近いものは、この空間内で近い位置に存在します。

このマッピングは、それぞれの単語に2次元空間の座標(2個の実数値)を割り当てるので、全部で[単語数, 2] サイズのデータとして表現されます。*1 では、self.vocab_size (= 30522) 個のトークンを self.embed_dim (= 512) 次元のベクトルにマッピングするので、[30522, 512] サイズのデータが用意されます。

ただしこの段階では、用意されたデータはランダムに初期化されたもので、単語の意味に応じたマッピングが行われるわけではありません。この後、学習処理を行うことで、なんらかの意味で有用なマッピングが得られることになります。

次に、先ほどの図の「Positional Encoding」と書かれた箱を付け加えます。この箱は、1文に含まれるそれぞれのトークンに対して、「頭から何個目のトークン(単語)であるか」という情報を追加します。具体的には、「頭からn個目のトークンである」という情報を表す(埋め込み空間の)位置情報ベクトルを n=1,2,... について用意して、先ほど得られた各トークンの埋め込み表現(ベクトル)に対して(テキスト内でのトークンの位置に応じて)対応する位置情報ベクトルを加えます。ベクトルを足してしまうと、トークンの意味情報と位置情報が混じり合って区別できなくなる気がするかも知れませんが、そこは問題ありません。今の場合、512次元のベクトルを用いているので、仮に、先頭の500次元で意味情報を表して、残りの12次元で位置情報を表せば、情報が混じり合うことはありません。実際には、このような明確な区別はありませんが、学習による最適化処理の中で、自然に情報が分離されるような表現が得られるものと期待します。

オリジナルの論文では、位置情報ベクトルとして、次のような数学関数で計算されるものを用いていました。

一方、今回の実装では、位置情報ベクトルも学習対象のパラメーターとしてあります。具体的には、*2 で、[0, 1, 2,...,127](1文に含まれるトークン数と同じ長さの数列)を用意して、*3 で、それぞれの値を(擬似的にトークンとみなして)self.embed_dim (= 512) 次元のベクトルにマッピングします。これを *1 で用意した埋め込み表現のベクトルに加えます(*4)。

なお、*4 のコメントに # broadcast とあるのは、次の理由によります。input_ids に N 個のテキストを入力した場合、*1 の出力は、[N, トークン数, 埋め込み空間の次元] = [N, 128, 512] サイズですが、*3 で用意される位置情報ベクトルは、[トークン数, 埋め込み空間の次元] = [128, 512] サイズであり、*4 は、サイズの異なるデータの足し算になっています。このような場合、NumPy の Array オブジェクトと同様にブロードキャスト処理が行われて、N 個のそれぞれのデータに対して、個別に *3 のデータが加えられます。次の例を参照してください。

a = np.array(
    [[ 1,  2,  3,  4,  5],
     [ 6,  7,  8,  9, 10],
     [11, 12, 13, 14, 15]])
b = np.array([1, 1, 1, 1, 1])
a + b # broadcast
#### output ####
array([[ 2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16]])

コードの残りの部分は、レイヤー正規化とドロップアウト層を適用していますが、この部分は、大規模モデルで学習をスムーズに進めるための一般的なテクニックです。

それでは最後に、学習対象のパラメーターの構造と、入出力データのサイズを確認しておきます。まず、パラメーターの構造は次になります。

variables = Embeddings(embed_dim=512).init(random.PRNGKey(0), train_text[:1], eval=True)
jax.tree_util.tree_map(lambda x: x.shape, variables['params'])
#### output ####
FrozenDict({
    Embed_0: {
        embedding: (30522, 512), # *1
    },
    Embed_1: {
        embedding: (128, 512), # *2
    },
    LayerNorm_0: {
        bias: (512,),
        scale: (512,),
    },
})

*1 は、self.vocab_size (= 30522) 個のトークンを self.embed_dim (= 512) 次元のベクトルにマッピングする、[30522, 512] サイズのデータで、*2 は、[0, 1, 2,...,127](1文に含まれるトークン数と同じ長さの数列)を self.embed_dim (= 512) 次元のベクトルにマッピングする、[128, 512] サイズのデータに当たります。

入出力データのサイズは次になります。

input_text = train_text[:3]
output = Embeddings(embed_dim=512).apply(variables, input_text, eval=True)
input_text.shape, output.shape
#### output ####
((3, 128), (3, 128, 512))

ここでは、3個のテキストを含むバッチを入力しており、1つのテキストは128個のトークンを持つので、入力データは [3, 128] サイズです。それぞれのトークンが 512 次元の埋め込み空間のベクトルに置き換えられるので、出力データは、[3, 128, 512] サイズになります。

パート2に続く。。。。

enakai00.hatenablog.com