Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.sumの使い方・引数を徹底解説!dim=-1, 0, (1, 1)などの意味とは?

torch.sumの役割は、与えられたテンソルの要素の合計を返すことです。

目次

イメージを掴もう

torch.sumの引数

引数は、torch.sum(input, dim, keepdim=False, *, dtype=None) です。
それぞれの意味を見ていきます。

  • input
    入力テンソルを表します。
  • dim
    削減する次元を表します。
  • keepdim
    出力テンソルにdimを保持するかどうかを決めます。
    デフォルトはFalseです。
  • dtype
    返されるテンソルの希望するデータ型を設定します。
    指定された場合、入力テンソルは演算が行われる前にdtypeにキャストされます。

dimとkeepdimについてはもっと詳しく説明します。

dimの詳しい説明

  • dimを設定しない場合
  • dimをint型で設定する場合
  • dimをtuple型で設定する場合

に分けられます。

dim=-1などは、「int型で設定する場合」で説明しています。

dimを設定しない場合

まずは、dimを設定しなかった場合にどのような働きをするのか見ていきます。

m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
"""

torch.sum(m)
# tensor(120)

dimを設定しない場合は、テンソルのすべての要素の合計が返ってくることが分かります。

dimをint型で設定する場合

m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
"""

次に、dimをint型で設定する場合にどのような働きをするのか見ていきます。
上記のテンソルの場合に設定できるdimは、-2~1の整数です。
それ以外の整数を入れると、次のようなエラーが出ます。
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

上記のエラーが出る理由は、mの次元数が2でそれ以上の値を設定したからです。
入力テンソルの次元数を増やせば、入れられる整数の範囲も広がります。式で表すと、-(入力テンソルの次元数)~(入力テンソルの次元数-1)になります。
今回は、分かりやすくするために4×4の2次元のテンソルの場合を見ていきます。

# 以下のテンソルをもとにそれぞれのdimの場合を見ていきます
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
"""
dim=1の場合
torch.sum(m, dim=1)
# tensor([ 6, 22, 38, 54])

同じ行の要素を合計していることが分かります。

dim=0の場合
torch.sum(m, dim=0)
# tensor([24, 28, 32, 36])

同じ列の要素を合計していることが分かります。

dim=-1の場合
torch.sum(m, dim=-1)
# tensor([ 6, 22, 38, 54])

同じ行の要素を合計していることが分かります。
この結果は、dim=1の時と同じですね。これは、i_list = [0, 1]のリストにおいてi_list[-1]が1であることからも納得いくのではないでしょうか。

dim=-2の場合
torch.sum(m, dim=-2)
# tensor([24, 28, 32, 36])

同じ列の要素を合計していることが分かります。
この結果は、dim=0の時と同じです。これもリストで考えると分かると思います。

dimをtuple型で設定する場合

続いて、dimをtuple型で設定する場合を見ていきます。
同じように以下のテンソルの場合で見ていきます。

m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
"""
dim=(1, )の場合
torch.sum(m, dim=(1, ))
# tensor([ 6, 22, 38, 54])

このような与え方をした場合は、int型で与えた場合と同じ挙動をします。

dim=(1, 0)の場合
torch.sum(m, dim=(1, 0))
# tensor(120)

これは、torch.sum(m, dim=1)としてから、torch.sum(m, dim=0)とした場合と同じ挙動をしています。

keepdimの詳しい説明

keepdimを設定した場合にどのような挙動をするのか見ていきます。

keepdim=Falseの場合

この場合は、keepdimを設定しない場合と同じ挙動をします。

keepdim=Trueの場合
torch.sum(m, dim=1, keepdim=True)
"""
tensor([[ 6],
        [22],
        [38],
        [54]])
"""
# torch.Size([4, 1])

keepdim=Trueとすると、次元数を削減せずに合計処理します。
なので、「出力テンソルにdimを保持するかどうかを決めます」という説明だったんですね。

まとめ

  • torch.sum()はテンソルの合計を求めることができる
  • dimはどの次元を削減するのかを設定できる
  • keepdimを設定すると次元を削減せずに合計できる
あわせて読みたい
【Pytorch】torch.catの使い方・引数を徹底解説!テンソルを結合する方法 torch.catはPyTorchでテンソル(多次元配列)を結合するための関数です。この関数を使用すると、指定した次元でテンソルを連結することができます。 ドキュメント:torc...
あわせて読みたい
【PyTorch】ベクトルを操作する代表的な関数を解説!view, reshape, transpose, unsqueeze, matmul, ein... PyTorchを使っているコードを見ているとベクトルを操作する関数がたくさん出てくると思います。しかし、どんな機能なのかわからないと自分でコードを変更しようとしても難しいです。そこで今回は、実際にベクトルを操作しながらそれぞれの関数を解説していきます。
あわせて読みたい
【Pytorch】nn.Linearの引数・ソースコードを徹底解説! torch.nn.Linearは基本的な全結合層です。今回は、nn.Linearの引数とソースコードをしっかりと説明していきます。曖昧な理解を直していきましょう。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次