【Pytorch】torch.catの意味や引数を詳しく解説!dim=-1を設定した場合は?

torch.catはgithubでモデルの構成を見ていると、かなり登場するので色々と調べてみました。
例を使って色々と試しています。

目次

torch.catの意味

torch.catの役割は、与えられた次元のテンソルの列を連結することです。
catと聞くと猫?となりますが、concatenate(連結する)の略だと考えれば納得ですね。

torch.catの引数

では、torch.catの引数について見ていきます。
torch.catの引数は、torch.cat(tensorsdim=0*out=None)です。
それぞれの引数の説明を見ていきます。

tensors

入力テンソルが入ったlistまたはtupleです。

注意点として、listまたはtuple内のテンソルの次元は同じでなくてはいけません。

dim

テンソルが連結される次元を設定できます。

デフォルトはdim=0です。

out

出力テンソルです。

では、dimoutの説明がこれだけだと分かりにくいと思うので、この2つについて詳しく説明していきます。

dimの詳細

dimは連結させる次元を指定できます。
この意味を以下の3次元のテンソル同士を連結させることで確かめます。

m = torch.arange(8).reshape(2, 2, 2)
"""
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])
"""
m.size()
# torch.Size([2, 2, 2])
dim=0を指定した場合
torch.cat(tensors=(m, m), dim=0)
"""
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]],

        [[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])
"""
# torch.Size([4, 2, 2])

サイズを見てみると、torch.Size([2, 2, 2])torch.Size([4, 2, 2])になっているのが分かりますね。つまり、dim=0を指定すると、1番目の次元に連結されるということです。pythonは0から始まるものなので、dim=0と指定するという訳です。

dim=1を指定した場合
torch.cat(tensors=(m, m), dim=1)
"""
tensor([[[0, 1],
         [2, 3],
         [0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7],
         [4, 5],
         [6, 7]]])
"""
# torch.Size([2, 4, 2])

同様に、サイズを見てみると、torch.Size([2, 2, 2])torch.Size([2, 4, 2])になっています。つまり、dim=1を指定すると、2番目の次元に連結されるということです。

dim=2を指定した場合
torch.cat(tensors=(m, m), dim=2)
"""
tensor([[[0, 1, 0, 1],
         [2, 3, 2, 3]],

        [[4, 5, 4, 5],
         [6, 7, 6, 7]]])
"""
# torch.Size([2, 2, 4])

同様に、サイズを見てみると、torch.Size([2, 2, 2])torch.Size([2, 2, 4])になっています。つまり、dim=2を指定すると、3番目の次元に連結されるということです。

ここまで理解できていれば、dim=-1などの意味も分かってくると思います。

pythonではa = [0, 1, 2]のようなリストがあった場合、a[-1]は2を表しますよね。

同じように考えると、今回はdim=-1dim=2と等しいということが分かると思います。

連結前のテンソルの次元によって変化するので、その点は注意が必要です。

outの詳細

outにテンソルを指定すると上書きされます。(多分この説明であっているはずです。間違っていたら指摘お願いします。)
色々試しているので、動作についてはそこから感じ取ってください。

# 以下のテンソルを連結します
m = torch.arange(8).reshape(2, 2, 2)
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])
out=連結前のテンソルを指定した場合
torch.cat(tensors=(m, m), out=m)
# RuntimeError: 0unsupported operation: 
# the input tensors cannot refer to any of the output memory locations. 
# Found overlap in input tensor 0

入力テンソルと出力テンソルは異なる必要があります。

out=torch.tensor([])
を指定した場合
n = torch.tensor([])
# tensor([])

torch.cat(tensors=(m, m), out=n)
"""
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]],

        [[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
"""

print(n)
"""
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]],

        [[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
"""

nが上書きされてますね。
また、mはtorch.LongTensorでしたが、連結後はtorch.FloatTensorになっています。
これは、outに指定したテンソルに合わせるためだと思います。

まとめ

  • dimはどの次元を基準に連結させるかを決めることができる
  • outにテンソルを指定すると、連結後のテンソルに上書きされる
あわせて読みたい
【Pytorch】torch.sumのdim=-1, 0, (1, 1)などの意味とは?実際にコードを動かして検証してみた githubでコードを見ていると、torch.sum(m, dim=-1)のように傍目では意味が分からない部分が出てきたので、torch.sum()について色々と調べてみました。 【の役割】 torc...

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次