Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.unsqueezeの引数・使い方を徹底解説!どのような操作が行われているかを分かりやすく解説!

torch.unsqueezeは、PyTorchの関数であり、テンソルに新しい次元を挿入するための操作を行います。
挿入される次元の大きさは1であり、元のテンソルの次元数が1つ増えます。

ドキュメント:torch.unsqueeze — PyTorch 2.0 documentation

目次

イメージを掴もう

torch.unsqueezeの引数

  • input
    操作を行いたいテンソルを指定します。
  • dim
    新しく挿入したい次元のインデックスを指定します。

torch.unsqueezeの使い方・入出力例

dimの違いによってどのような操作がされるのか見ていきます。

次元0に挿入する場合

import torch

# サンプルテンソルの作成
tensor = torch.tensor([1, 2, 3])

# 次元0に挿入
unsqueeze_tensor = torch.unsqueeze(tensor, 0)

print("挿入後のテンソル:")
print(unsqueeze_tensor) # 出力: tensor([[1, 2, 3]])
print("挿入後のテンソルの形状:", unsqueeze_tensor.shape) # torch.Size([1, 3])

次元1に挿入する場合

import torch

# サンプルテンソルの作成
tensor = torch.tensor([1, 2, 3])

# 次元1に挿入
unsqueeze_tensor = torch.unsqueeze(tensor, 1)

print("挿入後のテンソル:")
print(unsqueeze_tensor) # 出力: tensor([[1], [2], [3]])
print("挿入後のテンソルの形状:", unsqueeze_tensor.shape) # 出力: torch.Size([3, 1])

負の次元に挿入する場合

import torch

# サンプルテンソルの作成
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 次元-1に挿入 (最後の次元に挿入)
unsqueeze_tensor = torch.unsqueeze(tensor, -1)

print("挿入後のテンソル:")
print(unsqueeze_tensor)
# 出力:
# tensor([[[1],
#          [2],
#          [3]],
#
#         [[4],
#          [5],
#          [6]]])
print("挿入後のテンソルの形状:", unsqueeze_tensor.shape)
# 出力: torch.Size([2, 3, 1])
import torch

# サンプルテンソルの作成
tensor = torch.tensor([[1, 2], [3, 4]])

# 負の次元(-2)に挿入
unsqueeze_tensor = torch.unsqueeze(tensor, -2)

print("挿入後のテンソル:")
print(unsqueeze_tensor)
# 出力: tensor([[[1, 2]], [[3, 4]]])
print("挿入後のテンソルの形状:", unsqueeze_tensor.shape)
# 出力: torch.Size([2, 1, 2])

次元の範囲外に挿入する場合

import torch

# サンプルテンソルの作成
tensor = torch.tensor([1, 2, 3])

# 次元3に挿入 (範囲外)
unsqueeze_tensor = torch.unsqueeze(tensor, 3)
# IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 3)

範囲外に挿入するとエラーが起こります。

新しい次元を挿入できることのメリット

新しい次元を挿入出来たら何が嬉しいのかについて解説します。

テンソルの形状を制御できる

unsqueezeを使用して新しい次元を挿入することで、テンソルの形状を自由に制御することができます。
これは、特定の形状のテンソルが必要な機械学習モデルや演算にとって重要です。
例えば、畳み込みニューラルネットワーク(CNN)では、通常、バッチサイズ×チャネル数×高さ×幅の4次元テンソルが入力として必要です。
単一の画像は通常3次元テンソルですが、unsqueezeを使用して新しい次元を追加することでバッチサイズの次元を追加し、CNNに適した形状に変換することができます。

ブロードキャストが容易になる

ブロードキャストは、形状の異なるテンソル間での演算を可能にする強力な機能ですが、そのためにはテンソルの次元が適切に整合している必要があります。
unsqueezeを使用して新しい次元を挿入することで、ブロードキャストを容易に行うことができます。通常はエラーとなるようなテンソル同士の演算でも、unsqueezeを使用してテンソルに新しい次元を追加して形状を変化させることで、ブロードキャストが可能になります。

データの解釈を変えられる

テンソルの形状は、データの解釈に直接関連しています。unsqueezeを使用して新しい次元を挿入することで、データの解釈を変えることができます。例えば、形状が(12, 32, 32)のテンソルは、色チャネル数が12である32×32ピクセルの画像と解釈できますが、unsqueezeを使用して新しい次元を追加して形状を(1, 3, 32, 32)にすることで、1つの色チャネル数が12である32×32ピクセルの画像からなるバッチと解釈することができます。

演算の効率向上が期待できる

特定の次元を挿入することで、計算上の利便性や効率を向上させることができます。
例えば、行列とベクトルの積を計算する際、ベクトルを行ベクトルまたは列ベクトルとして扱うことが必要です。
この場合、unsqueezeを使用してベクトルに新しい次元を追加し、適切な形状にすることで行列とベクトルの積を効率的に計算できます。

まとめ

torch.unsqueezeは、モデルの入力データの形状を変更する際に便利です。
torch.unsqueezeは、PyTorchの関数であり、テンソルに新しい次元を挿入するための操作を行います。
例えば、1次元のテンソルを2次元に変換することや、2次元のテンソルを3次元に変換することができます。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次