本記事は、以前紹介した
KerasでOptimizerを比較!SGD・Adam・RMSpropの精度と学習速度を検証
の続編です。
前回は基本的な最適化手法の比較を行いましたが、
今回はその中でも特に利用頻度が高いAdamを改良したAdamWに焦点を当て、理論・数式・実装・性能を包括的に検証します。
AdamWとは?
AdamW(Decoupled Weight Decay Regularization)は、Loshchilov & Hutter(2017)が提案した最適化手法です。
従来のAdamがL2正則化を勾配更新に組み込むのに対し、AdamWは「重み減衰(Weight Decay)」を独立して適用する点が特徴です。
これにより、過学習を抑えつつ安定した学習を実現し、汎化性能を向上させます。
Adamとの違い
- Adam: L2正則化が勾配更新に含まれ、減衰効果が不安定になりやすい。
- AdamW: Weight Decayを独立して適用し、過学習を抑制しながら安定した訓練を可能にする。
数式レベルの違い
AdamWでは更新式において、重み減衰項 -λ * w が勾配とは独立に加わります。
この「Decoupled(分離された)」処理がAdamとの大きな違いです。
Kerasによる実装例
from tensorflow.keras.optimizers import AdamW
optimizer = AdamW(learning_rate=1e-3, weight_decay=1e-4)
model.compile(optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy'])
TensorFlow 2.11以降では、keras.optimizers.AdamW が標準実装として利用可能です。
Adam vs AdamWの比較
| 項目 | Adam | AdamW |
|---|---|---|
| Weight Decayの扱い | 勾配に含まれる | 独立して適用 |
| 汎化性能 | やや劣る | 良好 |
| 実装の簡単さ | 標準的 | TensorFlow 2.11以降で標準対応 |
実験:MNISTによる比較
手書き数字データセット「MNIST」を使い、AdamとAdamWの性能を比較しました。
CNN(畳み込みニューラルネットワーク)を構築し、精度・損失の推移を可視化しています。
from tensorflow import keras
from tensorflow.keras import layers
# MNISTデータの読み込み
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# 正規化とチャンネル次元追加
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train[..., None]
x_test = x_test[..., None]
def create_model(optimizer):
model = keras.Sequential([
layers.Input(shape=(28, 28, 1)),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
return model
optimizers = {
"Adam": keras.optimizers.Adam(),
"AdamW": keras.optimizers.AdamW(learning_rate=1e-3, weight_decay=1e-4)
}
results = {}
histories = {}
for name, opt in optimizers.items():
print(f"Training with {name}...")
model = create_model(opt)
histories[name] = model.fit(
x_train, y_train,
validation_split=0.2,
epochs=10
)
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
results[name] = {
"val_acc": histories[name].history["val_accuracy"][-1],
"test_acc": test_acc
}
import matplotlib.pyplot as plt
for name, opt in optimizers.items():
print(f"{name}: {results[name]}")
for name, history in histories.items():
plt.plot(history.history['val_accuracy'], label=f'{name} val_accuracy')
plt.title('Validation Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()
for name, history in histories.items():
plt.plot(history.history['val_loss'], label=f'{name} val_loss')
plt.title('Validation Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
学習結果
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Training with Adam...
Epoch 1/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 15s 6ms/step - accuracy: 0.8855 - loss: 0.3846 - val_accuracy: 0.9769 - val_loss: 0.0818
Epoch 2/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 10s 6ms/step - accuracy: 0.9791 - loss: 0.0703 - val_accuracy: 0.9828 - val_loss: 0.0591
Epoch 3/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 10s 6ms/step - accuracy: 0.9852 - loss: 0.0485 - val_accuracy: 0.9827 - val_loss: 0.0592
Epoch 4/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9905 - loss: 0.0310 - val_accuracy: 0.9822 - val_loss: 0.0567
Epoch 5/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9933 - loss: 0.0226 - val_accuracy: 0.9827 - val_loss: 0.0595
Epoch 6/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9949 - loss: 0.0163 - val_accuracy: 0.9824 - val_loss: 0.0643
Epoch 7/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9968 - loss: 0.0110 - val_accuracy: 0.9854 - val_loss: 0.0583
Epoch 8/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9977 - loss: 0.0085 - val_accuracy: 0.9835 - val_loss: 0.0743
Epoch 9/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9981 - loss: 0.0063 - val_accuracy: 0.9828 - val_loss: 0.0757
Epoch 10/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9976 - loss: 0.0074 - val_accuracy: 0.9828 - val_loss: 0.0740
Training with AdamW...
Epoch 1/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 7s 3ms/step - accuracy: 0.8850 - loss: 0.3765 - val_accuracy: 0.9783 - val_loss: 0.0784
Epoch 2/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9797 - loss: 0.0666 - val_accuracy: 0.9797 - val_loss: 0.0686
Epoch 3/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9883 - loss: 0.0406 - val_accuracy: 0.9840 - val_loss: 0.0568
Epoch 4/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9908 - loss: 0.0297 - val_accuracy: 0.9841 - val_loss: 0.0556
Epoch 5/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9938 - loss: 0.0192 - val_accuracy: 0.9846 - val_loss: 0.0539
Epoch 6/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9961 - loss: 0.0137 - val_accuracy: 0.9852 - val_loss: 0.0587
Epoch 7/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9958 - loss: 0.0129 - val_accuracy: 0.9848 - val_loss: 0.0612
Epoch 8/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9981 - loss: 0.0075 - val_accuracy: 0.9843 - val_loss: 0.0659
Epoch 9/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9982 - loss: 0.0062 - val_accuracy: 0.9812 - val_loss: 0.0837
Epoch 10/10
1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9972 - loss: 0.0074 - val_accuracy: 0.9869 - val_loss: 0.0620
Adam: {'val_acc': 0.9828333258628845, 'test_acc': 0.9850999712944031}
AdamW: {'val_acc': 0.9869166612625122, 'test_acc': 0.9873999953269958}
描画グラフ
精度グラフ
損失グラフ
結果と考察
| オプティマイザ | Val Accuracy | Test Accuracy | Val Loss傾向 |
|---|---|---|---|
| Adam | 0.9828 | 0.9851 | 後半でやや上昇(軽度の過学習) |
| AdamW | 0.9869 | 0.9874 | 安定して低下(汎化性能が良好) |
- 収束速度は両者とも同等。
- AdamWはVal Lossが安定し、過学習が抑制。
- 最終精度ではAdamWがわずかに上回った。
まとめ:AdamとAdamWの違いと使い分け
AdamWは、Adamの利点(高速収束・安定性)を維持しつつ、過学習を軽減する改良手法です。
特に深層学習モデルやTransformer系モデルで広く採用されており、現代的な最適化のデファクトスタンダードとなりつつあります。
- 短期学習では差は小さいが、長期訓練ではAdamWが安定して高精度。
- 過学習を防ぎたい場合はAdamWを推奨。
- 今後はCIFAR-10やImageNetなどより複雑なデータでの検証も有用。
以前の記事で解説した従来の最適化手法(SGD / RMSprop / Adam)に比べ、
AdamWは正則化の取り扱いを改良することで、より高い汎化性能を示しました。
今後は他の最適化手法(例:LAMB, AdaBelief など)とも比較し、さらなる性能向上を検証してみます。



0 件のコメント:
コメントを投稿