Question
pytorch で モデルを eval するのと、モデルパラメータすべてに requires_grad=False を指定するのと、with no_grad で実行するのはそれぞれどう違うの? コード例とともにまとめて
ご質問ありがとうございます。PyTorch におけるモデルの評価モード(eval())、勾配計算の無効化(requires_grad=False)、および torch.no_grad() コンテキストマネージャは、それぞれ目的、適用範囲、および効果が異なります。
ご要望に応じ、それぞれの違いを簡潔にまとめ、コード例を示します。
違いのまとめ 💡
| 特徴 | model.eval() | パラメータの requires_grad=False | with torch.no_grad(): |
|---|---|---|---|
| 目的 | モデル内部の動作モードを切り替え、評価時の振る舞いに変更する。 | 特定のテンソル(通常はモデルのパラメータ)に対する勾配計算を完全に無効化する。 | コンテキスト内での勾配追跡を一時的に無効化する。 |
| 適用範囲 | モデル全体(nn.Module インスタンス)。 | 個々のテンソル(torch.Tensor)。 | with ブロック内のすべての演算。 |
| 主な効果 | Dropout 層をオフにし、Batch Norm 層を学習中の移動平均ではなく、統計の固定(学習済みの平均と分散)を使用するように設定する。 | そのテンソルを含む演算グラフの勾配を記録しない。メモリ使用量を減らす。 | with ブロック内の演算で勾配追跡が行われない。最速かつ最も一般的な評価時の方法。 |
| 通常の使用場面 | モデルを評価(テスト、検証、推論)する前。 | ファインチューニングなどで特定の層を凍結(更新しない)したい時。 | モデルを評価または推論する時。 |
コード例 💻
1. model.eval() と with torch.no_grad():
これら 2 つは、通常、推論時や評価時に一緒に使用されます。
Python
import torch
import torch.nn as nn
# サンプルモデルを定義 (DropoutとBatchNormを含む)
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
# Dropout層は eval() でオフになる
self.dropout = nn.Dropout(0.5)
# BatchNorm層は eval() で統計が固定される
self.bn = nn.BatchNorm1d(10)
self.linear = nn.Linear(10, 1)
def forward(self, x):
x = self.dropout(x)
x = self.bn(x)
return self.linear(x)
model = SimpleModel()
input_data = torch.randn(16, 10)
# 1. 評価モードに切り替え
# - Dropoutが無効化 (入力=出力)
# - BatchNormの統計値が固定される
model.eval()
# 2. 勾配計算の無効化(メモリと速度の最適化)
with torch.no_grad():
output = model(input_data)
# このブロック内のすべての演算で勾配追跡は行われない
# 3. 再び学習モードに戻す場合
# model.train() 2. パラメータの requires_grad=False
これは主にファインチューニングや特徴抽出の際に、特定の層の学習を凍結するために使用されます。
Python
import torch
import torch.nn as nn
model = nn.Linear(10, 1) # サンプルとしてLinear層を使用
# 凍結したい層(例:すべてのパラメータ)を取得
for param in model.parameters():
# 勾配追跡を永久に無効化
param.requires_grad = False
# 以降の学習ループで...
# 凍結されたパラメータは、勾配が計算されないため更新されません。
# 例外: 一部の層のみを学習させたい場合
# 例: 最後のLinear層のみを学習させる
model = SimpleModel() # 前述のSimpleModelを再利用
for name, param in model.named_parameters():
if 'linear' not in name:
# 'linear'層以外のパラメータを凍結
param.requires_grad = False
else:
# 'linear'層のパラメータは学習可能のまま
pass まとめ
model.eval()は、モデルの動作(Dropout や BatchNorm)を変更します。これは推論時に必須です。with torch.no_grad():は、勾配の記録を無効化します。これは推論時の速度とメモリ効率のために推奨されます。requires_grad=Falseは、特定のパラメータを凍結(更新されないようにする)するために使用されます。
評価・推論を行う際は、model.eval() と with torch.no_grad(): をセットで使用することが標準的なプラクティスです。
他に PyTorch の機能についてご質問があれば、お気軽にお尋ねください。