【Pytorch】torch.catの使い方・引数を徹底解説!テンソルを結合する方法
torch.catはPyTorchでテンソル(多次元配列)を結合するための関数です。
この関数を使用すると、指定した次元でテンソルを連結することができます。
ドキュメント:torch.cat — PyTorch 2.0 documentation
イメージを掴もう
イメージを掴むと使いやすくなると思うので画像にしました。
torch.catの引数
torch.catは基本的に以下のように使うことができます。
import torch
# サンプルのテンソルを作成します
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
# テンソルを連結します
result = torch.cat((tensor1, tensor2))
print(result)
# 出力:
# tensor([1, 2, 3, 4, 5, 6])
では引数について見ていきます。
- tensors
結合するテンソルのリストまたはタプルを指定します。 - dim
結合する次元のインデックスを指定します。
デフォルト値は0です。 - out
結果を格納する出力テンソルです。
指定されない場合は新しいテンソルが作成されます。
dimとoutについて具体例を用いて詳しく解説します。
dimの詳細
dimは連結させる次元を指定できます。
この意味を以下の3次元のテンソル同士を連結させることで確かめます。
m = torch.arange(8).reshape(2, 2, 2)
"""
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
"""
m.size()
# torch.Size([2, 2, 2])
dim=0
を指定した場合torch.cat(tensors=(m, m), dim=0)
"""
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]],
[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
"""
# torch.Size([4, 2, 2])
サイズを見てみると、torch.Size([2, 2, 2])
→torch.Size([4, 2, 2])
になっているのが分かりますね。つまり、dim=0
を指定すると、1番目の次元に連結されるということです。pythonは0から始まるものなので、dim=0
と指定するという訳です。
dim=1
を指定した場合torch.cat(tensors=(m, m), dim=1)
"""
tensor([[[0, 1],
[2, 3],
[0, 1],
[2, 3]],
[[4, 5],
[6, 7],
[4, 5],
[6, 7]]])
"""
# torch.Size([2, 4, 2])
同様に、サイズを見てみると、torch.Size([2, 2, 2])
→torch.Size([2, 4, 2])
になっています。つまり、dim=1
を指定すると、2番目の次元に連結されるということです。
dim=2
を指定した場合torch.cat(tensors=(m, m), dim=2)
"""
tensor([[[0, 1, 0, 1],
[2, 3, 2, 3]],
[[4, 5, 4, 5],
[6, 7, 6, 7]]])
"""
# torch.Size([2, 2, 4])
同様に、サイズを見てみると、torch.Size([2, 2, 2])
→torch.Size([2, 2, 4])
になっています。つまり、dim=2
を指定すると、3番目の次元に連結されるということです。
ここまで理解できていれば、dim=-1
などの意味も分かってくると思います。
pythonではa = [0, 1, 2]
のようなリストがあった場合、a[-1]
は2を表しますよね。
同じように考えると、今回はdim=-1
はdim=2
と等しいということが分かると思います。
連結前のテンソルの次元によって変化するので、その点は注意が必要です。
outの詳細
outにテンソルを指定すると上書きされます。
色々試しているので、動作についてはそこから感じ取ってください。
# 以下のテンソルを連結します
m = torch.arange(8).reshape(2, 2, 2)
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
out=連結前のテンソル
を指定した場合torch.cat(tensors=(m, m), out=m)
# RuntimeError: 0unsupported operation:
# the input tensors cannot refer to any of the output memory locations.
# Found overlap in input tensor 0
入力テンソルと出力テンソルは異なる必要があります。
out=torch.tensor([])
を指定した場合n = torch.tensor([])
# tensor([])
torch.cat(tensors=(m, m), out=n)
"""
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]],
[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
"""
print(n)
"""
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]],
[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
"""
nが上書きされてますね。
また、mはtorch.LongTensor
でしたが、連結後はtorch.FloatTensor
になっています。
これは、outに指定したテンソルに合わせるためだと思います。
まとめ
dim
はどの次元を基準に連結させるかを決めることができるout
にテンソルを指定すると、連結後のテンソルに上書きされる
コメント