【Pytorch】torch.narrowの意味と引数について徹底解説!

torch.narrow
について、例を用いて使い方をまとめてみました。
torch.narrow
とは
torch.narrowは、入力テンソルを狭めた新しいテンソルを返すことができます。
narrow(狭まる)という意味から分かりますね。
torch.narrow
の引数
では、torch.narrowの引数について見ていきます。
torch.catの引数は、torch.narrow(input, dim, start, length)
です。
それぞれの引数の説明を見ていきます。
入力テンソルです。
狭める次元を設定できます。
始点となる次元です。
終了する次元までの長さです。
torch.narrowの注意点
dim
で指定した次元における長さは、start
+length
の長さを超えていなければいけません。逆に、start+lengthの長さは、dimで指定した次元における長さ以下でないといけません。
この条件を満たさない場合は、RuntimeError: start (0) + length (6) exceeds dimension size (4).
のようなエラーが出ます。
それぞれの引数の働きを検証
説明文だけでは良く分からないので、実際にコードを動かしてそれぞれの引数の働きを確かめます。
m = torch.arange(12).reshape(3, 4)
"""
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
"""
上記のテンソルを基準にそれぞれの引数の働きを見ていきます。
dim=0
の場合
# dim=0, start=0, length=1
torch.narrow(m, 0, 0, 1)
"""
tensor([[0, 1, 2, 3]])
"""
# dim=0, start=0, length=2
torch.narrow(m, 0, 0, 2)
"""
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
"""
# dim=0, start=1, length=2
torch.narrow(m, 0, 1, 2)
"""
tensor([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
"""
length
は、start
基準でいくつに狭めるのかを設定するものだと分かります。
インデックスを指定している訳ではないので注意が必要ですね。
dim=1
の場合
# dim=1, start=0, length=1
torch.narrow(m, 1, 0, 1)
"""
tensor([[0],
[4],
[8]])
"""
# dim=1, start=0, length=2
torch.narrow(m, 1, 0, 2)
"""
tensor([[0, 1],
[4, 5],
[8, 9]])
"""
# dim=1, start=1, length=2
torch.narrow(m, 0, 1, 2)
"""
tensor([[ 1, 2],
[ 5, 6],
[ 9, 10]])
"""
dim=1の場合も同様にlength
は、start
基準でいくつに狭めるのかを設定するものだと分かります。
dim
について
dimという引数は、torch.catやtorch.sumなどにも出てきます。
また、dim=-1のように負の数を指定する場合もよく見かけます。
これは、テンソルの次元数を考えると良く分かります。
次元の数が長さのリストを考えてみましょう。
a = [0, 1, 2]
上記のリストで2を取り出す時は、a[2]
とすれば取得できますよね。または、リストの最後の要素を取得するという意味で捉えるとa[-1]
でも取得できます。
つまり、dim=-1とする場合も同様に考えることができます。
まとめ
troch.narrow
は入力テンソルを狭めた新しいテンソルを返すstart
~start
+length
までを返す


コメント