【Pytorch】torch.catの使い方・引数を徹底解説!テンソルを結合する方法

torch.catはPyTorchでテンソル(多次元配列)を結合するための関数です。
この関数を使用すると、指定した次元でテンソルを連結することができます。
ドキュメント:torch.cat — PyTorch 2.0 documentation
イメージを掴もう
イメージを掴むと使いやすくなると思うので画像にしました。

torch.catの引数
torch.catは基本的に以下のように使うことができます。
import torch
# サンプルのテンソルを作成します
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
# テンソルを連結します
result = torch.cat((tensor1, tensor2))
print(result)
# 出力:
# tensor([1, 2, 3, 4, 5, 6])では引数について見ていきます。
- tensors
結合するテンソルのリストまたはタプルを指定します。 - dim
結合する次元のインデックスを指定します。
デフォルト値は0です。 - out
結果を格納する出力テンソルです。
指定されない場合は新しいテンソルが作成されます。
dimとoutについて具体例を用いて詳しく解説します。
dimの詳細
dimは連結させる次元を指定できます。
この意味を以下の3次元のテンソル同士を連結させることで確かめます。
m = torch.arange(8).reshape(2, 2, 2)
"""
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
"""
m.size()
# torch.Size([2, 2, 2])dim=0を指定した場合torch.cat(tensors=(m, m), dim=0)
"""
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]],
[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
"""
# torch.Size([4, 2, 2])サイズを見てみると、torch.Size([2, 2, 2])→torch.Size([4, 2, 2])になっているのが分かりますね。つまり、dim=0を指定すると、1番目の次元に連結されるということです。pythonは0から始まるものなので、dim=0と指定するという訳です。
dim=1を指定した場合torch.cat(tensors=(m, m), dim=1)
"""
tensor([[[0, 1],
[2, 3],
[0, 1],
[2, 3]],
[[4, 5],
[6, 7],
[4, 5],
[6, 7]]])
"""
# torch.Size([2, 4, 2])同様に、サイズを見てみると、torch.Size([2, 2, 2])→torch.Size([2, 4, 2])になっています。つまり、dim=1を指定すると、2番目の次元に連結されるということです。
dim=2を指定した場合torch.cat(tensors=(m, m), dim=2)
"""
tensor([[[0, 1, 0, 1],
[2, 3, 2, 3]],
[[4, 5, 4, 5],
[6, 7, 6, 7]]])
"""
# torch.Size([2, 2, 4])同様に、サイズを見てみると、torch.Size([2, 2, 2])→torch.Size([2, 2, 4])になっています。つまり、dim=2を指定すると、3番目の次元に連結されるということです。
ここまで理解できていれば、dim=-1などの意味も分かってくると思います。
pythonではa = [0, 1, 2]のようなリストがあった場合、a[-1]は2を表しますよね。
同じように考えると、今回はdim=-1はdim=2と等しいということが分かると思います。
連結前のテンソルの次元によって変化するので、その点は注意が必要です。
outの詳細
outにテンソルを指定すると上書きされます。
色々試しているので、動作についてはそこから感じ取ってください。
# 以下のテンソルを連結します
m = torch.arange(8).reshape(2, 2, 2)
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])out=連結前のテンソルを指定した場合torch.cat(tensors=(m, m), out=m)
# RuntimeError: 0unsupported operation:
# the input tensors cannot refer to any of the output memory locations.
# Found overlap in input tensor 0入力テンソルと出力テンソルは異なる必要があります。
out=torch.tensor([])
を指定した場合n = torch.tensor([])
# tensor([])
torch.cat(tensors=(m, m), out=n)
"""
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]],
[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
"""
print(n)
"""
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]],
[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
"""nが上書きされてますね。
また、mはtorch.LongTensorでしたが、連結後はtorch.FloatTensorになっています。
これは、outに指定したテンソルに合わせるためだと思います。
まとめ
dimはどの次元を基準に連結させるかを決めることができるoutにテンソルを指定すると、連結後のテンソルに上書きされる


コメント