Llama-2が登場!8bit+LoRAでRLHFファインチューニングを試す方法はこちら

【Pytorch】nn.Identityの役割・使い方・メリットを徹底解説!入力を変換せずにそのまま出力する方法

nn.Identityは、入力をそのまま出力する恒等演算を行います。
通常、ニューラルネットワークの各層は、入力データを変換して新しい表現を生成しますが、入力データをそのまま出力に伝えたいという場合もあり、この際にnn.Identityを用いることでそのまま出力することができます。

nn.Identity — PyTorch 2.0 documentation

目次

イメージを掴もう

nn.Identityのソースコード

ただ、inputを返しているのが分かります。

class Identity(Module):
    r"""A placeholder identity operator that is argument-insensitive.

    Args:
        args: any argument (unused)
        kwargs: any keyword argument (unused)

    Shape:
        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
        - Output: :math:`(*)`, same shape as the input.

    Examples::

        >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 20])

    """
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__()

    def forward(self, input: Tensor) -> Tensor:
        return input

nn.Identityの使い方

基本的には以下のように簡単に使うことができます。

import torch
import torch.nn as nn

input_tensor = torch.randn(2, 3)
print(input_tensor)
# tensor([[ 0.1095, -0.4588,  0.5012],
#        [ 0.1171, -1.0373,  1.8625]])

identity = nn.Identity()
output_tensor = identity(input_tensor)
print(output_tensor)
# tensor([[ 0.1095, -0.4588,  0.5012],
#        [ 0.1171, -1.0373,  1.8625]])

スキップ接続(ResNetスタイル)

この例では、ResidualBlockというクラスを定義しています。このクラスは、ResNetスタイルのネットワークにおける基本的なブロックを表現しています。入力と出力のチャンネル数が異なる場合、スキップ接続にはnn.Conv2dを使用し、チャンネル数を揃えます。チャンネル数が同じ場合はnn.Identityを使用してスキップ接続を行います。

import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # スキップ接続
        if in_channels != out_channels:
            self.identity = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.identity = nn.Identity()
        
        self.relu2 = nn.ReLU()
    
    def forward(self, x):
        identity = self.identity(x)  # スキップ接続
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += identity  # スキップ接続の和を取る
        out = self.relu2(out)
        
        return out
>>> ResidualBlock(3, 3)
ResidualBlock(
  (conv1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (identity): Identity()
  (relu2): ReLU()
)

プーリング層のスキップ

この例では、PoolingSkipというクラスを定義しています。このクラスでは、プーリング層の後にnn.Identityを挿入しています。これにより、プーリングされた特徴マップと元の特徴マップの和を取ることができます。

import torch.nn as nn

class PoolingSkip(nn.Module):
    def __init__(self):
        super(PoolingSkip, self).__init__()
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.identity = nn.Identity()  # nn.Identityでスキップ
        
        self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
    
    def forward(self, x):
        pooled = self.pool(x)
        identity = self.identity(pooled)  # スキップ
        
        out = self.conv(x)
        
        out += identity  # スキップの和を取る
        
        return out
>>> PoolingSkip()
PoolingSkip(
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (identity): Identity()
  (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

グラフ構造の保存と再利用

この例では、Graphというクラスを定義しています。複数のノード(全結合層)の出力を結合し、nn.Identityを使用してグラフ構造を保存します。これにより、後続のモジュールでグラフの一部を再利用することができます。

import torch
import torch.nn as nn

class Graph(nn.Module):
    def __init__(self):
        super(Graph, self).__init__()
        
        self.node1 = nn.Linear(10, 10)
        self.node2 = nn.Linear(10, 10)
        self.node3 = nn.Linear(10, 10)
        
        self.graph = nn.Identity()
        
    def forward(self, x):
        out1 = self.node1(x)
        out2 = self.node2(x)
        out3 = self.node3(x)
        
        out = torch.cat((out1, out2, out3), dim=1)
        
        out = self.graph(out)
        
        return out
>>> Graph()
Graph(
  (node1): Linear(in_features=10, out_features=10, bias=True)
  (node2): Linear(in_features=10, out_features=10, bias=True)
  (node3): Linear(in_features=10, out_features=10, bias=True)
  (graph): Identity()
)

nn.Identityのメリット

nn.Identityは、重みやパラメータを必要としない演算のプレースホルダとして使用することができる関数なので以下のような利点があります。

一部の層を削除したり、置き換えたりできる

以下の例では、事前に学習したモデル(ここではtorchvision.models.resnet50を使用)の最終全結合層をnn.Identityに置き換えます。

これにより、全結合層の前の特徴ベクトルを取得することができます。

import torch
from torchvision import models
from torch import nn

# 事前に学習済みのモデルをロードします
model1 = models.resnet50(pretrained=True)
model2 = models.resnet50(pretrained=True)

# model1の最終全結合層をIdentityに置き換えます
model1.fc = nn.Identity()

model1.eval()
model2.eval()

# ランダムな入力を生成します
x = torch.rand(1, 3, 224, 224)

# モデルに入力し、出力を比較します
y1 = model1(x)
y2 = model2(x)

print(y1.shape)
# torch.Size([1, 2048])
print(y2.shape)
# torch.Size([1, 1000])

入力と出力の形状を変更することができる

正確には、nn.Identity自体は入力されたテンソルをそのまま出力するため、入力と出力の形状を変更するという特性は持っていません。

ただし、nn.Identityが出力形状を変更しないという特性を利用して、特定の層をnn.Identityに置き換えることで、その層が行っていた出力形状の変更を無効にする、という使い方が可能です。

例えば、あるネットワークが最終層でFlatten層を持っていて、これによりテンソルの形状を変更(たとえば、[バッチサイズ, チャンネル数, 高さ, 幅]から[バッチサイズ, チャンネル数, 高さ, 幅]に変更)しているとします。このFlatten層をnn.Identityに置き換えると、テンソルの形状はFlattenされずにそのまま出力されます。

nn.Identityを使う際の注意点

入力と出力の形状に互換性があることを確認する必要があります。
以下に、この注意点を無視した場合にどのようなエラーが発生するかの例を示します。

import torch
import torch.nn as nn

class IncompatibleLayers(nn.Module):
    def __init__(self):
        super(IncompatibleLayers, self).__init__()
        self.fc1 = nn.Linear(1000, 10)
        self.identity = nn.Identity()
        self.fc2 = nn.Linear(100, 1)  # 入力は100次元を期待している

    def forward(self, x):
        x = self.fc1(x)
        x = self.identity(x)  # ここで10次元のデータが出力される
        x = self.fc2(x)  # しかし、次の層は100次元のデータを期待している
        return x

model = IncompatibleLayers()

# ダミーの入力データを作成します
input_data = torch.rand(1, 1000)

# モデルを実行します
output = model(input_data)  # エラーが発生します
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x10 and 100x1)

この例では、nn.Identityの前後でテンソルの形状が異なるため、エラーが発生します。

まとめ

nn.Identityは、単純な恒等関数であるため、学習可能なパラメータを持ちません。そのため、モデルのパラメータの数や学習の対象となることはありません。ただし、ネットワーク内の情報の流れを制御する用途において便利です。

この記事が気に入ったら
フォローしてね!

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

目次