【Pytorch】torch.catの意味や引数を詳しく解説!dim=-1を設定した場合は?

torch.cat
はgithubでモデルの構成を見ていると、かなり登場するので色々と調べてみました。
例を使って色々と試しています。
torch.cat
の意味
torch.cat
の役割は、与えられた次元のテンソルの列を連結することです。
catと聞くと猫?となりますが、concatenate(連結する)の略だと考えれば納得ですね。
torch.cat
の引数
では、torch.cat
の引数について見ていきます。
torch.catの引数は、torch.cat(tensors, dim=0, *, out=None)
です。
それぞれの引数の説明を見ていきます。
入力テンソルが入ったlist
またはtuple
です。
注意点として、list
またはtuple
内のテンソルの次元は同じでなくてはいけません。
テンソルが連結される次元を設定できます。
デフォルトはdim=0
です。
出力テンソルです。
では、dim
とout
の説明がこれだけだと分かりにくいと思うので、この2つについて詳しく説明していきます。
dimの詳細
dimは連結させる次元を指定できます。
この意味を以下の3次元のテンソル同士を連結させることで確かめます。
m = torch.arange(8).reshape(2, 2, 2)
"""
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
"""
m.size()
# torch.Size([2, 2, 2])
dim=0
を指定した場合torch.cat(tensors=(m, m), dim=0)
"""
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]],
[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
"""
# torch.Size([4, 2, 2])
サイズを見てみると、torch.Size([2, 2, 2])
→torch.Size([4, 2, 2])
になっているのが分かりますね。つまり、dim=0
を指定すると、1番目の次元に連結されるということです。pythonは0から始まるものなので、dim=0
と指定するという訳です。
dim=1
を指定した場合torch.cat(tensors=(m, m), dim=1)
"""
tensor([[[0, 1],
[2, 3],
[0, 1],
[2, 3]],
[[4, 5],
[6, 7],
[4, 5],
[6, 7]]])
"""
# torch.Size([2, 4, 2])
同様に、サイズを見てみると、torch.Size([2, 2, 2])
→torch.Size([2, 4, 2])
になっています。つまり、dim=1
を指定すると、2番目の次元に連結されるということです。
dim=2
を指定した場合torch.cat(tensors=(m, m), dim=2)
"""
tensor([[[0, 1, 0, 1],
[2, 3, 2, 3]],
[[4, 5, 4, 5],
[6, 7, 6, 7]]])
"""
# torch.Size([2, 2, 4])
同様に、サイズを見てみると、torch.Size([2, 2, 2])
→torch.Size([2, 2, 4])
になっています。つまり、dim=2
を指定すると、3番目の次元に連結されるということです。
ここまで理解できていれば、dim=-1
などの意味も分かってくると思います。
pythonではa = [0, 1, 2]
のようなリストがあった場合、a[-1]
は2を表しますよね。
同じように考えると、今回はdim=-1
はdim=2
と等しいということが分かると思います。
連結前のテンソルの次元によって変化するので、その点は注意が必要です。
outの詳細
outにテンソルを指定すると上書きされます。(多分この説明であっているはずです。間違っていたら指摘お願いします。)
色々試しているので、動作についてはそこから感じ取ってください。
# 以下のテンソルを連結します
m = torch.arange(8).reshape(2, 2, 2)
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
out=連結前のテンソル
を指定した場合torch.cat(tensors=(m, m), out=m)
# RuntimeError: 0unsupported operation:
# the input tensors cannot refer to any of the output memory locations.
# Found overlap in input tensor 0
入力テンソルと出力テンソルは異なる必要があります。
out=torch.tensor([])
を指定した場合n = torch.tensor([])
# tensor([])
torch.cat(tensors=(m, m), out=n)
"""
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]],
[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
"""
print(n)
"""
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]],
[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
"""
nが上書きされてますね。
また、mはtorch.LongTensor
でしたが、連結後はtorch.FloatTensor
になっています。
これは、outに指定したテンソルに合わせるためだと思います。
まとめ
dim
はどの次元を基準に連結させるかを決めることができるout
にテンソルを指定すると、連結後のテンソルに上書きされる

コメント