【Pytorch】torch.sumの使い方・引数を徹底解説!dim=-1, 0, (1, 1)などの意味とは?
torch.sumの役割は、与えられたテンソルの要素の合計を返すことです。
イメージを掴もう
torch.sum
の引数
引数は、torch.sum(input, dim, keepdim=False, *, dtype=None)
です。
それぞれの意味を見ていきます。
- input
入力テンソルを表します。 - dim
削減する次元を表します。 - keepdim
出力テンソルにdim
を保持するかどうかを決めます。
デフォルトはFalse
です。 - dtype
返されるテンソルの希望するデータ型を設定します。
指定された場合、入力テンソルは演算が行われる前にdtype
にキャストされます。
dimとkeepdimについてはもっと詳しく説明します。
dim
の詳しい説明
dim
を設定しない場合dim
をint型で設定する場合dim
をtuple型で設定する場合
に分けられます。
dim=-1
などは、「int型で設定する場合」で説明しています。
dim
を設定しない場合
まずは、dim
を設定しなかった場合にどのような働きをするのか見ていきます。
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
"""
torch.sum(m)
# tensor(120)
dim
を設定しない場合は、テンソルのすべての要素の合計が返ってくることが分かります。
dim
をint型で設定する場合
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
"""
次に、dim
をint型で設定する場合にどのような働きをするのか見ていきます。
上記のテンソルの場合に設定できるdim
は、-2~1の整数です。
それ以外の整数を入れると、次のようなエラーが出ます。IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
# 以下のテンソルをもとにそれぞれのdimの場合を見ていきます
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
"""
dim=1
の場合torch.sum(m, dim=1)
# tensor([ 6, 22, 38, 54])
同じ行の要素を合計していることが分かります。
dim=0
の場合torch.sum(m, dim=0)
# tensor([24, 28, 32, 36])
同じ列の要素を合計していることが分かります。
dim=-1
の場合torch.sum(m, dim=-1)
# tensor([ 6, 22, 38, 54])
同じ行の要素を合計していることが分かります。
この結果は、dim=1
の時と同じですね。これは、i_list = [0, 1]
のリストにおいてi_list[-1]
が1であることからも納得いくのではないでしょうか。
dim=-2
の場合torch.sum(m, dim=-2)
# tensor([24, 28, 32, 36])
同じ列の要素を合計していることが分かります。
この結果は、dim=0
の時と同じです。これもリストで考えると分かると思います。
dim
をtuple型で設定する場合
続いて、dim
をtuple型で設定する場合を見ていきます。
同じように以下のテンソルの場合で見ていきます。
m = torch.arange(16).reshape(4, -1)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
"""
dim=(1, )
の場合torch.sum(m, dim=(1, ))
# tensor([ 6, 22, 38, 54])
このような与え方をした場合は、int型で与えた場合と同じ挙動をします。
dim=(1, 0)
の場合torch.sum(m, dim=(1, 0))
# tensor(120)
これは、torch.sum(m, dim=1)
としてから、torch.sum(m, dim=0)
とした場合と同じ挙動をしています。
keepdim
の詳しい説明
keepdimを設定した場合にどのような挙動をするのか見ていきます。
keepdim=False
の場合この場合は、keepdim
を設定しない場合と同じ挙動をします。
keepdim=True
の場合torch.sum(m, dim=1, keepdim=True)
"""
tensor([[ 6],
[22],
[38],
[54]])
"""
# torch.Size([4, 1])
keepdim=True
とすると、次元数を削減せずに合計処理します。
なので、「出力テンソルにdim
を保持するかどうかを決めます」という説明だったんですね。
まとめ
torch.sum()
はテンソルの合計を求めることができるdim
はどの次元を削減するのかを設定できるkeepdim
を設定すると次元を削減せずに合計できる
コメント