Attention Is All You Need

Attention Is All You Need [paper] は、エンコーダが「ソース系列」を「前後の文脈を反映した」系列に変換し、それを受けた下でデコーダが別の「ターゲット系列」を「1 単語先にシフトした系列」に変換するように訓練する Seq2Seq モデル "Transformer" を提案する論文である。Google Brain の Ashish Vaswani らによって著され (他に Google Research、University of Toronto 所属の著者も)、2017 年の NeurIPS に採択された。

Transformer は推論時には対象の「ソース系列」をエンコードした下で、最初は文頭トークンのみの「ターゲット系列」から順に次の単語をデコードしていく。Transformer のエンコーダ・デコーダは系列の処理に再帰も畳込みも用いず、"Scaled Dot-Product Attention" を内包する "Multi-Head Attention" 層と、ステップごとの線形変換のみを用いる。

Transformer のイラスト (原論文等を元に私が描画)

参考文献

paper
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, Illia Polosukhin. Attention is all you need. Advances in Neural Information Processing Systems (NeurIPS 2017), vol. 30, pp. 5998−6008, 2017.

Transformer の構造

"Scaled Dot-Product Attention" は、全て「長さ L_in の d_k 次元のベクトル列」である Q, K, V に対し、Attention(Q, K, V) := softmax(QK^T/√d_k)V と定義される (原論文では d_k=64)。よって、出力も「長さ L_in の d_k 次元のベクトル列」である。

"Multi-Head Attention" 層は、「長さ L_in の n_head×d_k 次元のベクトル列」(原論文では n_head=8) である入力系列をステップごとに 3×n_head×d_k 次元に線形変換して Q, K, V 用に 3 分割し、さらにそれぞれ n_head ヘッドに分割して、全ヘッド並列に Attention(Q, K, V) を適用し、n_head ヘッドの出力を連結して (n_head×d_k 次元)、改めて各ステップを n_head×d_k 次元に線形変換する。よって、出力も「長さ L_in の n_head×d_k 次元のベクトル列」である。

なお、「ソース系列」/「ターゲット系列」を埋め込んでエンコーダ / デコーダに入力する前に、固定の "Positional Encoding" を各ステップに加算する (系列の位置情報をもたせるため)。

Transformer 全体

以下の 3 箇所の (※) における語彙数×512次元の重み行列は共有する (Weight tying)。ただし、読み出し側では転置する。また、エンベディング側では重みに √512 を乗じる。

  • 「ソース系列」(トークン ID 列) をステップごとに語彙数次元 → 512 次元にエンベディングし (※)、位置エンコーディングを加算し、エンコーダ入力とする。
  • エンコーダ入力にエンコーダを適用してエンコーダ出力を得る。
  • 「ターゲット系列」(トークン ID 列) をステップごとに語彙数次元 → 512 次元にエンベディングし (※)、位置エンコーディングを加算し、デコーダ入力とする。
    • 訓練時は、例えば正解系列が「A B C D」のとき、「ターゲット系列」は「<BOS> A B C」とし、デコーダが「A B C D」を出力することを目指す (ために各ステップの出力の交差エントロピーの平均の最小化を目指す)。
    • 推論時は、「ターゲット系列」はまず「<BOS>」のみとして最初のトークン「A」を生成し、次に「<BOS> A」を入力して次のトークン「A B」を生成する、という操作を繰り返す。
  • エンコーダ出力の下に、デコーダ入力にデコーダを適用してデコーダ出力を得る。
  • デコーダ出力を 512 次元 → 語彙数次元に線形変換して読み出し (※)、ステップごとに softmax して各ステップのトークンの確率分布を得る。推論時はビームサーチによって確率の大きいトップ 4 本の系列のデコードを継続する。

原論文での位置エンコーディング

  • 位置 pos (0-indexed)、次元インデックス i (0-indexed) に対して、512 次元のベクトルの各成分を以下で定義する [脚注PE]。また、位置エンコーディングの加算後にはドロップアウトする。
    • 偶数インデックス 2i: sin(pos / 10000^(2i/512))
    • 奇数インデックス 2i+1: cos(pos / 10000^(2i/512))

