LSTM で足し算する

参考文献

  1. [1803.01271]An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling ;TCNの原論文。
  2. locuslab/TCN: Sequence modeling benchmarks and temporal convolutional networks ;TCNの原論文のリポジトリ。
  3. https://pytorch.org/docs/master/generated/torch.nn.LSTM.html ;torch.nn.LSTM のリファレンス。
  4. An Empirical Exploration of Recurrent Network ArchitecturesKerasのドキュメントの「初期化時に忘却ゲートのバイアスに1を加えます.」の箇所からリンクがある論文。
参考文献 1. の13ページ目の Table 3. に、この論文の提案手法である TCN のライバル手法にされている LSTM のコンフィギュレーションがありますよね。例えば1番最初の行は Adding Problem タスクの T=200 の場合ですが、n=2, Hidden=77, Dropout=0.0, Grad Clip=50, bias=5.0 とありますが、Dropout=0.0 はいいとして、それ以外は何を指しているのでしょう…?
最初の n=2 についてだけど、論文の6ページ目に the depth of the network n とあるよね。ただこれは TCN の TemporalBlock の積み重ね数だけど、LSTM でも同じ文字 n を使っているのは、LSTM ブロックの積み重ね数だと考えていいんじゃないかな。Hidden=77 は各 LSTM ブロックの入出力の次元数でいいと思うよ(但し1つ目の LSTM ブロックの入力次元数については入力データの次元数)。だって、著者のリポジトリで nhid とかかれているパラメータ、各 TemporalBlock の入出力次元数にセットされているからね。 Grad Clip=50 は訓練時に勾配ベクトルのノルムが50をはみ出したら50になるように縮めるという意味だからちょっと置いておこう。bias=5.0 は論文の6ページ目に initial forget-gate bias とあるからこのことだと思う。忘却ゲートのバイアスの初期値が重要だという話は調べると結構出てくるね。きちんと話を追えていないけど、でも、仮にもし忘却ゲートの重みやバイアスの初期値がゼロだったら記憶セルの最適化って全く進まないよね。記憶セルなんてなかったんだって感じで学習が進んじゃう。だから、「記憶セルがあることが前提で学習しなさい」と伝えるためにわざと最初に値を入れておくという感じだと思う。…まあそれで、今回の LSTM を実装すると以下のようになるかな。torch.nn.LSTM は引数 num_layers に指定した数だけ積み重ねられるからこれを使おう。
In [1]:
import torch.nn as nn
from collections import OrderedDict

def _debug_print(debug, *content):
    if debug:
        print(*content)

class myLSTM(nn.Module):
    def __init__(self,
                 input_size=2,   # 足し算タスクなので入力は2次元
                 output_size=1,  # 足し算タスクなので出力は1次元
                 num_layers=2,   # LSTM ブロックの積み重ね数が2
                 d_hidden=77,    # 各 LSTM ブロックの出力次元数が77
                 initial_forget_gate_bias=5.0,  # 忘却ゲートのバイアスの初期値
                 dropout=0.0):
        super(myLSTM, self).__init__()
        self.num_layers = num_layers
        self.d_hidden = d_hidden
        self.layers = OrderedDict()
        self.layers['lstm'] = nn.LSTM(input_size, d_hidden,
                                      num_layers=num_layers,
                                      dropout=dropout)
        self.layers['linear'] = nn.Linear(d_hidden, output_size)
        self.network = nn.Sequential(self.layers)
        self.init_weights(initial_forget_gate_bias)

    def init_weights(self, initial_forget_gate_bias):
        # 忘却ゲートのバイアスの初期値をセット
        for i_layer in range(self.num_layers):
            bias = getattr(self.layers['lstm'], f'bias_ih_l{i_layer}')
            bias.data[self.d_hidden:(2*self.d_hidden)] = initial_forget_gate_bias
            bias = getattr(self.layers['lstm'], f'bias_hh_l{i_layer}')
            bias.data[self.d_hidden:(2*self.d_hidden)] = initial_forget_gate_bias
        self.layers['linear'].weight.data.normal_(0, 0.01)

    def forward(self, x, hidden, debug=False):
        _debug_print(debug, '========== forward ==========')
        _debug_print(debug, x.size())
        out, hidden = self.layers['lstm'](x, hidden)
        _debug_print(debug, out.size())
        _debug_print(debug, hidden[0].size())
        _debug_print(debug, hidden[1].size())
        x = self.layers['linear'](hidden[0][-1,:,:])
        _debug_print(debug, x.size())
        _debug_print(debug, '=============================')
        return x, hidden

model_lstm = myLSTM()
print('◆ モデル')
print(model_lstm)
print('◆ 学習対象パラメータ')
for name, param in model_lstm.named_parameters():
    print(name.ljust(14), param.size())
