{src/tgt/memory}_[key_padding_]mask の合計 6 種類がある

source, target, memory について

  • src: Encoder の self-attention
  • tgt: Decoder の self-attention
  • memory: Decoder の source-target attention (cross-attention)

mask と key_padding_mask の違い

Attention の計算は

attention = query * key.T
output = softmax(attention) * value

となっている

  • 厳密には attention は で割られて normalize されているが

この計算式において、

  • attention をマスクするのが mask
  • 入力のうち padding のもの(実際には中身のないトークン)を mask するのが key_padding_mask

である。以下でそれぞれをより詳しく説明する

mask

additive mask である; すなわち、attention に足される。よって attention の計算式は実際には以下になる

attention = query * key.T + mask

(query および key).shape = (seq, enc_dim) より
(query*key.T).shape = (seq,seq) したがって
mask.shape = (seq,seq) である

  • mask[i][j]=-inf にすることにより、出力の i (out of seq) 番目について value の j 番目を使わせないことができる
  • 答えがリークしないよう隠すときに使う。例えば のときに mask することで、出力が未来を参照せず計算することを保証できる (causal attention)
  • pytorch では、型が boolean の場合、 True の際に隠す。すなわち -float.inf と同じ効果を得る

key_padding_mask

入力シーケンスのうち padding であるものを隠すのに使う

  • これは batch ごとに異なるであろう。そのため、pytorch ではサイズは (batch_size, seq)

内部的には mask とマージされているのでやっていることは一緒

  • 入力・出力の関係に基づいて mask することができない代わりに、batch ごとに違う padding に基づいて mask できる