原論文でのエンコーダ

  • エンコーダ層 (以下) を 6 層積み重ねる [脚注EN]。
    • 各ステップが 512 次元の入力系列を受け取る (ア)。
    • マルチヘッドアテンション層 (以下) を適用する。
      • 入力系列をステップごとにQ, K, V 用に 3×8×64 次元に線形変換して、8 ヘッドに分割する。
      • 8 ヘッド並列に Attention(Q, K, V) := softmax(QK^T/√d)V を適用する。
      • 各ヘッドからの出力を 8×64=512 次元に連結する。
      • 512 次元に線形変換する。
    • ドロップアウトする。
    • (ア) から残差接続する。
    • 層正規化する。(イ)
    • 2048 次元に線形変換、ReLU 活性化 [脚注ReLU]、ドロップアウトする。
    • 再び 512 次元に線形変換、ドロップアウトする。
    • (イ) から残差接続する。
    • 層正規化する。

原論文でのデコーダ

  • デコーダ層 (以下) を 6 層積み重ねる。
    • 各ステップが 512 次元の入力系列を受け取る (ア)。
    • マスク付きマルチヘッドアテンション層 (以下) を適用する。
      • 入力系列をステップごとに 3×8×64 次元に線形変換する。
      • Q, K, V 用に 3 分割して、さらにそれぞれ 8 ヘッドに分割する。
      • 8 ヘッド並列に MaskedAttention(Q, K, V) := softmax(mask(QK^T)/√d)V を適用する。ただし、mask(QK^T) は QK^T の対角線を含めない右上三角を $-\infty$ にする操作である。
        • このマスクは、訓練時に「『<BOS>』から『A』を予測」「『<BOS> A』から『- B』を予測」「『<BOS> A B』から『- - C』を予測」… を一挙に行えるようにするための措置である。このマスクによって、出力の 1 ステップ目は「<BOS>」までから、出力の 2 ステップ目は「<BOS> A」までから、出力の 3 ステップ目は「<BOS> A B」までから生成されるようにできる。このマスクを導入しない場合に比べモデルの表現力は制限されるが、このマスクを導入することにより各ステップの予測用のサンプルデータを別々に用意して別々に予測する必要がなくなり、訓練時間を短縮できる。なお、Transformer ではこの一挙訓練のために同じ文章の各ステップの予測は強制的に同じバッチになる。
      • 各ヘッドからの出力を 8×64=512 次元に連結する。
      • 512 次元に線形変換する。
    • ドロップアウトする。
    • (ア) から残差接続する。
    • 層正規化する。(イ)
    • クロスマルチヘッドアテンション層 (以下) を適用する。
      • (イ) から Q 用に 8×64 次元に線形変換して 8 ヘッドに分割する。
      • エンコーダ最終出力を K, V 用に 2×8×64 次元に線形変換して 8 ヘッドに分割する。
      • 8 ヘッド並列に Attention(Q, K, V) を適用する。
      • 各ヘッドからの出力を 8×64=512 次元に連結する。
      • 512 次元に線形変換する。
    • ドロップアウトする。
    • (イ) から残差接続する。
    • 層正規化する。(ウ)
    • 2048 次元に線形変換、ReLU 活性化、ドロップアウトする。
    • 再び 512 次元に線形変換、ドロップアウトする。
    • (ウ) から残差接続する。
    • 層正規化する。

トイモデル

以下のトイ Transformer は Positional Encoding を加算するところから Output Projection の手前までである (Weight tying は含まない)。動作確認では埋め込み次元数 6、中間層次元数 12、ヘッド数 2、エンコーダ層数 2、デコーダ層数 2 でインスタンス化し、長さ 8 の「ソース系列」に基づき長さ 4 の「ターゲット系列」をシフトしている。

import torch
from torch.nn import TransformerEncoder as Encoder
from torch.nn import TransformerEncoderLayer as EncoderLayer
from torch.nn import TransformerDecoder as Decoder
from torch.nn import TransformerDecoderLayer as DecoderLayer
import math


class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_m, len_max=100):
        super().__init__()
        pe = torch.zeros(len_max, d_m)
        pos = torch.arange(0, len_max).unsqueeze(1)  # (max_len, 1)
        div = torch.exp(torch.arange(0, d_m, 2) * (-math.log(10000.0) / d_m))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div[:pe[:, 1::2].shape[1]])
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, len_max, d_m)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]  # (batch, seq_len, d_m)