◆ モデル
myLSTM(
  (network): Sequential(
    (lstm): LSTM(2, 77, num_layers=2)
    (linear): Linear(in_features=77, out_features=1, bias=True)
  )
)
◆ 学習対象パラメータ
network.lstm.weight_ih_l0 torch.Size([308, 2])
network.lstm.weight_hh_l0 torch.Size([308, 77])
network.lstm.bias_ih_l0 torch.Size([308])
network.lstm.bias_hh_l0 torch.Size([308])
network.lstm.weight_ih_l1 torch.Size([308, 77])
network.lstm.weight_hh_l1 torch.Size([308, 77])
network.lstm.bias_ih_l1 torch.Size([308])
network.lstm.bias_hh_l1 torch.Size([308])
network.linear.weight torch.Size([1, 77])
network.linear.bias torch.Size([1])
えっと、パラメータの次元数が、 308?
77 × 4 = 308 だね。2次元の入力を77次元にするには右からサイズ [77, 2] の行列をかければいいけど、LSTM は「入力ゲート」「忘却ゲート」「通常のRNNの重み」「出力ゲート」があるから 77 が4倍になる。あと LSTM のバイアスを表示してみるね。忘却ゲートのバイアスに該当する 77~153 次元目に 5.0 を代入しただけだけどこんなんでいいのかな…Keras のドキュメントには忘却ゲートのバイアスを1にして、逆に忘却ゲートのバイアス以外は0にするといいというようにあったけど、とりあえず論文に言及があるのは忘却ゲートのバイアスだけだから他は放っておいた。
In [2]:
print(model_lstm.layers['lstm'].bias_ih_l0)
Parameter containing:
tensor([ 4.9165e-02,  5.3896e-02, -4.8629e-02,  9.2515e-02,  2.1954e-02,
         9.4397e-02,  9.4592e-02, -6.4292e-02,  3.2555e-02, -2.1142e-02,
        -8.1746e-02,  9.1270e-02,  5.8425e-02,  3.5523e-02,  8.8397e-02,
        -7.2806e-02, -9.3471e-02,  4.9092e-02,  7.1229e-03, -5.1445e-02,
        -8.1698e-02, -3.8696e-02, -2.5925e-02, -9.2030e-02,  1.0211e-01,
        -9.0567e-02, -1.0435e-01,  1.0762e-01,  5.9898e-02,  3.2932e-02,
        -5.2855e-02,  5.6225e-02,  4.4851e-02, -1.0331e-01, -7.9663e-02,
        -9.0007e-02,  3.9228e-02, -4.3425e-02, -7.9913e-02,  9.2957e-02,
        -2.9188e-02, -6.8715e-02,  6.6197e-02, -5.7450e-02, -5.3279e-02,
         4.1563e-02, -8.3667e-02,  7.3850e-02, -2.8193e-02,  6.4358e-02,
        -4.1299e-02, -8.7524e-02,  7.3115e-02, -6.8227e-02, -1.9827e-02,
         9.8330e-02, -1.0648e-01, -4.2002e-02,  3.3780e-02, -3.0554e-02,
         4.5411e-02,  6.9738e-02, -4.2859e-02, -6.1650e-02,  4.4743e-02,
        -2.4768e-02, -8.5655e-02, -2.8372e-02, -8.7389e-02,  1.0232e-01,
         3.9674e-02, -5.2236e-02,  1.0560e-01, -9.7303e-02, -7.2981e-02,
        -9.2142e-02, -7.0825e-02,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00,
         5.0000e+00,  5.0000e+00,  5.0000e+00,  5.0000e+00, -6.2721e-02,
        -5.8241e-02, -1.6867e-02, -4.4569e-02, -2.7921e-02,  7.4363e-02,
        -4.2447e-02, -5.5075e-02,  1.6161e-02, -2.4220e-02,  7.7675e-02,
        -6.0809e-02,  5.4841e-02,  8.6750e-02,  3.6424e-02,  2.8048e-02,
        -5.3524e-02,  3.5246e-02, -1.0170e-01,  1.7603e-02,  1.8143e-02,
        -8.9160e-02,  1.0085e-01,  1.0683e-01, -1.3512e-02, -5.2527e-02,
         6.8354e-02, -1.0891e-02,  2.8305e-02,  2.8064e-02,  1.7084e-02,
         7.6924e-02, -4.3729e-02,  1.1172e-01, -1.1196e-01,  4.4700e-02,
         8.4991e-02, -8.6713e-02, -1.3968e-02,  4.3707e-02, -8.8109e-02,
         6.2691e-02,  8.2722e-02, -6.4417e-02, -9.9627e-03,  4.9382e-02,
        -3.7137e-03, -3.1076e-02,  9.7192e-02, -4.7260e-02,  7.8700e-02,
         9.8256e-02, -2.5228e-02,  8.1686e-02,  5.6391e-02,  3.2331e-02,
         8.2114e-02, -3.0936e-02, -8.8276e-02,  1.8504e-02, -9.9111e-02,
        -1.2504e-02,  5.6677e-02, -2.6727e-02,  5.9095e-02, -9.7887e-02,
         7.5615e-02,  2.5467e-02, -2.1051e-02, -2.2644e-03,  3.6444e-03,
         1.1376e-01,  1.0347e-02,  8.6061e-02, -1.5429e-02,  9.6989e-02,
         1.3893e-02, -8.8932e-03, -9.1217e-02, -1.8536e-03, -1.1309e-01,
        -5.9436e-02,  6.8963e-02, -7.6549e-02, -4.3315e-02, -1.0767e-01,
         1.3827e-02,  3.8573e-02, -3.2789e-02, -4.6074e-02,  4.1241e-02,
        -7.4975e-02, -2.8728e-02,  3.2466e-02, -5.4221e-02, -2.2455e-02,
        -8.3136e-02,  7.5368e-02, -6.2805e-02, -7.6103e-02,  1.8733e-03,
        -7.1987e-02, -1.0046e-01, -4.7820e-02,  6.5614e-02, -6.4605e-03,
         1.1242e-01,  8.9668e-02, -5.1174e-02, -5.9855e-02, -8.3304e-02,
         1.6023e-03,  1.7028e-03,  1.0380e-03,  2.0233e-02, -8.6717e-02,
         4.5389e-02, -8.3116e-02,  3.1499e-03,  1.8729e-02, -8.1673e-02,
        -1.0322e-01,  2.6273e-02,  8.2414e-02, -9.9420e-02,  1.1180e-01,
         1.0178e-01,  3.9553e-02,  3.0256e-02, -3.9064e-02, -4.5343e-02,
        -8.0867e-02,  3.5570e-02,  1.0058e-01,  1.0549e-01,  7.8596e-02,
        -9.6782e-02,  7.9042e-02, -3.2222e-02, -7.1313e-02, -1.9761e-02,
         5.7132e-03,  7.6312e-02, -9.1598e-02, -8.6944e-02,  6.7028e-03,
        -6.1358e-02,  9.3485e-04,  2.1562e-02,  2.2943e-02,  6.1359e-02,
        -9.1821e-02, -2.2169e-03, -8.0288e-02], requires_grad=True)
