オープンソース版ChatGPT「Open Assistant」が開発中!Open Assistantの機能向上に直接貢献してみませんか?日本語データセットを作成する方法はこちらをクリック!

【Pytorch】torch.sumのdim=-1, 0, (1, 1)などの意味とは?実際にコードを動かして検証してみた

githubでコードを見ていると、torch.sum(m, dim=-1)のように傍目では意味が分からない部分が出てきたので、torch.sum()について色々と調べてみました。

目次

torch.sumの役割

torch.sum()の役割は、与えられたテンソルの要素の合計を返すことです。
これは、sumという単語から想像できますね。

torch.sumの引数

引数は、torch.sum(input, dim, keepdim=False, *, dtype=None) です。
それぞれの意味を見ていきます。

input

入力テンソルを表します。

dim

削減する次元を表します。

keepdim

出力テンソルにdimを保持するかどうかを決めます。
デフォルトはFalseです。

dtype

返されるテンソルの希望するデータ型を設定します。
指定された場合、入力テンソルは演算が行われる前にdtypeにキャストされます。

dimkeepdimの説明が良く分からないと思うので、実際にコードを動かして調べてみました。(私も説明文だけでは分かりませんでした)

dimの詳しい説明

  • dimを設定しない場合
  • dimをint型で設定する場合
  • dimをtuple型で設定する場合

に分けられます。

dim=-1などは、「int型で設定する場合」で説明しています。

dimを設定しない場合

まずは、dimを設定しなかった場合にどのような働きをするのか見ていきます。

m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
"""

torch.sum(m)
# tensor(120)

dimを設定しない場合は、テンソルのすべての要素の合計が返ってくることが分かります。

dimをint型で設定する場合

m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
"""

次に、dimをint型で設定する場合にどのような働きをするのか見ていきます。
上記のテンソルの場合に設定できるdimは、-2~1の整数です。
それ以外の整数を入れると、次のようなエラーが出ます。
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

上記のエラーが出る理由は、mの次元数が2でそれ以上の値を設定したからです。
入力テンソルの次元数を増やせば、入れられる整数の範囲も広がります。式で表すと、-(入力テンソルの次元数)~(入力テンソルの次元数-1)になります。
今回は、分かりやすくするために4×4の2次元のテンソルの場合を見ていきます。

# 以下のテンソルをもとにそれぞれのdimの場合を見ていきます
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
"""
dim=1の場合
torch.sum(m, dim=1)
# tensor([ 6, 22, 38, 54])

同じ行の要素を合計していることが分かります。

dim=0の場合
torch.sum(m, dim=0)
# tensor([24, 28, 32, 36])

同じ列の要素を合計していることが分かります。

dim=-1の場合
torch.sum(m, dim=-1)
# tensor([ 6, 22, 38, 54])

同じ行の要素を合計していることが分かります。
この結果は、dim=1の時と同じですね。これは、i_list = [0, 1]のリストにおいてi_list[-1]が1であることからも納得いくのではないでしょうか。

dim=-2の場合
torch.sum(m, dim=-2)
# tensor([24, 28, 32, 36])

同じ列の要素を合計していることが分かります。
この結果は、dim=0の時と同じです。これもリストで考えると分かると思います。

dimをtuple型で設定する場合

続いて、dimをtuple型で設定する場合を見ていきます。
同じように以下のテンソルの場合で見ていきます。

m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
"""
dim=(1, )の場合
torch.sum(m, dim=(1, ))
# tensor([ 6, 22, 38, 54])

このような与え方をした場合は、int型で与えた場合と同じ挙動をします。

dim=(1, 0)の場合
torch.sum(m, dim=(1, 0))
# tensor(120)

これは、torch.sum(m, dim=1)としてから、torch.sum(m, dim=0)とした場合と同じ挙動をしています。

keepdimの詳しい説明

keepdimを設定した場合にどのような挙動をするのか見ていきます。

keepdim=Falseの場合

この場合は、keepdimを設定しない場合と同じ挙動をします。

keepdim=Trueの場合
torch.sum(m, dim=1, keepdim=True)
"""
tensor([[ 6],
        [22],
        [38],
        [54]])
"""
# torch.Size([4, 1])

keepdim=Trueとすると、次元数を削減せずに合計処理します。
なので、「出力テンソルにdimを保持するかどうかを決めます」という説明だったんですね。

まとめ

  • torch.sum()はテンソルの合計を求めることができる
  • dimはどの次元を削減するのかを設定できる
  • keepdimを設定すると次元を削減せずに合計できる
あわせて読みたい
【Pytorch】torch.catの意味や引数を詳しく解説!dim=-1を設定した場合は? torch.catはgithubでモデルの構成を見ていると、かなり登場するので色々と調べてみました。例を使って色々と試しています。 【の意味】 torch.catの役割は、与えられた...
あわせて読みたい
【PyTorch】ベクトルを操作する代表的な関数を解説!view, reshape, transpose, unsqueeze, matmul, ein... PyTorchを使っているコードを見ているとベクトルを操作する関数がたくさん出てくると思います。しかし、どんな機能なのかわからないと自分でコードを変更しようとしても難しいです。そこで今回は、実際にベクトルを操作しながらそれぞれの関数を解説していきます。
あわせて読みたい
【Pytorch】nn.Linearの引数・ソースコードを徹底解説! torch.nn.Linearは基本的な全結合層です。今回は、nn.Linearの引数とソースコードをしっかりと説明していきます。曖昧な理解を直していきましょう。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次