Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.einsumの引数・使い方を徹底解説!アインシュタインの縮約規則を利用して複雑なテンソル操作を短い文字列を使って行う

torch.einsumは、入力されたテンソルの要素の積を、アインシュタインの縮約規則に基づいて指定された次元に沿って合計する関数です。この規則により、多くの共通する多次元線形代数の配列操作を、短縮形式で表現できます。

ドキュメント:torch.einsum — PyTorch 2.0 documentation

アインシュタインの縮約記法(アインシュタインの縮約規則またはアインシュタインの縮約表記とも呼ばれる)は、数学、特に数理物理学における線形代数の使用において、式中の一連の添字付き項目に対する総和を暗示する表記法の規則です。これは、一連の項目に対する総和を示すことで簡潔さを実現します。アインシュタインはこれを1916年に物理学に導入しました。
参考:wiki(アインシュタインの縮約記法)

目次

torch.einsumの引数

  • equation
    アインシュタイン縮約規則の添字を指定します。
  • operands
    アインシュタイン縮約を計算するテンソルを指定します。

torch.einsumの使い方・入出力例

どのような挙動をするのか分かるような具体的なコード例を挙げます。

対角成分の和

# 対角成分の和 (Trace)
matrix = torch.randn(4, 4)
trace = torch.einsum('ii', matrix)
print("対角成分の和")
print(trace)  # 行列の対角成分の和(トレース)が出力されます。
print(trace.shape) # 出力: torch.Size([])

対角行列

# 対角行列
matrix = torch.randn(4, 4)
diag_matrix = torch.einsum('ii->i', matrix)
print("対角行列")
print(diag_matrix)  # 行列の対角成分が出力されます。
print(diag_matrix.shape) # 出力: torch.Size([4])

転置

# ベクトルの作成
x = torch.tensor([[1, 2, 3]])

# 転置
y = torch.einsum('ij->ji', x)

print(y) # 出力: tensor([[1], [2], [3]])

合計

# ベクトルの作成
x = torch.tensor([1, 2, 3])

# 合計
y = torch.einsum('i->', x)

print(y) # 出力: tensor(6)

列ごとの合計

# 列ごとの合計
matrix = torch.randn(4, 4)
col_sum = torch.einsum('ij->j', matrix)
print("列ごとの合計")
print(col_sum)  # 各列の合計が出力されます。
print(col_sum.shape) # 出力: torch.Size([4])

行ごとの合計

# 行ごとの合計
matrix = torch.randn(4, 4)
row_sum = torch.einsum('ij->i', matrix)
print("行ごとの合計")
print(row_sum)  # 各行の合計が出力されます。
print(row_sum.shape) # 出力: torch.Size([4])

全成分の合計

# 全成分の合計
matrix = torch.randn(4, 4)
total_sum = torch.einsum('ij->', matrix)
print("全成分の合計")
print(total_sum)  # 全成分の合計が出力されます。
print(total_sum.shape) # 出力: torch.Size([])

要素ごとの積

# ベクトルの作成
x = torch.tensor([1, 2, 3])
z = torch.tensor([4, 5, 6])

# 要素ごとの積
y = torch.einsum('i,i->i', x, z)

print(y) # 出力: tensor([ 4, 10, 18])

外積

# 外積
x = torch.randn(5)
y = torch.randn(4)
outer_product = torch.einsum('i,j->ij', x, y)
print("外積")
print(outer_product)  # ベクトルxとyの外積が出力されます。
print(outer_product.shape) # 出力: torch.Size([5, 4]
# ベクトルの作成
x = torch.tensor([1, 2, 3])
z = torch.tensor([4, 5, 6])

# 外積
y = torch.einsum('i,j->ij', x, z)

print(y) # 出力: tensor([[ 4,  5,  6],
#                        [ 8, 10, 12],
#                        [12, 15, 18]])

ベクトルの内積

# ベクトルの内積
x = torch.randn(5)
y = torch.randn(5)
inner_product = torch.einsum('i,i->', x, y)
print("ベクトルの内積")
print(inner_product)  # ベクトルxとyの内積が出力されます。
print(inner_product.shape) # 出力: torch.Size([])
# ベクトルの作成
x = torch.tensor([1, 2, 3])
z = torch.tensor([4, 5, 6])

# ドット積
y = torch.einsum('i,i->', x, z)

print(y) # 出力: tensor(32)

行列-ベクトル積

# 行列とベクトルの作成
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
z = torch.tensor([7, 8, 9])

# 行列-ベクトル積
y = torch.einsum('ij,j->i', x, z)

print(y) # 出力: tensor([ 50, 122])

行列-行列積

# 行列の作成
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
z = torch.tensor([[7, 8], [9, 10], [11, 12]])

# 行列-行列積
y = torch.einsum('ik,kj->ij', x, z)

print(y) # 出力: tensor([[ 58,  64],
#                        [139, 154]])

バッチ行列積

# バッチ行列積
As = torch.randn(3, 2, 5)
Bs = torch.randn(3, 5, 4)
batched_matmul = torch.einsum('bij,bjk->bik', As, Bs)
print("バッチ行列積")
print(batched_matmul)  # バッチ行列積が出力されます。
print(batched_matmul.shape) # 出力: torch.Size([3, 2, 4])
# 3次元テンソルの作成
x = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
z = torch.tensor([[[13, 14], [15, 16], [17, 18]], [[19, 20], [21, 22], [23, 24]]])

# バッチ行列積
y = torch.einsum('bik,bkj->bij', x, z)

print(y) # 出力: tensor([[[ 94, 100],
#                        [229, 244]],
#                        [[508, 532],
#                        [697, 730]]])

まとめ

torch.einsumはアインシュタインの縮約規則を利用して複雑なテンソル操作を短い文字列を使って表現できる関数です。
複雑なテンソル操作を行いたいときでも、コードは簡潔で読みやすくなります。
また、torch.einsumはエラーの可能性を低減し、テンソル操作の結果を予測しやすくします。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次