【Pytorch】torch.powとは?引数を徹底解説!Integers to negative integer powers are not allowed.と出た場合の対処法も紹介
torch.powはどのような働きをするのかを調べてみました。
目次
torch.pow
とは
torch.pow
は、入力の各要素のべき乗をexponent
で表し、その結果をテンソルとして返します。
torch.pow
の引数
では、torch.pow
の引数を見ていきます。
torch.pow
の引数は、torch.pow(input, exponent, *, out=None)
です。
それぞれの引数の説明を見ていきます。
input
入力テンソルです。
exponent
指数の値です。
out
出力テンソルです。
exponent
について、例をみながらどのような働きをするのか見ていきます。
exponent
の詳細
- 負数の場合
- 0の場合
- 正数の場合
- 少数の場合
- テンソルの場合
を見ていきます。
負数の場合
m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
"""
上のテンソルをinput
として見ていきます。
torch.pow(m, -2)
"""
RuntimeError: Integers to negative integer powers are not allowed.
"""
エラーが出ましたね。
このような場合は、input
のテンソルのタイプをfloatにすると大丈夫です。
m = torch.arange(12, dtype=float).reshape(3, 4)
"""
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]], dtype=torch.float64)
"""
torch.pow(m, -2)
"""
tensor([[ inf, 1.0000, 0.2500, 0.1111],
[0.0625, 0.0400, 0.0278, 0.0204],
[0.0156, 0.0123, 0.0100, 0.0083]], dtype=torch.float64)
"""
それぞれの要素が-2乗されているのが分かります。
0の場合
m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
"""
torch.pow(m, 0)
"""
tensor([[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]])
"""
0の場合も同様に各要素が0乗されているのが分かります。
正数の場合
m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
"""
torch.pow(m, 2)
"""
tensor([[ 0, 1, 4, 9],
[ 16, 25, 36, 49],
[ 64, 81, 100, 121]])
"""
各要素が2乗されているのが分かります。
少数の場合
m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
"""
torch.pow(m, 2.5)
"""
tensor([[ 0.0000, 1.0000, 5.6569, 15.5885],
[ 32.0000, 55.9017, 88.1816, 129.6418],
[181.0193, 243.0000, 316.2278, 401.3116]])
"""
各要素が2.5乗されているのが分かります。
テンソルの場合
テンソルをexponent
に設定する場合は、input
のテンソルとサイズを同じにする必要があります。
m1 = torch.arange(6).reshape(2, 3)
"""
tensor([[0, 1, 2],
[3, 4, 5]])
"""
m2 = torch.arange(6).reshape(3, 2)
"""
tensor([[0, 1],
[2, 3],
[4, 5]])
"""
torch.pow(m1, m2)
"""
RuntimeError: The size of tensor a (3) must match
the size of tensor b (2) at non-singleton dimension 1
"""
上記のようなエラーが出ます。
では、サイズを揃えてテンソルを設定した場合を見ていきます。
m1 = torch.arange(6).reshape(2, 3)
"""
tensor([[0, 1, 2],
[3, 4, 5]])
"""
m2 = torch.arange(6).reshape(2, 3)
"""
tensor([[0, 1, 2],
[3, 4, 5]])
"""
torch.pow(m1, m2)
"""
tensor([[ 1, 1, 4],
[ 27, 256, 3125]])
"""
(m1の各要素)の(m2の各要素)乗が出力になっています。
まとめ
torch.pow
は、入力の各要素をexponentの値を指数としてべき乗する- exponentには、数またはテンソルを入力できる
【Pytorch】nn.Linearの引数・ソースコードを徹底解説!
torch.nn.Linearは基本的な全結合層です。今回は、nn.Linearの引数とソースコードをしっかりと説明していきます。曖昧な理解を直していきましょう。
コメント