【Pytorch】torch.narrowの意味と引数について徹底解説!

torch.narrowについて、例を用いて使い方をまとめてみました。

目次

torch.narrowとは

torch.narrowは、入力テンソルを狭めた新しいテンソルを返すことができます。

narrow(狭まる)という意味から分かりますね。

torch.narrowの引数

では、torch.narrowの引数について見ていきます。
torch.catの引数は、torch.narrow(inputdimstartlength)です。
それぞれの引数の説明を見ていきます。

input

入力テンソルです。

dim

狭める次元を設定できます。

start

始点となる次元です。

length

終了する次元までの長さです。

torch.narrowの注意点

dimで指定した次元における長さは、start+lengthの長さを超えていなければいけません。逆に、start+lengthの長さは、dimで指定した次元における長さ以下でないといけません。
この条件を満たさない場合は、RuntimeError: start (0) + length (6) exceeds dimension size (4).
のようなエラーが出ます。

それぞれの引数の働きを検証

説明文だけでは良く分からないので、実際にコードを動かしてそれぞれの引数の働きを確かめます。

m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
"""

上記のテンソルを基準にそれぞれの引数の働きを見ていきます。

dim=0の場合

# dim=0, start=0, length=1
torch.narrow(m, 0, 0, 1)
"""
tensor([[0, 1, 2, 3]])
"""

# dim=0, start=0, length=2
torch.narrow(m, 0, 0, 2)
"""
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])
"""

# dim=0, start=1, length=2
torch.narrow(m, 0, 1, 2)
"""
tensor([[ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
"""

lengthは、start基準でいくつに狭めるのかを設定するものだと分かります。
インデックスを指定している訳ではないので注意が必要ですね。

dim=1の場合

# dim=1, start=0, length=1
torch.narrow(m, 1, 0, 1)
"""
tensor([[0],
        [4],
        [8]])
"""

# dim=1, start=0, length=2
torch.narrow(m, 1, 0, 2)
"""
tensor([[0, 1],
        [4, 5],
        [8, 9]])
"""

# dim=1, start=1, length=2
torch.narrow(m, 0, 1, 2)
"""
tensor([[ 1,  2],
        [ 5,  6],
        [ 9, 10]])
"""

dim=1の場合も同様にlengthは、start基準でいくつに狭めるのかを設定するものだと分かります。

dimについて

dimという引数は、torch.cattorch.sumなどにも出てきます。

また、dim=-1のように負の数を指定する場合もよく見かけます。

これは、テンソルの次元数を考えると良く分かります。

次元の数が長さのリストを考えてみましょう。

a = [0, 1, 2]

上記のリストで2を取り出す時は、a[2]とすれば取得できますよね。または、リストの最後の要素を取得するという意味で捉えるとa[-1]でも取得できます。

つまり、dim=-1とする場合も同様に考えることができます。

まとめ

  • troch.narrowは入力テンソルを狭めた新しいテンソルを返す
  • start~start+lengthまでを返す
あわせて読みたい
【Pytorch】torch.catの意味や引数を詳しく解説!dim=-1を設定した場合は? torch.catはgithubでモデルの構成を見ていると、かなり登場するので色々と調べてみました。例を使って色々と試しています。 【の意味】 torch.catの役割は、与えられた...
あわせて読みたい
【Pytorch】torch.sumのdim=-1, 0, (1, 1)などの意味とは?実際にコードを動かして検証してみた githubでコードを見ていると、torch.sum(m, dim=-1)のように傍目では意味が分からない部分が出てきたので、torch.sum()について色々と調べてみました。 【の役割】 torc...

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次