公式ドキュメント

torch.gather(source, dim, index, *sparse_grad=Falseout=None-> Tensor

指定された次元 dim に沿って source から index 番目を「集めてくる」関数。

  • index を作る方法としては torch.topk などがある

例えば 3 次元のとき

out[i][j][k] = source[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = source[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = source[i][j][index[i][j][k]]  # if dim == 2

となるという。さらに以下が成立する。

source and index must have the same number of dimensions. It is also required that index.size(d) <= source.size(d) for all dimensions d != dimout will have the same shape as index. Note that source and index do not broadcast against each other.

ややこしくて全然わからんので例を出して考える。

3 次元の例

例えば、最後の行である

out[i][j][k] = source[i][j][index[i][j][k]]  # if dim == 2

の例を考える。

source と index は同じ次元数でなければならない。今回は 3 次元の場合を考えている。
どちらも 3 次元なので、index の shape を (A,B,C)、source の shape を (X,Y,Z) とする。
このとき以下が成り立つ

  • i,j,k は index 全体をイテレートするので、出力である out の shape は (A,B,C) になる。
  • 右辺で out of index を起こさないために、A<=X, B<=Y, max(index)<=Z である必要がある。
    • 普通は A=X, B=Y の状況で使うだろうけど

別の言い方をすると、torch.gather は以下の 3 重 for ループ

for i in range(A):
    for j in range(B):
        for k in range(C):
            out[i][j][k] = source[i][j][index[i][j][k]]

を並列に実行する関数である。

1 次元の例

もっと簡単にして、1 次元バージョンで考えてみる。

for k in range(C):
	out[k] = source[index[k]]

ここまですればさすがにわかりやすい。index 番目を持ってきているだけである。
dim で指定されているもの以外の次元(上の三次元の例で言うと ij の次元)は、out, source, index で同じようにイテレートしているだけだ。

NOTE

公式ドキュメントでは sourceinput と呼ばれているが、これは予約語だし index と語感が似ていて紛らわしいので source に変えた。

cf. torch.index_add: 足していくバージョン。ただし、gather は index 番目「から集めてくる」ものであるのに対し、こちらは index 番目「に足す」ものである