それで、実際に学習できるんですか?
たぶん以下のようになると思うけど(雑だけど)。忘却ゲートのバイアスが5の方が0のときよりロスは小さいみたい?
In [3]:
import numpy as np
import torch

from torch.autograd import Variable
torch.manual_seed(1)

# TCN向けの足し算データを生成する関数
# ソース https://github.com/locuslab/TCN/blob/master/TCN/adding_problem/utils.py
# out: torch.Size([N, 2, seq_length]), torch.Size([N, 1])
def data_generator(N, seq_length):
    X_num = torch.rand([N, 1, seq_length])
    X_mask = torch.zeros([N, 1, seq_length])
    Y = torch.zeros([N, 1])
    for i in range(N):
        positions = np.random.choice(seq_length, size=2, replace=False)
        X_mask[i, 0, positions[0]] = 1
        X_mask[i, 0, positions[1]] = 1
        Y[i,0] = X_num[i, 0, positions[0]] + X_num[i, 0, positions[1]]
    X = torch.cat((X_num, X_mask), dim=1)
    return Variable(X), Variable(Y)

# TCN向けのデータをLSTM向けに転置する関数
# out: torch.Size([seq_length, N, 2]), torch.Size([N, 1])
def copy_for_lstm(x):
    x_ = x.clone().detach()
    x_ = x_.transpose(0, 2).transpose(1, 2).contiguous()
    return x_
In [4]:
# LSTMの隠れ状態と記憶セルの初期値を作成
num_layers = 2
batch_size = 100
d_hidden = 77
hidden0 = torch.zeros([num_layers, batch_size, d_hidden])
cell0 = torch.zeros([num_layers, batch_size, d_hidden])

seq_length = 200
data = data_generator(batch_size, seq_length)
x = copy_for_lstm(data[0])

hidden = (hidden0.clone().detach(), cell0.clone().detach())
out, hidden = model_lstm.forward(x, hidden, debug=True)
print('\n◇ out = 足し算結果(ネットワークを学習していないので足し算にはなっていない、次元だけ確認)')
print(out.size())
print(out[:10])
print('◇ hidden = 流した後の隠れ状態と記憶セル')
print(hidden[0].size(), hidden[1].size())
print(hidden)
========== forward ==========
torch.Size([200, 100, 2])
torch.Size([200, 100, 77])
torch.Size([2, 100, 77])
torch.Size([2, 100, 77])
torch.Size([100, 1])
=============================

