Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.powとは?引数を徹底解説!Integers to negative integer powers are not allowed.と出た場合の対処法も紹介

torch.powはどのような働きをするのかを調べてみました。

目次

torch.powとは

torch.powは、入力の各要素のべき乗をexponentで表し、その結果をテンソルとして返します。

torch.powの引数

では、torch.powの引数を見ていきます。

torch.powの引数は、torch.pow(inputexponent*out=None)です。

それぞれの引数の説明を見ていきます。

input

入力テンソルです。

exponent

指数の値です。

out

出力テンソルです。

exponentについて、例をみながらどのような働きをするのか見ていきます。

exponentの詳細

  • 負数の場合
  • 0の場合
  • 正数の場合
  • 少数の場合
  • テンソルの場合

を見ていきます。

負数の場合

m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
"""

上のテンソルをinputとして見ていきます。

torch.pow(m, -2)
"""
RuntimeError: Integers to negative integer powers are not allowed.
"""

エラーが出ましたね。

このような場合は、inputのテンソルのタイプをfloatにすると大丈夫です。

m = torch.arange(12, dtype=float).reshape(3, 4)
"""
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]], dtype=torch.float64)
"""
torch.pow(m, -2)
"""
tensor([[   inf, 1.0000, 0.2500, 0.1111],
        [0.0625, 0.0400, 0.0278, 0.0204],
        [0.0156, 0.0123, 0.0100, 0.0083]], dtype=torch.float64)
"""

それぞれの要素が-2乗されているのが分かります。

0の場合

m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
"""

torch.pow(m, 0)
"""
tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])
"""

0の場合も同様に各要素が0乗されているのが分かります。

正数の場合

m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
"""

torch.pow(m, 2)
"""
tensor([[  0,   1,   4,   9],
        [ 16,  25,  36,  49],
        [ 64,  81, 100, 121]])
"""

各要素が2乗されているのが分かります。

少数の場合

m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
"""

torch.pow(m, 2.5)
"""
tensor([[  0.0000,   1.0000,   5.6569,  15.5885],
        [ 32.0000,  55.9017,  88.1816, 129.6418],
        [181.0193, 243.0000, 316.2278, 401.3116]])
"""

各要素が2.5乗されているのが分かります。

テンソルの場合

テンソルをexponentに設定する場合は、inputのテンソルとサイズを同じにする必要があります。

m1 = torch.arange(6).reshape(2, 3)
"""
tensor([[0, 1, 2],
        [3, 4, 5]])
"""
m2 = torch.arange(6).reshape(3, 2)
"""
tensor([[0, 1],
        [2, 3],
        [4, 5]])
"""

torch.pow(m1, m2)
"""
RuntimeError: The size of tensor a (3) must match 
the size of tensor b (2) at non-singleton dimension 1
"""

上記のようなエラーが出ます。

では、サイズを揃えてテンソルを設定した場合を見ていきます。

m1 = torch.arange(6).reshape(2, 3)
"""
tensor([[0, 1, 2],
        [3, 4, 5]])
"""
m2 = torch.arange(6).reshape(2, 3)
"""
tensor([[0, 1, 2],
        [3, 4, 5]])
"""

torch.pow(m1, m2)
"""
tensor([[   1,    1,    4],
        [  27,  256, 3125]])
"""

(m1の各要素)の(m2の各要素)乗が出力になっています。

まとめ

  • torch.powは、入力の各要素をexponentの値を指数としてべき乗する
  • exponentには、数またはテンソルを入力できる
あわせて読みたい
【Pytorch】torch.catの使い方・引数を徹底解説!テンソルを結合する方法 torch.catはPyTorchでテンソル(多次元配列)を結合するための関数です。この関数を使用すると、指定した次元でテンソルを連結することができます。 ドキュメント:torc...
あわせて読みたい
【Pytorch】nn.Linearの引数・ソースコードを徹底解説! torch.nn.Linearは基本的な全結合層です。今回は、nn.Linearの引数とソースコードをしっかりと説明していきます。曖昧な理解を直していきましょう。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次