Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【PyTorch】ベクトルを操作する代表的な関数を解説!view, reshape, transpose, unsqueeze, matmul, einsumなどを使えるようになろう!

PyTorchを使っているコードを見ているとベクトルを操作する関数がたくさん出てくると思います。しかし、どんな機能なのかわからないと自分でコードを変更しようとしても難しいです。そこで今回は、実際にベクトルを操作しながらそれぞれの関数を解説していきます。ドキュメントはこちらです。

目次

view

あわせて読みたい
【Pytorch】torch.viewの引数・使い方を徹底解説!エラーが出るコードとは?コード例を豊富に用意! torch.viewは、テンソルの形状(shape)を変更するための関数です。この関数を使用すると、テンソルの要素数は変わらずに、形状を自由に変更することができます。 ドキ...

viewは元のTensorと同じデータで違う配列のデータを返してくれます。

第一引数は行数、第二引数は列数を表し、(元の配列の行数)×(元の配列の列数)=(第一引数)×(第二引数)でなくてはいけません。

githubなどでコードを見ているとview(-1)やview(-1, 5)などがよく見られます。

これは、第一引数が-1なら(元の配列の行数)×(元の配列の列数)=(第二引数に合わせた値)×(第二引数)となり自動的に合わせてくれます。

第一引数しかなく第一引数が-1だった場合は、行数1の配列となります。

import torch
x1 = torch.randn(4,5)
x1
"""
tensor([[ 2.0152, -0.4889,  0.6884,  0.0827, -0.0111],
        [-0.2107,  0.4966,  0.5579, -0.3411,  0.2335],
        [-0.4020, -0.5581, -0.7476, -1.5015, -0.4678],
        [-0.0260, -0.1089,  0.5966,  1.9690, -0.2223]])
"""

x2 = x1.view(10, 2)
x2
"""
tensor([[ 2.0152, -0.4889],
        [ 0.6884,  0.0827],
        [-0.0111, -0.2107],
        [ 0.4966,  0.5579],
        [-0.3411,  0.2335],
        [-0.4020, -0.5581],
        [-0.7476, -1.5015],
        [-0.4678, -0.0260],
        [-0.1089,  0.5966],
        [ 1.9690, -0.2223]])
"""

x3 = x1.view(-1)
x3
"""
tensor([ 2.0152, -0.4889,  0.6884,  0.0827, -0.0111, -0.2107,  0.4966,  0.5579,
        -0.3411,  0.2335, -0.4020, -0.5581, -0.7476, -1.5015, -0.4678, -0.0260,
        -0.1089,  0.5966,  1.9690, -0.2223])
"""

x4 = x1.view(-1, 2)
x4
"""
tensor([[ 2.0152, -0.4889],
        [ 0.6884,  0.0827],
        [-0.0111, -0.2107],
        [ 0.4966,  0.5579],
        [-0.3411,  0.2335],
        [-0.4020, -0.5581],
        [-0.7476, -1.5015],
        [-0.4678, -0.0260],
        [-0.1089,  0.5966],
        [ 1.9690, -0.2223]])
"""

また、元の配列はメモリ上においても順番に並んでいなければエラーがでます。次の例では、メモリ上でx1を転置したものをサイズ変更しようとしてエラーが出ています。

x1.T.view(10, 2)
"""
RuntimeError: view size is not compatible with input tensor's size and stride 
(at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
"""

エラーに書いてある通り、この解決策としては次の二つあります。

# viewの前にcontiguous()をつける
x1.T.contiguous().view(10, 2)

# reshapeを使う
x1.T.reshape(10, 2)

reshape

あわせて読みたい
【Pytorch】torch.reshapeの引数・使い方を徹底解説!20個のコード例を用意!torch.viewとの違いも解説! torch.reshapeは、指定したテンソルを新しい形状に変換します。この関数は、テンソルの要素数が変わらないように形状を変更します。つまり、テンソルの形状を変更するだ...

reshapeはviewとほぼ同じ働きをします。違いとして、reshapeの場合はメモリ上の並び順は違って大丈夫という点です。

x1 = torch.randn(4,5)
x1
"""
tensor([[ 2.0152, -0.4889,  0.6884,  0.0827, -0.0111],
        [-0.2107,  0.4966,  0.5579, -0.3411,  0.2335],
        [-0.4020, -0.5581, -0.7476, -1.5015, -0.4678],
        [-0.0260, -0.1089,  0.5966,  1.9690, -0.2223]])
"""

x2 = x1.reshape(10, 2)
x2
"""
tensor([[ 2.0152, -0.4889],
        [ 0.6884,  0.0827],
        [-0.0111, -0.2107],
        [ 0.4966,  0.5579],
        [-0.3411,  0.2335],
        [-0.4020, -0.5581],
        [-0.7476, -1.5015],
        [-0.4678, -0.0260],
        [-0.1089,  0.5966],
        [ 1.9690, -0.2223]])
"""

x3 = x1.T.reshape(10, 2)
x3
"""
tensor([[ 2.0152, -0.2107],
        [-0.4020, -0.0260],
        [-0.4889,  0.4966],
        [-0.5581, -0.1089],
        [ 0.6884,  0.5579],
        [-0.7476,  0.5966],
        [ 0.0827, -0.3411],
        [-1.5015,  1.9690],
        [-0.0111,  0.2335],
        [-0.4678, -0.2223]])

"""