◇ out = 足し算結果(ネットワークを学習していないので足し算にはなっていない、次元だけ確認)
torch.Size([100, 1])
tensor([[-0.0025],
        [-0.0088],
        [-0.0166],
        [-0.0055],
        [-0.0239],
        [-0.0173],
        [ 0.0040],
        [-0.0158],
        [-0.0021],
        [-0.0226]], grad_fn=<SliceBackward>)
◇ hidden = 流した後の隠れ状態と記憶セル
torch.Size([2, 100, 77]) torch.Size([2, 100, 77])
(tensor([[[ 0.4744, -0.3614, -0.5132,  ..., -0.4230,  0.4121,  0.4119],
         [ 0.4572, -0.3583, -0.5197,  ..., -0.4263,  0.4100,  0.3948],
         [ 0.4590, -0.3679, -0.5241,  ..., -0.4239,  0.4007,  0.3892],
         ...,
         [ 0.4665, -0.3798, -0.4826,  ..., -0.4065,  0.3867,  0.4315],
         [ 0.4696, -0.3296, -0.5443,  ..., -0.4045,  0.3817,  0.4176],
         [ 0.4625, -0.3428, -0.5251,  ..., -0.4157,  0.4224,  0.4039]],

        [[ 0.5115, -0.5246, -0.5442,  ..., -0.3999, -0.4695,  0.3655],
         [ 0.5027, -0.5164, -0.5433,  ..., -0.4205, -0.4633,  0.3725],
         [ 0.5146, -0.4881, -0.5898,  ..., -0.3399, -0.5100,  0.3577],
         ...,
         [ 0.4788, -0.5288, -0.5811,  ..., -0.4078, -0.5186,  0.4486],
         [ 0.4448, -0.5565, -0.5365,  ..., -0.3549, -0.4890,  0.3322],
         [ 0.4976, -0.5203, -0.5409,  ..., -0.3922, -0.4817,  0.3674]]],
       grad_fn=<StackBackward>), tensor([[[ 25.3851, -19.5185,  -2.2148,  ...,  -7.0757,  21.8752,  21.2672],
         [ 24.9314, -18.4482,  -2.1745,  ...,  -7.2815,  21.8010,  21.4574],
         [ 26.8274, -21.9711,  -2.5544,  ...,  -6.7963,  23.2875,  20.9006],
         ...,
         [ 29.8533, -22.5231,  -2.4401,  ...,  -3.5235,  24.2307,  21.8897],
         [ 24.6162, -15.4336,  -4.5720,  ...,  -5.4434,  18.8222,  23.8761],
         [ 24.6132, -18.0139,  -2.4366,  ...,  -7.3646,  21.3561,  21.7261]],

        [[ 17.5962, -28.9121, -15.6031,  ...,  -8.2359, -13.6865,   4.7207],
         [ 17.1507, -29.2552, -15.5106,  ...,  -7.4294, -13.2211,   5.2530],
         [ 17.0677, -22.5658, -15.9626,  ...,  -9.6145, -23.2147,   3.4630],
         ...,
         [ 28.1819, -33.5668, -22.2365,  ...,  -8.5320, -28.6113,   3.7873],
         [  6.8183, -14.3391, -26.0844,  ...,  -7.9752, -11.1114,  12.1419],
         [ 17.1735, -25.4016, -15.4087,  ...,  -7.9439, -11.6768,   5.0952]]],
       grad_fn=<StackBackward>))
In [5]:
import torch.optim as optim
import torch.nn.functional as F


def train(model_lstm):
    optimizer = optim.SGD(model_lstm.parameters(), lr=0.001)
    grad_clip = 50.0
    total_loss = 0

    for epoch in range(100):
        optimizer.zero_grad()
        # data = data_generator(batch_size, seq_length)
        data = data_generator(100, 100)
        x = copy_for_lstm(data[0])
    
        hidden = (hidden0.clone().detach(), cell0.clone().detach())
        out, hidden = model_lstm.forward(x, hidden)

        loss = F.mse_loss(out, data[1])
        loss.backward()
        if grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(model_lstm.parameters(), grad_clip)
        optimizer.step()
        
        total_loss += loss.item()

    print(total_loss)
In [6]:
model_lstm = myLSTM(initial_forget_gate_bias=0.0)
train(model_lstm)

model_lstm = myLSTM(initial_forget_gate_bias=5.0)
train(model_lstm)
101.7260605096817
26.77291202545166
In [ ]: