KerasでCNNを学習するとき、optimizerはなんとなく adam にしていませんか?
MNISTでのSGD・Adam・RMSpropの比較は → KerasでOptimizerを比較!SGD・Adam・RMSpropの精度と学習速度を検証 で行いました。今回はより難しいCIFAR-10・GAP構成で30エポックの本格的な比較を行います。SGDにはmomentum=0.9を加えた現実的な設定で比較します。
📘 この記事でわかること
- Adam・SGD(momentum=0.9)・RMSpropでCIFAR-10の精度・収束速度がどう変わるか
- MNISTとCIFAR-10で結果に違いはあるか
- GAP使用CNNモデルでのoptimizerの選び方
Adam・SGD・RMSpropの特徴おさらい
| optimizer | 特徴 | 学習率の調整 |
|---|---|---|
| Adam | モーメンタムと適応学習率を組み合わせた手法。多くのタスクでデフォルトとして使われる | パラメータごとに自動調整 |
| SGD(momentum=0.9) | シンプルな確率的勾配降下法。momentumを加えることで収束を安定させる | 固定(手動調整が必要) |
| RMSprop | 過去の勾配の二乗平均を使って学習率を適応的に調整。RNNや非定常な問題に強い | パラメータごとに自動調整 |
今回はSGDに momentum=0.9 を設定します。デフォルトのSGD(momentum=0)は収束が非常に遅くCIFAR-10では実用的でないためです。
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。optimizer以外の条件は全て同一にして、optimizerの影響だけを取り出します。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 42 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 1s (6,773 kB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122354 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 33.9 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備・モデル構築関数
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
def build_model(name):
return keras.Sequential([
keras.layers.Input(shape=(32, 32, 3)),
keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax'),
], name=name)
def compile_and_fit(model, optimizer):
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
start = time.time()
history = model.fit(x_train, y_train, epochs=30, batch_size=64,
validation_split=0.2, verbose=1)
return history, time.time() - start
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 3s 0us/step
3パターンの学習実行
configs = [
(keras.optimizers.Adam(learning_rate=0.001), 'A_Adam'),
(keras.optimizers.SGD(learning_rate=0.01, momentum=0.9), 'B_SGD'),
(keras.optimizers.RMSprop(learning_rate=0.001), 'C_RMSprop'),
]
histories, times, scores = {}, {}, {}
for optimizer, name in configs:
print(f"\n=== {name} ===")
model = build_model(name)
h, t = compile_and_fit(model, optimizer)
s = model.evaluate(x_test, y_test, verbose=0)
label = name.split('_')[1]
histories[label] = h
times[label] = t
scores[label] = s
print(f"学習時間:{t:.1f}秒 test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_Adam === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 15s 13ms/step - accuracy: 0.2745 - loss: 1.9106 - val_accuracy: 0.3565 - val_loss: 1.7224 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 10ms/step - accuracy: 0.3766 - loss: 1.6691 - val_accuracy: 0.3803 - val_loss: 1.6437 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4237 - loss: 1.5572 - val_accuracy: 0.4730 - val_loss: 1.4670 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4642 - loss: 1.4628 - val_accuracy: 0.4873 - val_loss: 1.4155 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4949 - loss: 1.3852 - val_accuracy: 0.5153 - val_loss: 1.3251 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5123 - loss: 1.3351 - val_accuracy: 0.5267 - val_loss: 1.2907 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5303 - loss: 1.2942 - val_accuracy: 0.5470 - val_loss: 1.2376 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5421 - loss: 1.2614 - val_accuracy: 0.5355 - val_loss: 1.2549 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5525 - loss: 1.2295 - val_accuracy: 0.5465 - val_loss: 1.2357 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5648 - loss: 1.1994 - val_accuracy: 0.5648 - val_loss: 1.1921 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5719 - loss: 1.1772 - val_accuracy: 0.5736 - val_loss: 1.1671 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5864 - loss: 1.1465 - val_accuracy: 0.5923 - val_loss: 1.1159 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5952 - loss: 1.1217 - val_accuracy: 0.5973 - val_loss: 1.1018 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6021 - loss: 1.1001 - val_accuracy: 0.6100 - val_loss: 1.0766 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6116 - loss: 1.0816 - val_accuracy: 0.6119 - val_loss: 1.0666 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6191 - loss: 1.0603 - val_accuracy: 0.6245 - val_loss: 1.0407 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6235 - loss: 1.0431 - val_accuracy: 0.6317 - val_loss: 1.0195 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6307 - loss: 1.0284 - val_accuracy: 0.6340 - val_loss: 1.0183 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6393 - loss: 1.0065 - val_accuracy: 0.6317 - val_loss: 1.0237 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 7ms/step - accuracy: 0.6456 - loss: 0.9929 - val_accuracy: 0.6509 - val_loss: 0.9806 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6488 - loss: 0.9784 - val_accuracy: 0.6426 - val_loss: 0.9880 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6531 - loss: 0.9634 - val_accuracy: 0.6572 - val_loss: 0.9545 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6606 - loss: 0.9484 - val_accuracy: 0.6564 - val_loss: 0.9540 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6641 - loss: 0.9364 - val_accuracy: 0.6497 - val_loss: 0.9663 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6681 - loss: 0.9257 - val_accuracy: 0.6629 - val_loss: 0.9435 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6747 - loss: 0.9071 - val_accuracy: 0.6667 - val_loss: 0.9334 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6779 - loss: 0.9003 - val_accuracy: 0.6734 - val_loss: 0.9190 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6827 - loss: 0.8878 - val_accuracy: 0.6676 - val_loss: 0.9209 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 10ms/step - accuracy: 0.6872 - loss: 0.8755 - val_accuracy: 0.6825 - val_loss: 0.8913 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6920 - loss: 0.8636 - val_accuracy: 0.6842 - val_loss: 0.8843 学習時間:136.0秒 test_accuracy:0.6786 === B_SGD === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 8ms/step - accuracy: 0.1861 - loss: 2.1400 - val_accuracy: 0.2288 - val_loss: 2.0187 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.2740 - loss: 1.9176 - val_accuracy: 0.3263 - val_loss: 1.7936 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3227 - loss: 1.7845 - val_accuracy: 0.3432 - val_loss: 1.7740 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3563 - loss: 1.7039 - val_accuracy: 0.3924 - val_loss: 1.6234 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3923 - loss: 1.6237 - val_accuracy: 0.4226 - val_loss: 1.5475 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4219 - loss: 1.5622 - val_accuracy: 0.4442 - val_loss: 1.5165 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4496 - loss: 1.4955 - val_accuracy: 0.4806 - val_loss: 1.4262 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4711 - loss: 1.4429 - val_accuracy: 0.4948 - val_loss: 1.3907 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4886 - loss: 1.3963 - val_accuracy: 0.5119 - val_loss: 1.3441 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5041 - loss: 1.3616 - val_accuracy: 0.5356 - val_loss: 1.2905 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5181 - loss: 1.3277 - val_accuracy: 0.5492 - val_loss: 1.2564 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5330 - loss: 1.2911 - val_accuracy: 0.5415 - val_loss: 1.2639 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5401 - loss: 1.2717 - val_accuracy: 0.5665 - val_loss: 1.2111 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5484 - loss: 1.2427 - val_accuracy: 0.5668 - val_loss: 1.2129 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5596 - loss: 1.2176 - val_accuracy: 0.5674 - val_loss: 1.2030 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5691 - loss: 1.1946 - val_accuracy: 0.5680 - val_loss: 1.1778 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5774 - loss: 1.1671 - val_accuracy: 0.5759 - val_loss: 1.1608 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5824 - loss: 1.1578 - val_accuracy: 0.6058 - val_loss: 1.1033 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5931 - loss: 1.1279 - val_accuracy: 0.5972 - val_loss: 1.1212 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5993 - loss: 1.1178 - val_accuracy: 0.5953 - val_loss: 1.1461 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6035 - loss: 1.1008 - val_accuracy: 0.6116 - val_loss: 1.0870 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6036 - loss: 1.0960 - val_accuracy: 0.5960 - val_loss: 1.1130 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6142 - loss: 1.0713 - val_accuracy: 0.6257 - val_loss: 1.0470 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.6212 - loss: 1.0539 - val_accuracy: 0.6069 - val_loss: 1.0792 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 10ms/step - accuracy: 0.6241 - loss: 1.0452 - val_accuracy: 0.6353 - val_loss: 1.0292 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6267 - loss: 1.0379 - val_accuracy: 0.6225 - val_loss: 1.0588 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6314 - loss: 1.0193 - val_accuracy: 0.6347 - val_loss: 1.0166 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6407 - loss: 1.0056 - val_accuracy: 0.6374 - val_loss: 0.9983 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6422 - loss: 0.9950 - val_accuracy: 0.6254 - val_loss: 1.0561 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6499 - loss: 0.9788 - val_accuracy: 0.6329 - val_loss: 1.0242 学習時間:124.3秒 test_accuracy:0.6296 === C_RMSprop === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 8ms/step - accuracy: 0.2387 - loss: 2.0054 - val_accuracy: 0.3385 - val_loss: 1.7971 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3332 - loss: 1.7775 - val_accuracy: 0.3858 - val_loss: 1.6846 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3974 - loss: 1.6314 - val_accuracy: 0.4128 - val_loss: 1.5925 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4437 - loss: 1.5252 - val_accuracy: 0.4706 - val_loss: 1.4449 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4709 - loss: 1.4471 - val_accuracy: 0.5037 - val_loss: 1.3804 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4955 - loss: 1.3898 - val_accuracy: 0.4863 - val_loss: 1.3804 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5120 - loss: 1.3420 - val_accuracy: 0.5114 - val_loss: 1.3230 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5269 - loss: 1.2998 - val_accuracy: 0.5264 - val_loss: 1.2832 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5393 - loss: 1.2656 - val_accuracy: 0.5252 - val_loss: 1.2909 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5533 - loss: 1.2304 - val_accuracy: 0.5699 - val_loss: 1.1820 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5679 - loss: 1.1984 - val_accuracy: 0.5740 - val_loss: 1.1603 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5781 - loss: 1.1702 - val_accuracy: 0.5638 - val_loss: 1.2172 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5859 - loss: 1.1476 - val_accuracy: 0.5535 - val_loss: 1.2181 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5958 - loss: 1.1225 - val_accuracy: 0.5538 - val_loss: 1.2295 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6027 - loss: 1.0996 - val_accuracy: 0.5820 - val_loss: 1.1272 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6118 - loss: 1.0789 - val_accuracy: 0.6201 - val_loss: 1.0569 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6173 - loss: 1.0601 - val_accuracy: 0.6138 - val_loss: 1.0450 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6269 - loss: 1.0366 - val_accuracy: 0.6009 - val_loss: 1.0806 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6308 - loss: 1.0258 - val_accuracy: 0.6357 - val_loss: 1.0194 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6363 - loss: 1.0080 - val_accuracy: 0.6355 - val_loss: 1.0258 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6455 - loss: 0.9868 - val_accuracy: 0.6444 - val_loss: 0.9872 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6504 - loss: 0.9765 - val_accuracy: 0.6499 - val_loss: 0.9768 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6536 - loss: 0.9599 - val_accuracy: 0.6551 - val_loss: 0.9693 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6604 - loss: 0.9487 - val_accuracy: 0.6608 - val_loss: 0.9369 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6666 - loss: 0.9350 - val_accuracy: 0.6580 - val_loss: 0.9563 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6703 - loss: 0.9211 - val_accuracy: 0.6671 - val_loss: 0.9360 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6763 - loss: 0.9115 - val_accuracy: 0.6549 - val_loss: 0.9726 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6783 - loss: 0.8960 - val_accuracy: 0.6565 - val_loss: 0.9689 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6852 - loss: 0.8863 - val_accuracy: 0.6687 - val_loss: 0.9413 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6876 - loss: 0.8774 - val_accuracy: 0.6740 - val_loss: 0.9145 学習時間:123.0秒 test_accuracy:0.6707
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('optimizer_comparison.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(収束の安定性)────────────
fig2, axes2 = plt.subplots(1, 3, figsize=(18, 5))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(f'{label}')
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('optimizer_stability.png', dpi=150)
plt.show()
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>10} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8}")
print("-" * 46)
for label in ['Adam', 'SGD', 'RMSprop']:
val_acc = histories[label].history['val_accuracy'][-1]
test_acc = scores[label][1]
t = times[label]
print(f"{label:>10} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f}")
print("-" * 46)
結果サマリ
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Time(s)
----------------------------------------------
Adam | 0.6842 | 0.6786 | 136.0
SGD | 0.6329 | 0.6296 | 124.3
RMSprop | 0.6740 | 0.6707 | 123.0
----------------------------------------------
実験結果
精度グラフ
損失グラフ
Adam
SGD
RMSprop
| パターン | 最終 val_accuracy | 最終 test_accuracy | 学習時間 | 収束の傾向 |
|---|---|---|---|---|
| A:Adam(lr=0.001) | 68.42% | 67.86% | 136.0秒 | 最高精度・安定収束 |
| B:SGD(lr=0.01・momentum=0.9) | 63.29% | 62.96% | 124.3秒 | 収束が遅く精度も最低 |
| C:RMSprop(lr=0.001) | 67.40% | 67.07% | 123.0秒 | Adamに次ぐ精度・最短時間 |
考察
① AdamとRMSpropが接近、SGDが約5%差で最下位
結果を整理すると以下の通りです。
| 比較 | Test Acc差 | 学習時間差 |
|---|---|---|
| Adam vs RMSprop | 0.79%(Adam優位) | 13.0秒(RMSprop優位) |
| Adam vs SGD | 4.90%(Adam優位) | 11.7秒(SGD優位) |
AdamとRMSpropの精度差はわずか0.79%と非常に近い結果になりました。両者はどちらも適応学習率を持つ手法であり、CIFAR-10程度の規模では同等の性能を発揮します。学習時間はRMSpropが13秒短い(約10%)ため、精度よりも速度を重視する場面ではRMSpropも有力な選択肢です。
② SGDが約5%低精度になった理由
SGD(momentum=0.9・lr=0.01)のtest_accuracy(62.96%)はAdamより約4.9%低い結果です。SGDはAdamやRMSpropと異なり学習率を固定するため、最適な学習率の設定が精度に直結します。
今回設定したlr=0.01はSGDにとって一般的な値ですが、このモデル・このタスクに対して最適かどうかは別問題です。学習率スケジューラ(ReduceLROnPlateauなど)と組み合わせることで、SGDでもより高い精度に到達できる可能性があります。またSGDはResNetなどの大きなモデルで長期間学習させると最終的にAdamを上回るケースも報告されていますが、今回の30エポック・小規模モデルでは差が出る結果になりました。
③ MNISTとCIFAR-10で結果は変わったか
既存のMNIST版実験では5エポックで3種とも0.98前後の精度に達しており、差はほぼありませんでした。CIFAR-10ではAdamとRMSpropが約67〜68%で並び、SGDが約63%と約5%差をつけられるという明確な差が出ました。
タスクが難しくなり収束に時間がかかるほど、適応学習率を持つAdam・RMSpropが有利になる傾向が実験で確認できました。MNISTのようにシンプルなタスクでは3種の差が出にくいですが、CIFAR-10レベルではoptimizerの選択が精度に影響します。
④ 精度と学習時間の効率比較
| optimizer | Test Acc | 学習時間 | 精度/時間(効率) |
|---|---|---|---|
| Adam | 67.86% | 136.0秒 | 0.499%/秒 |
| RMSprop | 67.07% | 123.0秒 | 0.545%/秒 |
| SGD | 62.96% | 124.3秒 | 0.507%/秒 |
精度を学習時間で割った「効率」で見ると、RMSpropが最も効率が良い結果になりました。精度最優先ならAdam、精度と速度のバランスを重視するならRMSpropが有力です。
まとめ
- Adam が最高精度(67.86%)。多くのタスクでデフォルトとして信頼できる選択
- RMSprop はAdamと0.79%差・学習時間は13秒短い。精度と速度のバランスが良く有力な代替手段
- SGD は約5%低精度。学習率の固定が収束の遅さに直結。学習率スケジューラとの併用が推奨
- MNISTでは差がほぼなかったが、CIFAR-10では適応学習率(Adam・RMSprop)が優位という結果に
- 迷ったらまずAdam(lr=0.001)。速度も重視するならRMSpropも試す価値あり
関連記事もあわせてどうぞ:
- MNISTでのoptimizer比較 → KerasでOptimizerを比較!SGD・Adam・RMSpropの精度と学習速度を検証
- AdamWとAdamの比較 → 【Keras】AdamWとAdamの違いを実験で比較|どちらを使うべきか判断基準も解説
- Adamのlearning_rate比較 → Adamのlearning_rateを変えると何が起きる?(0.001 vs 0.01 vs 0.0001)【Keras×CIFAR-10実験】
- 学習率スケジューラ5種の比較 → Kerasで学習率スケジューラ5種を徹底比較|精度&収束の違いを実装・グラフ付きで解説






0 件のコメント:
コメントを投稿