Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.viewの引数・使い方を徹底解説!エラーが出るコードとは?コード例を豊富に用意!

torch.viewは、テンソルの形状(shape)を変更するための関数です。
この関数を使用すると、テンソルの要素数は変わらずに、形状を自由に変更することができます。

ドキュメント:Tensor Views — PyTorch 2.0 documentation

目次

イメージを掴もう

torch.viewの引数

  • shape
    テンソルの新しい形状を指定するための整数値のシーケンスです。
    各整数は、その次元のサイズを表します。

torch.viewの使い方・入出力例

入力と出力例を見てどうなっているのか把握しましょう。

import torch

# 1次元テンソルの作成
x = torch.tensor([1, 2, 3, 4, 5, 6])

# テンソルの形状を変更する
y = x.view(6)
print(y)
# 出力: tensor([1, 2, 3, 4, 5, 6])
# 元の形状を維持したままの1次元テンソル

z = x.view(2, 3)
print(z)
# 出力:
# tensor([[1, 2, 3],
#         [4, 5, 6]])
# 2行3列の2次元テンソルに変換されます

w = x.view(3, -1)
print(w)
# 出力:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])
# 3行2列の2次元テンソルに変換されます(-1を指定して自動的にサイズを決定)

a = x.view(6, 1)
print(a)
# 出力:
# tensor([[1],
#         [2],
#         [3],
#         [4],
#         [5],
#         [6]])
# 6行1列の2次元テンソルに変換されます

b = x.view(2, 1, 3)
print(b)
# 出力:
# tensor([[[1, 2, 3]],
# 
#         [[4, 5, 6]]])
# 2チャンネルの1行3列の3次元テンソルに変換されます

c = x.view(-1, 2, 1)
print(c)
# 出力:
# tensor([[[1],
#          [2]],
# 
#         [[3],
#          [4]],
# 
#         [[5],
#          [6]]])
# 3チャンネルの2行1列の3次元テンソルに変換されます

# 3次元テンソルの作成
d = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

# テンソルの形状を変更する
e = d.view(4, 2)
print(e)
# 出力:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6],
#         [7, 8]])
# 元の形状を維持したままの2次元テンソル

f = d.view(-1)
print(f)
# 出力: tensor([1, 2, 3, 4, 5, 6, 7, 8])
# テンソルを1次元にフラット化します(自動的に要素数に合わせた形状に変換)

g = d.view(2, -1, 2)
print(g)
# 出力:
# tensor([[[1, 2],
#          [3, 4]],
# 
#         [[5, 6],
#          [7, 8]]])
# 2チャンネルの2行2列の3次元テンソルに変換されます(-1を指定して自動的にサイズを決定)

エラーが出る場合とは

viewは、元のテンソルと整合しない形状を指定した場合にエラーを発生させることがあります。

viewがエラーを出すケースをいくつか説明します。

import torch

# エラーを発生させる例

# 1次元テンソルの作成
x = torch.tensor([1, 2, 3, 4, 5, 6])

# サイズが一致しない形状への変換
y = x.view(4)  # エラー: サイズが一致しません
# RuntimeError: shape '[4]' is invalid for input of size 6

# 要素数が一致しない形状への変換
z = x.view(2, 3, 2)  # エラー: 要素数が一致しません
# RuntimeError: shape '[2, 3, 2]' is invalid for input of size 6

# 負の値が複数含まれる形状への変換
w = x.view(-1, -1)  # エラー: 負の値が複数含まれています
# RuntimeError: Only one dimension can be inferred

また、メモリ上での配列の順番が連続していない場合に、サイズ変更時にエラーが発生します。

# メモリ上で連続しない順番のテンソルを作成
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
x2 = x1.t()  # x1を転置したテンソル

# サイズ変更を試みる
reshaped_tensor = x2.view(6)
# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

解決策としては、reshapeを使う方法とcontiguousを使う方法があります。

# reshapeを使う方法
import torch

x1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
x2 = x1.t()

reshaped_tensor = x2.reshape(6)  # reshape()を使用して形状変更

print(reshaped_tensor) # tensor([1, 4, 2, 5, 3, 6])
# contiguousを使う方法
import torch

x1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
x2 = x1.t()

x2 = x2.contiguous()  # contiguous()を使用してテンソルを連続化
reshaped_tensor = x2.view(6)

print(reshaped_tensor) # tensor([1, 4, 2, 5, 3, 6])

まとめ

torch.viewを使用すると、テンソルの形状を柔軟に変更することができます。
適切な形状を指定することで、データの表現やニューラルネットワークの入力に適したテンソルを作成できます。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次