Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.reshapeの引数・使い方を徹底解説!20個のコード例を用意!torch.viewとの違いも解説!

torch.reshapeは、指定したテンソルを新しい形状に変換します。
この関数は、テンソルの要素数が変わらないように形状を変更します。
つまり、テンソルの形状を変更するだけであり、要素の値や順序は変更されません

ドキュメント:torch.reshape — PyTorch 2.0 documentation

目次

イメージを掴もう

torch.reshapeの引数

  • input
    形状を変更したいテンソルを指定します。
  • shape
    テンソルを変換した後の形状を指定する整数のタプルまたはリスト

torch.viewの使い方・入出力例

コード例をみてtorch.viewの挙動を理解してください。

基本的なコード例

import torch

# 1. テンソルを1次元に変形
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
y1 = torch.reshape(x1, (-1,))
print(x1.shape, "->", y1.shape)  # torch.Size([2, 3]) -> torch.Size([6])

# 2. テンソルを2次元に変形
x2 = torch.tensor([1, 2, 3, 4, 5, 6])
y2 = torch.reshape(x2, (2, -1))
print(x2.shape, "->", y2.shape)  # torch.Size([6]) -> torch.Size([2, 3])

# 3. テンソルを3次元に変形
x3 = torch.tensor([1, 2, 3, 4, 5, 6])
y3 = torch.reshape(x3, (1, 2, -1))
print(x3.shape, "->", y3.shape)  # torch.Size([6]) -> torch.Size([1, 2, 3])

# 4. テンソルを元の形状に戻す
x4 = torch.tensor([[1, 2], [3, 4], [5, 6]])
y4 = torch.reshape(x4, (3, 2))
print(x4.shape, "->", y4.shape)  # torch.Size([3, 2]) -> torch.Size([3, 2])

# 5. テンソルの次元を追加
x5 = torch.tensor([1, 2, 3, 4])
y5 = torch.reshape(x5, (2, 2, 1))
print(x5.shape, "->", y5.shape)  # torch.Size([4]) -> torch.Size([2, 2, 1])

# 6. テンソルの次元を削除
x6 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
y6 = torch.reshape(x6, (-1, 2))
print(x6.shape, "->", y6.shape)  # torch.Size([2, 2, 2]) -> torch.Size([4, 2])

# 7. テンソルの要素数を変更しない形状変更
x7 = torch.tensor([1, 2, 3, 4, 5, 6])
y7 = torch.reshape(x7, (6, 1))
print(x7.shape, "->", y7.shape)  # torch.Size([6]) -> torch.Size([6, 1])

# 8. テンソルの一部の次元を結合
x8 = torch.tensor([[1, 2], [3, 4]])
y8 = torch.reshape(x8, (-1,))
print(x8.shape, "->", y8.shape)  # torch.Size([2, 2]) -> torch.Size([4])

# 9. テンソルを転置
x9 = torch.tensor([[1, 2, 3], [4, 5, 6]])
y9 = torch.reshape(x9, (3, 2)).t()
print(x9.shape, "->", y9.shape)  # torch.Size([2, 3]) -> torch.Size([2, 3])

# 10. テンソルを展開
x10 = torch.tensor([[1, 2], [3, 4], [5, 6]])
y10 = torch.reshape(x10, (-1,))
print(x10.shape, "->", y10.shape)  # torch.Size([3, 2]) -> torch.Size([6])

応用的なコード例

import torch

# 11. テンソルの要素を逆順に並べる
x11 = torch.tensor([[1, 2, 3], [4, 5, 6]])
y11 = torch.reshape(torch.flip(x11, [0, 1]), (-1,))
print(x11.shape, "->", y11.shape)  # torch.Size([2, 3]) -> torch.Size([6])

# 12. テンソルの要素をブロックごとに結合
x12 = torch.tensor([[1, 2], [3, 4]])
y12 = torch.reshape(x12, (2, 1, 2)).repeat(1, 3, 1).reshape(2, -1)
print(x12.shape, "->", y12.shape)  # torch.Size([2, 2]) -> torch.Size([2, 6])

# 13. テンソルをバッチ次元として処理
x13 = torch.tensor([1, 2, 3, 4])
y13 = torch.reshape(x13, (1, 2, 2))
print(x13.shape, "->", y13.shape)  # torch.Size([4]) -> torch.Size([1, 2, 2])

# 14. テンソルの特定の次元を固定して変形
x14 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
y14 = torch.reshape(x14, (3, 3, 1)).permute(1, 0, 2)
print(x14.shape, "->", y14.shape)  # torch.Size([3, 3]) -> torch.Size([3, 3, 1])

# 15. テンソルを要素ごとに積として変形
x15 = torch.tensor([1, 2, 3, 4])
y15 = torch.reshape(x15, (-1, 1)) * torch.reshape(x15, (1, -1))
print(x15.shape, "->", y15.shape)  # torch.Size([4]) -> torch.Size([4, 4])

