Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】torch.argmaxの使い方・引数を徹底解説!インデックスを取得する理由についても解説!

PyTorchのargmax関数は、データの中から最大値を持つ要素のインデックスを取得するための便利な関数です。
データ解析や機械学習の実装において、重要な役割を果たします。

ドキュメント: torch.argmax — PyTorch 2.0 documentation

目次

イメージを把握しよう

使い方などを見る前にtorch.argmaxのイメージを把握しておきましょう。

このようにデータ内にある一番大きい要素のインデックス(Pythonは0から数える)を取得できます。

複数の最大値や最小値を取得したい場合はtorch.topkを利用できます。

あわせて読みたい
【Pytorch】torch.topkの意味・使い方・引数を徹底解説!テンソルの最大値や最小値に対応する要素やイン... torch.topkは、PyTorchのテンソル(Tensor)に対して、指定した次元ごとにk個の最大値または最小値を見つけるための関数です。この関数は、指定した次元で最大値または...

torch.argmaxの引数

torch.argmaxは次のように簡単に利用することができます。

import torch

# テンソルの作成
tensor = torch.tensor([1, 3, 5, 2, 4])

# 最大値のインデックスを取得
max_index = torch.argmax(tensor)

print(max_index)  # 出力: tensor(2)

# テンソルの特定のインデックスの値を取得
max_value = tensor[max_index]

print(max_value)  # 出力: tensor(5)

では、引数について見ていきましょう。

  • input
    入力テンソルです。torch.argmax関数は、このテンソル内の最大値のインデックスを返します。
  • dim
    縮約する次元を指定します。もしNoneの場合、平坦化された入力のargmaxが返されます。
    dimを指定すると、指定した次元ごとに最大値のインデックスが返されます。
    dim=0とすることで列ごとの最大値のインデックスが、dim=1とすることで行ごとの最大値のインデックスが取得されます。
    値として負の値を指定できます。たとえば、dim=-1は最後の次元を表します。
  • keepdim
    出力テンソルが指定した次元を保持するかどうかを指定します。
    dim=Noneの場合は無視されます。
    keepdim=False(デフォルト)の場合、出力テンソルは指定した次元を保持せずに返されます。

torch.argmaxの使い方・入出力例

入力と出力例を見るとどうなっているのか分かりやすいと思います。

dimの影響を理解する

import torch

# 1次元テンソルの場合
input_tensor = torch.tensor([3, 1, 5, 2, 4])
max_index = torch.argmax(input_tensor)
print(max_index)  # 出力: tensor(2)
# input_tensorの最大値のインデックスは2

# 2次元テンソルの場合
input_tensor = torch.tensor([[3, 1, 5], [2, 4, 6]])
max_index = torch.argmax(input_tensor, dim=1)
print(max_index)  # 出力: tensor([2, 2])
# dim=1により、各行で最大値のインデックスを取得
# 最初の行の最大値のインデックスは2、2番目の行の最大値のインデックスも2

# 2次元テンソルの場合
input_tensor = torch.tensor([[3, 1, 5], [2, 4, 6]])
max_index = torch.argmax(input_tensor, dim=-1)
print(max_index)  # 出力: tensor([2, 2])
# dim=-1は最後の次元を表し、この場合は各行での最大値のインデックスを取得しています

# 3次元テンソルの場合
input_tensor = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
max_index = torch.argmax(input_tensor, dim=-2)
print(max_index)  # 出力: tensor([[1, 1, 1], [1, 1, 1]])
# dim=-2は最後から2番目の次元を表し、この場合は各列ごとに最大値のインデックスを取得しています

keepdimの影響を理解する

import torch

# 2次元テンソルの場合
input_tensor = torch.tensor([[3, 1, 5], [2, 4, 6]])
max_index = torch.argmax(input_tensor, dim=1, keepdim=False)
print(max_index)  # 出力: tensor([2, 2])
# keepdim=Falseの場合、出力テンソルは指定した次元を保持せずに返される

