【Pytorch】torch.sumの使い方・引数を徹底解説!dim=-1, 0, (1, 1)などの意味とは?

torch.sumの役割は、与えられたテンソルの要素の合計を返すことです。
イメージを掴もう

torch.sumの引数
引数は、torch.sum(input, dim, keepdim=False, *, dtype=None) です。
それぞれの意味を見ていきます。
- input
入力テンソルを表します。 - dim
削減する次元を表します。 - keepdim
出力テンソルにdimを保持するかどうかを決めます。
デフォルトはFalseです。 - dtype
返されるテンソルの希望するデータ型を設定します。
指定された場合、入力テンソルは演算が行われる前にdtypeにキャストされます。
dimとkeepdimについてはもっと詳しく説明します。
dimの詳しい説明
dimを設定しない場合dimをint型で設定する場合dimをtuple型で設定する場合
に分けられます。
dim=-1などは、「int型で設定する場合」で説明しています。
dimを設定しない場合
まずは、dimを設定しなかった場合にどのような働きをするのか見ていきます。
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
"""
torch.sum(m)
# tensor(120)dimを設定しない場合は、テンソルのすべての要素の合計が返ってくることが分かります。
dimをint型で設定する場合
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
"""次に、dimをint型で設定する場合にどのような働きをするのか見ていきます。
上記のテンソルの場合に設定できるdimは、-2~1の整数です。
それ以外の整数を入れると、次のようなエラーが出ます。IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
# 以下のテンソルをもとにそれぞれのdimの場合を見ていきます
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
"""dim=1の場合torch.sum(m, dim=1)
# tensor([ 6, 22, 38, 54])同じ行の要素を合計していることが分かります。
dim=0の場合torch.sum(m, dim=0)
# tensor([24, 28, 32, 36])同じ列の要素を合計していることが分かります。
dim=-1の場合torch.sum(m, dim=-1)
# tensor([ 6, 22, 38, 54])同じ行の要素を合計していることが分かります。
この結果は、dim=1の時と同じですね。これは、i_list = [0, 1]のリストにおいてi_list[-1]が1であることからも納得いくのではないでしょうか。
dim=-2の場合torch.sum(m, dim=-2)
# tensor([24, 28, 32, 36])同じ列の要素を合計していることが分かります。
この結果は、dim=0の時と同じです。これもリストで考えると分かると思います。
dimをtuple型で設定する場合
続いて、dimをtuple型で設定する場合を見ていきます。
同じように以下のテンソルの場合で見ていきます。
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
"""dim=(1, )の場合torch.sum(m, dim=(1, ))
# tensor([ 6, 22, 38, 54])このような与え方をした場合は、int型で与えた場合と同じ挙動をします。
dim=(1, 0)の場合torch.sum(m, dim=(1, 0))
# tensor(120)これは、torch.sum(m, dim=1)としてから、torch.sum(m, dim=0)とした場合と同じ挙動をしています。
keepdimの詳しい説明
keepdimを設定した場合にどのような挙動をするのか見ていきます。
keepdim=Falseの場合この場合は、keepdimを設定しない場合と同じ挙動をします。
keepdim=Trueの場合torch.sum(m, dim=1, keepdim=True)
"""
tensor([[ 6],
[22],
[38],
[54]])
"""
# torch.Size([4, 1])keepdim=Trueとすると、次元数を削減せずに合計処理します。
なので、「出力テンソルにdimを保持するかどうかを決めます」という説明だったんですね。
まとめ
torch.sum()はテンソルの合計を求めることができるdimはどの次元を削減するのかを設定できるkeepdimを設定すると次元を削減せずに合計できる




コメント