# 16. テンソルを2つの要素に分割
x16 = torch.tensor([1, 2, 3, 4, 5, 6])
y16_1, y16_2 = torch.split(x16, [3, 3])
print(x16.shape, "->", y16_1.shape, y16_2.shape)  # torch.Size([6]) -> torch.Size([3]) torch.Size([3])

# 17. テンソルを特定の要素でパディング
x17 = torch.tensor([[1, 2, 3], [4, 5, 6]])
y17 = torch.nn.functional.pad(x17, (1, 2, 3, 4))
print(x17.shape, "->", y17.shape)  # torch.Size([2, 3]) -> torch.Size([9, 6])

# 18. テンソルを特定の軸で結合
x18_1 = torch.tensor([[1, 2], [3, 4]])
x18_2 = torch.tensor([[5, 6], [7, 8]])
y18 = torch.reshape(torch.stack((x18_1, x18_2), dim=1), (-1, 2))
print(x18_1.shape, x18_2.shape, "->", y18.shape)  # torch.Size([2, 2]) torch.Size([2, 2]) -> torch.Size([4, 2])

# 19. テンソルの形状を変えずに次元をシャッフル
x19 = torch.tensor([[1, 2], [3, 4], [5, 6]])
y19 = torch.reshape(x19, (3, 1, 2)).permute(2, 0, 1)
print(x19.shape, "->", y19.shape)  # torch.Size([3, 2]) -> torch.Size([2, 3, 1])

# 20. テンソルの要素をランダムに入れ替えて変形
x20 = torch.tensor([1, 2, 3, 4, 5, 6])
y20 = torch.reshape(x20[torch.randperm(x20.numel())], (2, -1))
print(x20.shape, "->", y20.shape)  # torch.Size([6]) -> torch.Size([2, 3])

torch.reshapeとtorch.viewとの違い

PyTorchでは、テンソルの形状変換には主に2つの関数、torch.reshape()とtorch.view()が使用されます。
これらの関数はテンソルの形状を変更するために使用されますが、いくつかの重要な違いがあります。

スクロールできます
torch.reshapetorch.view
役割入力テンソルと同じデータと要素数を持つが、指定された形状を持つテンソルを返す自身のテンソルと同じデータを持つが、異なる形状を持つ新しいテンソルを返す​
ビューまたはコピー可能な場合、返されるテンソルは入力のビューになる。それ以外の場合はコピーになる​新しいビューのサイズが元のサイズとストライドと互換性がある場合のみ、テンソルをビューとして変形できる。それ以外の場合、代わりにコピーする必要がある​
条件返されるテンソルが入力のビューになるためには、入力が連続的であるか、互換性のあるストライドを持つ必要がある。ただし、コピーかビューかに依存してはならない​テンソルがメモリ上で非連続的な場合、エラーが発生し、テンソルをビューとして変形することはできない​
テンソルがメモリ上で非連続的な場合エラーは発生せず、テンソルはコピーされる。エラーが発生し、テンソルをビューとして変形することはできない​
推奨される使用場面形状の互換性に依存せず、入力テンソルの新しい形状が必要な場合に使用する形状の変更が必要で、新しいビューのサイズが元のサイズとストライドと互換性がある場合、またはビューが可能かどうか不明な場合はtorch.reshapeを使用することが推奨される​
import torch

# reshapeの例

# テンソルを1次元に変形
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
y1 = torch.reshape(x1, (-1,))
print(x1.shape, "->", y1.shape)  # torch.Size([2, 3]) -> torch.Size([6])

# メモリ上で非連続なテンソルを作成
x2 = torch.tensor([[1, 2, 3], [4, 5, 6]])
x2 = x2[:, :-1]  # メモリ上で非連続なテンソルにする
y2 = torch.reshape(x2, (-1,))
print(x2.shape, "->", y2.shape)  # torch.Size([2, 2]) -> torch.Size([4])
# viewの例

# テンソルを1次元に変形
x3 = torch.tensor([[1, 2, 3], [4, 5, 6]])
y3 = x3.view(-1)
print(x3.shape, "->", y3.shape)  # torch.Size([2, 3]) -> torch.Size([6])

# メモリ上で非連続なテンソルを作成
x4 = torch.tensor([[1, 2, 3], [4, 5, 6]])
x4 = x4[:, :-1]  # メモリ上で非連続なテンソルにする
y4 = x4.view(-1)
print(x4.shape, "->", y4.shape)  # torch.Size([2, 2]) -> torch.Size([4])
# 出力: RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

まとめ

torch.reshapeは、テンソル(多次元配列)の形状を変更するために使用されます。
この関数を使用することで、テンソルの次元やサイズを柔軟に変更することができます。
テンソルがメモリ上で非連続的な場合において、torch.viewと挙動が異なるので注意が必要です。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次