Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】nn.Softmax・F.softmaxの使い方・引数・ソースコードを徹底解説!それぞれの違いも解説!

softmax関数は、入力されたベクトルを確率分布として解釈するための関数です。
各要素を正規化して、0から1の範囲に収めることで、各要素の値を確率として解釈することができます。
主に多クラス分類問題で最終的な出力の活性化関数として使用されます。

ドキュメント

torch.nn.functional.softmax — PyTorch 2.0 documentation

torch.nn.Softmax — PyTorch 2.0 documentation

目次

イメージを把握しよう

使い方などを見る前にSoftmax関数のイメージを把握しておきましょう。

勝手なイメージです。

ある要素を容器1の器に入れ込むために上手く操作して合計を1にするイメージです。

F.softmaxとnn.Softmaxの違い

F.softmax

  • 学習パラメータを持たない関数
  • 計算グラフ内でのみ使用され、パラメータの更新は行われない
  • 単一の計算ステップ内でのみ有効

nn.Softmax

  • 学習パラメータを持つモジュール
  • モデルの一部として定義され、学習中にパラメータの更新が行われる
  • モデルの保存と再利用が可能
あわせて読みたい
【Pytorch】nn.Module(nn.○○)とnn.functional(nn.functional.○○)の違いを徹底解説!どちらを使用す... PyTorchにはニューラルネットワークのモデルを構築するための2つの主要な方法があります。それがnn.Module(例えばnn.Linear、nn.Conv2dなど)とnn.functional(例えばn...

F.softmaxの引数

F.softmax関数は以下のように簡単に書くことができます。

import torch
import torch.nn.functional as F

# 入力テンソルの作成
inputs = torch.tensor([1.0, 2.0, 3.0])

# ソフトマックス関数の適用
probs = F.softmax(inputs, dim=0)

# 結果の出力
print(probs) # tensor([0.0900, 0.2447, 0.6652])

では、引数について見ていきましょう。

  • input
    ソフトマックス関数を適用する入力テンソル
  • dim
    ソフトマックス関数の計算対象とする次元を指定します。
    デフォルトではNoneであり、この場合は最後の次元に対してソフトマックス関数が適用されます。
    負の値を指定することもできます。
  • _stacklevel
    関数内で発生した例外やデバッグ情報に正しいスタックトレースを表示するために使用されます。このパラメータは通常、内部の関数呼び出しのレベルを指定するために使用されます。なのであまり動作には関係ないため今回は触りません。
  • dtype
    返されるテンソルのデータ型を指定します。
    指定された場合、入力テンソルは指定されたデータ型にキャストされます。
    データ型のオーバーフローを防ぐために使用することができます。

F.softmaxの使い方・入出力例

入力と出力の例を見て挙動を理解しましょう。

dimの影響を理解する

import torch
import torch.nn.functional as F

# 入力テンソルの作成
input_tensor = torch.tensor([[1.0, 2.0, 3.0],
                             [4.0, 5.0, 6.0]])

# dim=0の場合
output_dim0 = F.softmax(input_tensor, dim=0)
print("dim=0:\n", output_dim0)
# 出力結果:
# tensor([[0.0474, 0.0474, 0.0474],
#        [0.9526, 0.9526, 0.9526]])

# dim=1の場合
output_dim1 = F.softmax(input_tensor, dim=1)
print("dim=1:\n", output_dim1)
# 出力結果:
# tensor([[0.0900, 0.2447, 0.6652],
#        [0.0900, 0.2447, 0.6652]])

# dim=-1の場合(最後の次元)
output_dim_minus1 = F.softmax(input_tensor, dim=-1)
print("dim=-1:\n", output_dim_minus1)
# 出力結果:
# tensor([[0.0900, 0.2447, 0.6652],
#        [0.0900, 0.2447, 0.6652]])

# dim=-2の場合(最後から2番目の次元)
output_dim_minus2 = F.softmax(input_tensor, dim=-2)
print("dim=-2:\n", output_dim_minus2)
# 出力結果:
# tensor([[0.0474, 0.0474, 0.0474],
#        [0.9526, 0.9526, 0.9526]])

dtypeの影響を理解する

import torch
import torch.nn.functional as F

# 入力テンソルの作成
input_tensor = torch.tensor([1.0, 2.0, 3.0])

# ソフトマックス関数の適用(dtype指定あり)
output_float = F.softmax(input_tensor, dim=0, dtype=torch.float32)
output_double = F.softmax(input_tensor, dim=0, dtype=torch.float64)

print(output_float) # tensor([0.0900, 0.2447, 0.6652])
print(output_double) # tensor([0.0900, 0.2447, 0.6652], dtype=torch.float64)

F.softmaxのソースコードの解説

ソースコードはこちらになります。

def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor:
    # もしinputがtorch関数を持っている場合は、torch関数の処理を呼び出して返す
    if has_torch_function_unary(input):
        return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
    # もしdimが指定されていない場合は、softmaxを計算する次元を取得する
    if dim is None:
        dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
    # もしdtypeが指定されていない場合は、指定された次元でsoftmaxを計算する
    if dtype is None:
        ret = input.softmax(dim)
    else:
        # dtypeが指定されている場合は、指定された次元とdtypeでsoftmaxを計算する
        ret = input.softmax(dim, dtype=dtype)
    return ret

ではそれぞれの関数の役割を見ていきます。

has_torch_function_unary

handle_torch_function関数は、__torch_function__メソッドのオーバーライドをチェックするための機能を実装しています。
この関数は、PyTorchの公開APIで呼び出された関数に対して、引数をチェックする役割を担っています。

_get_softmax_dim

def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int:
    # "Implicit dimension choice for {} has been deprecated. 
    # Change the call to include dim=X as an argument."という警告メッセージを表示します。
    # {}の部分には`name`が入ります。
    # stacklevelを指定して、スタックトレースでの表示位置を制御します。
    warnings.warn(
        "Implicit dimension choice for {} has been deprecated. "
        "Change the call to include dim=X as an argument.".format(name),
        stacklevel=stacklevel,
    )
    # 入力テンソルの次元数に応じて、返す次元を決定します。
    if ndim == 0 or ndim == 1 or ndim == 3:
        ret = 0
    else:
        ret = 1
    return ret

この関数は、softmax関数で使用する次元を取得するための補助関数です。

input.softmax

この部分でsoftmax関数の処理を行っています。

nn.Softmaxの引数

nn.Softmaxも次のように簡単に書くことができます。

import torch
import torch.nn as nn

# 入力テンソルの作成
input_tensor = torch.tensor([[1.0, 2.0, 3.0]])

# Softmaxのインスタンスを作成
softmax = nn.Softmax(dim=1)

# Softmaxを適用
output = softmax(input_tensor)

print(input_tensor) # tensor([[1., 2., 3.]])
print(output) # tensor([[0.0900, 0.2447, 0.6652]])

次に、nn.Softmaxの引数を見ていきます。

  • dim
    Softmaxを計算する次元です。
    この次元の各スライスごとに和が1になります。

nn.Softmaxの使い方・入出力例

nn.Softmaxの挙動も見ておきます。

dimの影響を理解する

import torch
import torch.nn as nn

# 入力テンソルの作成
input_tensor = torch.tensor([[1.0, 2.0, 3.0],
                             [4.0, 5.0, 6.0]])

# Softmaxのインスタンスを作成
softmax_dim0 = nn.Softmax(dim=0)
softmax_dim1 = nn.Softmax(dim=1)
softmax_dim_minus1 = nn.Softmax(dim=-1)
softmax_dim_minus2 = nn.Softmax(dim=-2)

# dim=0の場合
output_dim0 = softmax_dim0(input_tensor)
print(output_dim0)
# 出力:
# tensor([[0.0474, 0.0474, 0.0474],
#        [0.9526, 0.9526, 0.9526]])

# dim=1の場合
output_dim1 = softmax_dim1(input_tensor)
print(output_dim1)
# 出力:
# tensor([[0.0900, 0.2447, 0.6652],
#         [0.0900, 0.2447, 0.6652]])

# dim=-1の場合
output_dim_minus1 = softmax_dim_minus1(input_tensor)
print(output_dim_minus1)
# 出力:
# tensor([[0.0900, 0.2447, 0.6652],
#         [0.0900, 0.2447, 0.6652]])

# dim=-2の場合
output_dim_minus2 = softmax_dim_minus2(input_tensor)
print(output_dim_minus2)
# 出力:
# tensor([[0.0474, 0.0474, 0.0474],
#         [0.9526, 0.9526, 0.9526]])

nn.Softmaxのソースコードの解説

ソースコードはこちらです。

class Softmax(Module):
    # モジュールが定数として扱う属性を指定します
    __constants__ = ['dim']
    dim: Optional[int]

    def __init__(self, dim: Optional[int] = None) -> None:
        # 親クラスの初期化メソッドを呼び出します
        super().__init__()
        # dim属性を設定します
        self.dim = dim

    def __setstate__(self, state):
        # 親クラスの__setstate__メソッドを呼び出します
        super().__setstate__(state)
        # dim属性が存在しない場合は、Noneで初期化します
        if not hasattr(self, 'dim'):
            self.dim = None

    def forward(self, input: Tensor) -> Tensor:
        # F.softmax関数を使用して入力をsoftmax変換し、結果を返します
        return F.softmax(input, self.dim, _stacklevel=5)

    def extra_repr(self) -> str:
        # オブジェクトの追加の文字列表現を返します
        return 'dim={dim}'.format(dim=self.dim)

Moduleクラスを継承したSoftmaxクラスを定義しています。Softmaxクラスは、モジュール内でのsoftmax変換を実装しています。

まとめ

ソフトマックス関数(Softmax function)は、与えられた実数の集合を確率分布に変換するために使用される非線形関数です。主に多クラス分類問題において、出力層の活性化関数として広く使われます。
今回解説したようにPytorchを利用すれば簡単に実装することができます。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次