{src/tgt/memory}_[key_padding_]mask の合計 6 種類がある
source, target, memory について
- src: Encoder の self-attention
- tgt: Decoder の self-attention
- memory: Decoder の source-target attention (cross-attention)
mask と key_padding_mask の違い
Attention の計算は
attention = query * key.T
output = softmax(attention) * value
となっている
- 厳密には attention は で割られて normalize されているが
この計算式において、
- attention をマスクするのが mask
- 入力のうち padding のもの(実際には中身のないトークン)を mask するのが key_padding_mask
である。以下でそれぞれをより詳しく説明する
mask
additive mask である; すなわち、attention に足される。よって attention の計算式は実際には以下になる
attention = query * key.T + mask
(query および key).shape = (seq, enc_dim) より
(query*key.T).shape = (seq,seq) したがって
mask.shape = (seq,seq) である
mask[i][j]=-infにすることにより、出力のi(out ofseq) 番目について value のj番目を使わせないことができる- 答えがリークしないよう隠すときに使う。例えば のときに mask することで、出力が未来を参照せず計算することを保証できる (causal attention)
- pytorch では、型が boolean の場合、
Trueの際に隠す。すなわち-float.infと同じ効果を得る
key_padding_mask
入力シーケンスのうち padding であるものを隠すのに使う
- これは batch ごとに異なるであろう。そのため、pytorch ではサイズは
(batch_size, seq)
内部的には mask とマージされているのでやっていることは一緒
- 入力・出力の関係に基づいて mask することができない代わりに、batch ごとに違う padding に基づいて mask できる