max_index = torch.argmax(input_tensor, dim=1, keepdim=True)
print(max_index)  # 出力: tensor([[2], [2]])
# keepdim=Trueの場合、出力テンソルは指定した次元を保持して返される

# 3次元テンソルの場合
input_tensor = torch.tensor([[[3, 1, 5], [2, 4, 6]], [[7, 2, 4], [6, 9, 2]]])
max_index = torch.argmax(input_tensor, dim=2, keepdim=False)
print(max_index)
# 出力: tensor([[2, 2],
#                [0, 1]])
# keepdim=Falseの場合、出力テンソルは指定した次元を保持せずに返される

max_index = torch.argmax(input_tensor, dim=2, keepdim=True)
print(max_index)
# 出力: tensor([[[2],
#                [2]],
#
#               [[0],
#                [1]]])
# keepdim=Trueの場合、出力テンソルは指定した次元を保持して返される

複数最大値があった場合の挙動を理解する

import torch

# 1次元テンソルの場合
input_tensor = torch.tensor([3, 1, 5, 5, 2, 4])
max_index = torch.argmax(input_tensor)
print(max_index)  # 出力: tensor(2)
# input_tensorの最大値のインデックスは2

# 複数の最大値がある場合
input_tensor = torch.tensor([3, 1, 5, 5, 2, 4])
max_indices = torch.where(input_tensor == torch.max(input_tensor))[0]
print(max_indices)  # 出力: tensor([2, 3])
# torch.whereを使用して、最大値と同じ値を持つインデックスを取得

# インデックスをランダムに選択する場合
import random
random_max_index = random.choice(max_indices)
print(random_max_index)  # 出力: 2または3
# max_indicesからランダムに1つのインデックスを選択する

# すべての最大値のインデックスを取得する場合
all_max_indices = torch.nonzero(input_tensor == torch.max(input_tensor)).squeeze()
print(all_max_indices)  # 出力: tensor([2, 3])
# torch.nonzeroを使用して、最大値と同じ値を持つすべてのインデックスを取得
# squeeze()を使用して、不要な次元を削除

このコード例では、複数の最大値が存在する場合の挙動と対処法を示しています。
最初の例では、torch.argmaxを使用して最大値のインデックスを取得しましたが、この場合、最初に見つかった最大値のインデックスしか返されません。
そのため、2つ目の例では、torch.whereを使用して、最大値と同じ値を持つすべてのインデックスを取得しています。これにより、複数の最大値がある場合でも、すべての最大値のインデックスを取得できます。
さらに、ランダムに1つの最大値のインデックスを選択する場合には、random.choiceを使用してmax_indicesからランダムに1つのインデックスを選択します。
また、最後の例では、torch.nonzeroを使用して、最大値と同じ値を持つすべてのインデックスを取得しています。これにより、すべての最大値のインデックスを取得することができます。
これらの対処法を使用することで、複数の最大値がある場合に適切に対処することができます。

インデックスを取得する理由

torch.argmax関数は、テンソル内の最大値のインデックスを返すために設計されています。

データ内の最大値そのものではなく、その位置(インデックス)を取得する理由にはいくつかのメリットがあります。

  • データ内の最大値だけではなく、その位置情報も重要な場合があります。たとえば、最大値が出現する位置を特定する必要がある場合や、最大値を含む要素にアクセスする必要がある場合などがあります。インデックスの取得により、これらの操作を簡単に実行することができます。
  • 最大値そのものを取得するためには、すべての要素を比較していく必要があります。一方、インデックスを取得するためには、要素の比較が必要なくなります。インデックスを返すことで、最大値の位置を直接特定することができ、演算の効率が向上します。
  • torch.argmaxは多次元配列にも適用できます。この場合、各次元ごとに最大値の位置を取得することができます。たとえば、2次元のテンソルでは、行ごとまたは列ごとに最大値の位置を取得することができます。

理由はこれだけではないかもしれませんが、こんな感じだと思います。

まとめ

torch.argmax関数は、テンソル内の最大値のインデックスを取得するための関数です。
様々な利点があります。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次