torch.gather(source, dim, index, *, sparse_grad=False, out=None) -> Tensor指定された次元 dim に沿って source から index 番目を「集めてくる」関数。
indexを作る方法としては torch.topk などがある
例えば 3 次元のとき
out[i][j][k] = source[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = source[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = source[i][j][index[i][j][k]] # if dim == 2となるという。さらに以下が成立する。
sourceandindexmust have the same number of dimensions. It is also required thatindex.size(d) <= source.size(d)for all dimensionsd != dim.outwill have the same shape asindex. Note thatsourceandindexdo not broadcast against each other.
ややこしくて全然わからんので例を出して考える。
3 次元の例
例えば、最後の行である
out[i][j][k] = source[i][j][index[i][j][k]] # if dim == 2の例を考える。
source と index は同じ次元数でなければならない。今回は 3 次元の場合を考えている。
どちらも 3 次元なので、index の shape を (A,B,C)、source の shape を (X,Y,Z) とする。
このとき以下が成り立つ
i,j,kは index 全体をイテレートするので、出力であるoutの shape は(A,B,C)になる。- 右辺で out of index を起こさないために、
A<=X,B<=Y,max(index)<=Zである必要がある。- 普通は
A=X,B=Yの状況で使うだろうけど
- 普通は
別の言い方をすると、torch.gather は以下の 3 重 for ループ
for i in range(A):
for j in range(B):
for k in range(C):
out[i][j][k] = source[i][j][index[i][j][k]]を並列に実行する関数である。
1 次元の例
もっと簡単にして、1 次元バージョンで考えてみる。
for k in range(C):
out[k] = source[index[k]]ここまですればさすがにわかりやすい。index 番目を持ってきているだけである。
dim で指定されているもの以外の次元(上の三次元の例で言うと i と j の次元)は、out, source, index で同じようにイテレートしているだけだ。
NOTE
公式ドキュメントでは
sourceがinputと呼ばれているが、これは予約語だしindexと語感が似ていて紛らわしいのでsourceに変えた。
cf. torch.index_add: 足していくバージョン。ただし、gather は index 番目「から集めてくる」ものであるのに対し、こちらは index 番目「に足す」ものである