Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.matmulの引数・使い方を徹底解説!2つのテンソルの行列乗算を計算する方法

torch.matmulは、PyTorchのテンソルを操作する際に使用される行列積の関数です。
この関数は、与えられたテンソルの行列積を計算し、新しいテンソルを返します。
異なる次元のテンソルに対しても適用することができます。

ドキュメント:torch.matmul — PyTorch 2.0 documentation

目次

イメージを掴もう

torch.matmulの引数

  • input
    行列積の計算に使用される最初のテンソルです。
  • other
    行列積の計算に使用される2番目のテンソルです。
  • out(オプション)
    出力テンソルです。指定された場合、計算結果はこのテンソルに格納されます。

関数の動作について

関数の動作は入力テンソルの次元によって以下のように異なります。

スクロールできます
入力の次元計算結果のタイプ
1D × 1Dスカラー(ドット積)
2D × 2D行列-行列の積
1D × 2D行列-ベクトルの積
2D × 1D行列-ベクトルの積
N-D × 1Dバッチ化された行列-ベクトルの積
1D × N-Dバッチ化された行列-ベクトルの積
N-D × N-Dバッチ化された行列-行列の積

上記の表でN-Dは、2以上の次元数を表します。また、バッチ化された行列積では、適切な次元の追加と削除が行われます。

これらの情報を基に、適切な次元のテンソルを使用してtorch.matmul()関数を呼び出すことで、適切な行列積を計算することができます。

torch.matmulの使い方・入出力例

1D × 1D (スカラーのドット積)

import torch

# 入力テンソルの作成
input_1d = torch.tensor([1, 2, 3])
other_1d = torch.tensor([4, 5, 6])

# torch.matmul()を使用してドット積を計算
result = torch.matmul(input_1d, other_1d)

print(result)  # スカラー値が出力される
# 出力: tensor(32)

2D × 2D (行列-行列の積)

import torch

# 入力テンソルの作成
input_2d = torch.tensor([[1, 2], [3, 4]])
other_2d = torch.tensor([[5, 6], [7, 8]])

# torch.matmul()を使用して行列-行列の積を計算
result = torch.matmul(input_2d, other_2d)

print(result)  # 行列-行列の積が出力される
# 出力: tensor([[19, 22],
#               [43, 50]])

1D × 2D (行列-ベクトルの積)

import torch

# 入力テンソルの作成
input_1d = torch.tensor([1, 2])
other_2d = torch.tensor([[3, 4], [5, 6]])

# torch.matmul()を使用して行列-ベクトルの積を計算
result = torch.matmul(input_1d, other_2d)

print(result)  # 行列-ベクトルの積が出力される
# 出力: tensor([13, 16])

2D × 1D (行列-ベクトルの積)

import torch

# 入力テンソルの作成
input_2d = torch.tensor([[1, 2], [3, 4]])
other_1d = torch.tensor([5, 6])

# torch.matmul()を使用して行列-ベクトルの積を計算
result = torch.matmul(input_2d, other_1d)

print(result)  # 行列-ベクトルの積が出力される
# 出力: tensor([17, 39])

N-D × 1D (バッチ化された行列-ベクトルの積)

import torch

# 入力テンソルの作成
input_nd = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
other_1d = torch.tensor([9, 10])

# torch.matmul()を使用してバッチ化された行列-ベクトルの積を計算
result = torch.matmul(input_nd, other_1d)

print(result)  # バッチ化された行列-ベクトルの積が出力される
# 出力: tensor([[29, 67], 
#               [53, 125]])

1D × N-D (バッチ化された行列-ベクトルの積)

import torch

# 入力テンソルの作成
input_1d = torch.tensor([1, 2])
other_nd = torch.tensor([[[3, 4], [5, 6]], [[7, 8], [9, 10]]])

# torch.matmul()を使用してバッチ化された行列-ベクトルの積を計算
result = torch.matmul(input_1d, other_nd)

print(result)  # バッチ化された行列-ベクトルの積が出力される
# 出力: tensor([[13, 16], 
#               [17, 20]])

N-D × N-D (バッチ化された行列-行列の積)

import torch

# 入力テンソルの作成
input_nd = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
other_nd = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])

# torch.matmul()を使用してバッチ化された行列-行列の積を計算
result = torch.matmul(input_nd, other_nd)

print(result)  # バッチ化された行列-行列の積が出力される
# 出力: tensor([[[ 31,  34],
#                 [ 71,  78]],
#                [[155, 166],
#                 [211, 226]]])

2D × 2Dの場合(行列-行列の積)でoutを使用する例

import torch

# 入力テンソルの作成
input_2d = torch.tensor([[1, 2], [3, 4]])
other_2d = torch.tensor([[5, 6], [7, 8]])

# 出力テンソルの事前の準備
out_tensor = torch.empty((2, 2), dtype=torch.int64) # 2x2のテンソルを作成

# torch.matmul()を使用して行列-行列の積を計算し、結果を指定した出力テンソルに格納
torch.matmul(input_2d, other_2d, out=out_tensor)

print(out_tensor)  # 行列-行列の積が出力される
# tensor([[19, 22],
#         [43, 50]])

まとめ

torch.matmulは、PyTorchのテンソルを操作する際に使用される行列積の関数です。
torch.matmul()を適切に使用することで、異なる次元のテンソル間での行列積を柔軟かつ効果的に計算できます。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次