Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.catの使い方・引数を徹底解説!テンソルを結合する方法

torch.catはPyTorchでテンソル(多次元配列)を結合するための関数です。
この関数を使用すると、指定した次元でテンソルを連結することができます。

ドキュメント:torch.cat — PyTorch 2.0 documentation

目次

イメージを掴もう

イメージを掴むと使いやすくなると思うので画像にしました。

torch.catの引数

torch.catは基本的に以下のように使うことができます。

import torch

# サンプルのテンソルを作成します
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

# テンソルを連結します
result = torch.cat((tensor1, tensor2))
print(result)
# 出力:
# tensor([1, 2, 3, 4, 5, 6])

では引数について見ていきます。

  • tensors
    結合するテンソルのリストまたはタプルを指定します。
  • dim
    結合する次元のインデックスを指定します。
    デフォルト値は0です。
  • out
    結果を格納する出力テンソルです。
    指定されない場合は新しいテンソルが作成されます。

dimとoutについて具体例を用いて詳しく解説します。

dimの詳細

dimは連結させる次元を指定できます。
この意味を以下の3次元のテンソル同士を連結させることで確かめます。

m = torch.arange(8).reshape(2, 2, 2)
"""
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])
"""
m.size()
# torch.Size([2, 2, 2])
dim=0を指定した場合
torch.cat(tensors=(m, m), dim=0)
"""
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]],

        [[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])
"""
# torch.Size([4, 2, 2])

サイズを見てみると、torch.Size([2, 2, 2])torch.Size([4, 2, 2])になっているのが分かりますね。つまり、dim=0を指定すると、1番目の次元に連結されるということです。pythonは0から始まるものなので、dim=0と指定するという訳です。

dim=1を指定した場合
torch.cat(tensors=(m, m), dim=1)
"""
tensor([[[0, 1],
         [2, 3],
         [0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7],
         [4, 5],
         [6, 7]]])
"""
# torch.Size([2, 4, 2])

同様に、サイズを見てみると、torch.Size([2, 2, 2])torch.Size([2, 4, 2])になっています。つまり、dim=1を指定すると、2番目の次元に連結されるということです。

dim=2を指定した場合
torch.cat(tensors=(m, m), dim=2)
"""
tensor([[[0, 1, 0, 1],
         [2, 3, 2, 3]],

        [[4, 5, 4, 5],
         [6, 7, 6, 7]]])
"""
# torch.Size([2, 2, 4])

同様に、サイズを見てみると、torch.Size([2, 2, 2])torch.Size([2, 2, 4])になっています。つまり、dim=2を指定すると、3番目の次元に連結されるということです。

ここまで理解できていれば、dim=-1などの意味も分かってくると思います。

pythonではa = [0, 1, 2]のようなリストがあった場合、a[-1]は2を表しますよね。

同じように考えると、今回はdim=-1dim=2と等しいということが分かると思います。

連結前のテンソルの次元によって変化するので、その点は注意が必要です。

outの詳細

outにテンソルを指定すると上書きされます。
色々試しているので、動作についてはそこから感じ取ってください。

# 以下のテンソルを連結します
m = torch.arange(8).reshape(2, 2, 2)
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])
out=連結前のテンソルを指定した場合
torch.cat(tensors=(m, m), out=m)
# RuntimeError: 0unsupported operation: 
# the input tensors cannot refer to any of the output memory locations. 
# Found overlap in input tensor 0

入力テンソルと出力テンソルは異なる必要があります。

out=torch.tensor([])
を指定した場合
n = torch.tensor([])
# tensor([])

torch.cat(tensors=(m, m), out=n)
"""
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]],

        [[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
"""

print(n)
"""
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]],

        [[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
"""

nが上書きされてますね。
また、mはtorch.LongTensorでしたが、連結後はtorch.FloatTensorになっています。
これは、outに指定したテンソルに合わせるためだと思います。

まとめ

  • dimはどの次元を基準に連結させるかを決めることができる
  • outにテンソルを指定すると、連結後のテンソルに上書きされる
あわせて読みたい
【Pytorch】torch.sumの使い方・引数を徹底解説!dim=-1, 0, (1, 1)などの意味とは? torch.sumの役割は、与えられたテンソルの要素の合計を返すことです。 【イメージを掴もう】 【の引数】 引数は、torch.sum(input, dim, keepdim=False, *, dtype=None)...

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次