class ToyTransformer(torch.nn.Module):
    def __init__(self, d_m, n_head, d_ff, p_drop=0.1):
        super().__init__()
        self.pos_enc = PositionalEncoding(d_m)
        self.dropout = torch.nn.Dropout(p_drop)
        common_params = {
            'd_model': d_m,
            'nhead': n_head,
            'dim_feedforward': d_ff,
            'batch_first': True,
        }
        self.encoder = Encoder(EncoderLayer(**common_params), num_layers=2)
        self.decoder = Decoder(DecoderLayer(**common_params), num_layers=2)

    def forward(self, src, tgt):
        # src: (b, len_in, d_m)
        # tgt: (b, len_out, d_m)
        len_out = tgt.size(1)
        mask = torch.nn.Transformer.generate_square_subsequent_mask(len_out)
        encoded = self.encoder(self.dropout(self.pos_enc(src)))
        decoded = self.decoder(self.dropout(self.pos_enc(tgt)), encoded, tgt_mask=mask)
        return decoded


if __name__ == '__main__':
    d_m = 6
    n_head = 2
    d_ff = 12
    model = ToyTransformer(d_m, n_head, d_ff)

    # パラメータ数の確認
    n_param = sum(p.numel() for p in model.parameters())
    n_param_enc = sum(p.numel() for p in model.encoder.layers[0].parameters())
    n_param_dec = sum(p.numel() for p in model.decoder.layers[0].parameters())
    print(f'パラメータ総数: {n_param}')
    print(f'encoder層のパラメータ数: {n_param_enc}')
    print(f'decoder層のパラメータ数: {n_param_dec}')
    assert n_param == 2 * n_param_enc + 2 * n_param_dec

    # 入出力テンソルサイズの確認
    src = torch.randn(1, 8, 6)
    tgt = torch.randn(1, 4, 6)
    with torch.no_grad():
        out = model(src, tgt)
    assert out.shape == (1, 4, 6)

パラメータ総数: 1776
encoder層のパラメータ数: 354
decoder層のパラメータ数: 534

備考

  • 英独翻訳タスクなら、Transformer のエンコーダは元の英文 (単語特徴列) を前後の文脈を反映した特徴列とし (エンコーダ表現)、デコーダは現時点でデコードされている独文を、まず過去からの文脈のみ反映した特徴列として、各ステップがエンコーダ表現の何ステップ目にどれだけ関連するかを反映して更新することを繰り返す (最終的には、出力列の最終ステップを独単語に読み出してデコードを 1 ステップ進める)。エンコーダ層・デコーダ層を積み重ねることで、より複雑な対応関係を段階的に捉えられるのではないかと考えられている。
  • 後年、Transformer エンコーダの各層が担う言語情報の種類を probing [脚注probing] により分析した研究があり、浅い層ほど品詞・構文情報を、深い層ほど意味・共参照 [脚注coref] などの情報を捉える傾向が報告されている (BERT を対象とした分析だが Transformer 一般にも広く引用される) ()。
  • 後年、Transformer のアテンションヘッドの役割を分析した研究があり、一部のヘッドは隣接トークンへの注意・構文依存関係への注意などの特定のパターンに特化しており、残りの大多数は冗長で除去しても性能がほとんど低下しないことが報告されている (, )。
  • 後年、モデルサイズ・データ量・計算量と損失の関係がべき乗則に従うという「スケーリング則」が報告されており、層数や幅といった個別のアーキテクチャ詳細よりもパラメータ総数が性能を規定するという経験的知見が示されている ()。

脚注

ReLU
当時既に GELU () が提案されていたが広まっておらず、Attention Is All You Need では ReLU が採用されている。GELU は BERT (2018) で採用されてから Transformer 系モデルでの標準的な活性化関数として普及した。
PE
10000 は経験的な値である。
EN
原論文のように残差接続後に層正規化する (Post-LN) のではなく、入力直後に層正規化する (Pre-LN) 方が学習が安定するという研究もあり ()、PyTorch の torch.nn.TransformerEncoderLayer には Pre-LN に切り替えるフラグが実装されている。もっとも、Pre-LN に切り替えたとしても、その他の種々の最新の発展を取り込んでいるわけではないので、torch.nn.TransformerEncoderLayer の実利用は既に推奨されていない ()。
probing
probing とは、学習済みモデルの各層の出力ベクトルを入力として、品詞タグなど特定の言語情報を予測する小さな分類器を訓練する分析手法である。その分類器の精度が高ければ、その層がその情報を内部に持っていると解釈する。
coref
共参照 (coreference) とは、文中の複数の表現が同じ実体を指すことである。例えば「太郎が来た。彼は疲れていた」において「太郎」と「彼」は同じ実体を指している。