Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.topkの意味・使い方・引数を徹底解説!テンソルの最大値や最小値に対応する要素やインデックスを取得する方法

torch.topkは、PyTorchのテンソル(Tensor)に対して、指定した次元ごとにk個の最大値または最小値を見つけるための関数です。
この関数は、指定した次元で最大値または最小値のk個の要素と、それらの要素のインデックスを返します。
言語モデルなどの引数にもあると思います。

ドキュメント:torch.topk — PyTorch 2.0 documentation

目次

イメージを掴もう

使い方などを見る前にtorch.topkのイメージを把握しておきましょう。

画像では、ある集団から大きい要素を2つ取ってきています。

torch.topkの引数

torch.topkはこのように利用することができます。

import torch

# 2次元テンソルを作成
tensor = torch.tensor([[4, 2, 8],
                       [7, 6, 1]])

# 最大値のインデックスを求める
values, indices = torch.topk(tensor, k=2)

print(values)
# 出力: tensor([[8, 4],
#               [7, 6]])

print(indices)
# 出力: tensor([[2, 0],
#               [0, 1]])

では、引数について詳しく見ていきます。

  • input
    入力テンソルです。
    一般的には1次元または多次元のテンソルが使用されます。
  • k
    返される要素の数を示す整数値です。
    入力テンソルの各次元に対して、最大値または最小値のk個の要素が返されます。
  • dim
    ソートする次元を指定するための引数です。
    デフォルトでは、最後の次元(-1)が選択されます。
    指定した次元ごとに最大値または最小値が計算されます。
  • largest
    largest=Trueの場合、最大値を返します。
    largest=Falseの場合、最小値を返します。デフォルトはTrueです。
  • sorted
    sorted=Trueの場合、返された結果はソートされます。
    sorted=Falseの場合、ソートされません。デフォルトはTrueです。
  • out
    出力テンソルの場所を指定するための引数です。
    指定することで、既存のテンソルを再利用することができます。
    出力は(namedtuple)として返され、その要素はvaluesindicesです。
    valuesは最大値または最小値のk個の要素を含み、indicesはそれらの要素のインデックスを含みます。

torch.topkの使い方・入出力例

入力と出力例を見てどうなっているのか把握しましょう。

import torch

# 入力テンソルの作成
input_tensor = torch.tensor([[4, 2, 8],
                             [7, 6, 1]])

# dim=Noneの場合(最後の次元を選択)
values, indices = torch.topk(input_tensor, k=2)
print("dim=None:")
print("Values:", values)
print("Indices:", indices)
# 出力:
# Values: tensor([[8, 4],
#                 [7, 6]])
# Indices: tensor([[2, 0],
#                 [0, 1]])

# dim=0の場合(列ごとに最大値を選択)
values, indices = torch.topk(input_tensor, k=2, dim=0)
print("dim=0:")
print("Values:", values)
print("Indices:", indices)
# 出力:
# Values: tensor([[7, 6, 8],
#                 [4, 2, 1]])
# Indices: tensor([[1, 1, 0],
#                 [0, 0, 1]])

# sorted=True, largest=Trueの場合
values, indices = torch.topk(input_tensor, k=2, sorted=True, largest=True)
print("sorted=True, largest=True:")
print("Values:", values)
print("Indices:", indices)
# 出力:
# Values: tensor([[8, 4],
#                 [7, 6]])
# Indices: tensor([[2, 0],
#                 [0, 1]])

# largest=Falseの場合(最小値を選択)
values, indices = torch.topk(input_tensor, k=2, largest=False)
print("largest=False:")
print("Values:", values)
print("Indices:", indices)
# 出力:
# Values: tensor([[2, 4],
#                 [1, 6]])
# Indices: tensor([[1, 0],
#                 [2, 1]])

# sorted=Falseの場合(ソートされない)
values, indices = torch.topk(input_tensor, k=2, sorted=False)
print("sorted=False:")
print("Values:", values)
print("Indices:", indices)
# 出力:
# Values: tensor([[8, 4],
#                 [7, 6]])
# Indices: tensor([[2, 0],
#                 [0, 1]])

# out引数を使用せずに結果を取得する例
values, indices = torch.topk(input_tensor, k=2, dim=1)
print("Without out argument:")
print("Values:", values)
print("Indices:", indices)
# 出力:
# Values: tensor([[8, 4],
#                 [7, 6]])
# Indices: tensor([[2, 0],
#                 [0, 1]])

# out引数を使用して結果を既存のテンソルに格納する例
out_values = torch.empty(2, 2, dtype=torch.int64)
out_indices = torch.empty(2, 2, dtype=torch.int64)
torch.topk(input_tensor, k=2, dim=1, out=(out_values, out_indices))
print("Using out argument:")
print("Values:", out_values)
print("Indices:", out_indices)
# 出力:
# Values: tensor([[8, 4],
#                 [7, 6]])
# Indices: tensor([[2, 0],
#                 [0, 1]])

まとめ

torch.topkは、ディープラーニングのタスクやデータ処理のさまざまな場面で役立ちます。
最大値または最小値とそのインデックスを取得することで、データの解析や選択において重要な情報を得ることができる強力な関数です。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次