transpose

あわせて読みたい
【Pytorch】torch.transposeの引数・使い方を徹底解説!テンソルの次元を入れ替える方法 torch.transposeはテンソル(多次元配列)の次元を入れ替えるために使用されます。この関数を使用することで、テンソルの次元を簡単に変更することができます。 ドキュ...

transposeは元の配列を置換します。

x1 = torch.randn(4,5)
x1.shape
"""
torch.Size([4, 5])
"""

x2 = x1.transpose(0, 1)
x2.shape
"""
torch.Size([5, 4])
"""

unsqueeze

あわせて読みたい
【Pytorch】torch.unsqueezeの引数・使い方を徹底解説!どのような操作が行われているかを分かりやすく... torch.unsqueezeは、PyTorchの関数であり、テンソルに新しい次元を挿入するための操作を行います。挿入される次元の大きさは1であり、元のテンソルの次元数が1つ増えま...

unsqueezeは、新たに軸を追加するために用いられます。どの軸が追加されるのかは実際の出力結果を見るとわかりやすいと思います。

x1 = torch.randn(4,5)
x1.shape
"""
torch.Size([4, 5])
"""

x1.unsqueeze(0).shape
"""
torch.Size([1, 4, 5])
"""

x1.unsqueeze(1).shape
"""
torch.Size([4, 1, 5])
"""

x1.unsqueeze(2).shape
"""
torch.Size([4, 5, 1])
"""

matmul

あわせて読みたい
【Pytorch】torch.matmulの引数・使い方を徹底解説!2つのテンソルの行列乗算を計算する方法 torch.matmulは、PyTorchのテンソルを操作する際に使用される行列積の関数です。この関数は、与えられたテンソルの行列積を計算し、新しいテンソルを返します。異なる次...

matmulは、行列などの積を計算します。いろいろなパターンの計算ができます。次の例以外にもいろいろできます。

x1 = torch.randn(5, 3)
y1 = torch.rand(3, 5)
torch.matmul(x1, y1).shape
"""
torch.Size([5, 5])
"""
x1 = torch.randn(9, 2, 5, 3)
y1 = torch.rand(2, 3, 5)
torch.matmul(x1, y1).shape
"""
torch.Size([9, 2, 5, 5])
"""

正直、視覚的に何やっているのか分かりにくいですね。コードを見たときに分かりやすい関数としては次のeinsumが挙げられます。

einsum

あわせて読みたい
【Pytorch】torch.einsumの引数・使い方を徹底解説!アインシュタインの縮約規則を利用して複雑なテンソ... torch.einsumは、入力されたテンソルの要素の積を、アインシュタインの縮約規則に基づいて指定された次元に沿って合計する関数です。この規則により、多くの共通する多...

einsumは、上記のmatmulと同じように行列などの計算をします。

異なる点は、記述方法でかなり感覚的に分かりやすいです。

例えば、ミニバッチ5、行数3、列数5の行列とミニバッチ5、行数5、列数2の行列があり積を計算したい時は次のようなコードを書きます。

x1 = torch.randn(5, 3, 5)
y1 = torch.rand(5, 5, 2)
torch.einsum("bnm,bkl->bnl", x1, y1).shape
"""
torch.Size([5, 3, 2])
"""

特徴的なのは、第一引数の”bnm,bkl->bnl”ですね。

対応としては、x1のバッチサイズb, 行数n, 列数m、y1のバッチサイズb, 行数k, 列数l、出来上がったもののバッチサイズb, 行数n, 列数lとなっています。

einsumを使うとどのように行列が操作されているのか追いやすいです。他にも次のようにも使えます。

x2 = torch.randn(3)
y2 = torch.rand(4)
torch.einsum("i,j->ij", x1, y1).shape
"""
torch.Size([3, 4])
"""
x3 = torch.randn(3, 6)
y3 = torch.rand(6, 2)
torch.einsum("nm,ij->nj", x1, y1).shape
"""
torch.Size([3, 2])
"""

最後に

今回は、pytorchにおける行列関係の代表的な関数を見ていきました。言語モデルなどを作成する過程で避けては通れないところなので備忘録的にまとめました。他の関数についても順次追加していきたいと思います。

あわせて読みたい
【Pytorch】torch.catの使い方・引数を徹底解説!テンソルを結合する方法 torch.catはPyTorchでテンソル(多次元配列)を結合するための関数です。この関数を使用すると、指定した次元でテンソルを連結することができます。 ドキュメント:torc...
あわせて読みたい
【Pytorch】torch.sumの使い方・引数を徹底解説!dim=-1, 0, (1, 1)などの意味とは? torch.sumの役割は、与えられたテンソルの要素の合計を返すことです。 【イメージを掴もう】 【の引数】 引数は、torch.sum(input, dim, keepdim=False, *, dtype=None)...
あわせて読みたい
【Pytorch】nn.Linearの引数・ソースコードを徹底解説! torch.nn.Linearは基本的な全結合層です。今回は、nn.Linearの引数とソースコードをしっかりと説明していきます。曖昧な理解を直していきましょう。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次