warmup あり vs なし|学習率ウォームアップは 本当に効くのか?Keras実験で検証

投稿日:2026年3月9日月曜日 最終更新日:

callbacks CIFAR-10 CNN Google Colab Keras Learning Rate 学習率

X f B! P L
warmup あり vs なし|学習率ウォームアップは 本当に効くのか?Keras実験で検証 アイキャッチ画像

「学習率ウォームアップ」という言葉をご存知ですか?

BERTやResNetなど大きなモデルのコードによく登場しますが、「本当に効果があるのか」「Kerasでどう書くのか」を実験で確かめた記事は意外と少ない印象です。

今回はGoogle ColabとCIFAR-10を使い、warmupあり・なしを同条件で比較実験しました。

この記事を読むとわかること:

  • 学習率ウォームアップとは何か、なぜ効くとされているのか
  • KerasでLambdaCallbackを使ってwarmupを実装する方法
  • warmupあり vs なしで初期epochの挙動がどう変わるか

学習率ウォームアップとは?なぜ最初だけ低くするのか

学習率ウォームアップとは、学習の最初の数エポックだけ学習率を低い値から徐々に増やしていくテクニックです。

なぜ効果があるとされているかというと、学習の序盤はモデルの重みが初期値に近く不安定なため、高い学習率で大きなステップを踏むと勾配が暴れて損失が急激に変動することがあります。最初だけ学習率を低く抑えることで、序盤の不安定な動きを抑制しながら徐々に本来の学習速度に移行できます。

特にBERT・ViT・大規模CNNなど、パラメータ数が多いモデルで効果が出やすいとされています。

※各種学習率スケジューラの網羅的な比較は → Kerasの学習率スケジューラ5種を徹底比較【コード付き実験】

Kerasでwarmupを実装する方法

Kerasには warmup を直接指定するパラメータはありません。LambdaCallbackを使って各エポック開始時に学習率を手動で更新する方法で実装します。

実装の考え方は単純です。「エポック番号がwarmup期間内なら学習率を線形補間で計算し、optimizer.learning_rate.assign()でOptimizerに反映する」だけです。

import tensorflow as tf
from tensorflow import keras

def warmup_callback(optimizer, warmup_epochs, base_lr, start_lr):
    """
    warmup_epochs : ウォームアップを行うエポック数(例:5)
    base_lr       : ウォームアップ後の目標学習率(例:0.001)
    start_lr      : ウォームアップ開始時の学習率(例:0.0001)
    """
    def on_epoch_begin(epoch, logs):
        if epoch < warmup_epochs:
            # 線形補間で学習率を計算
            lr = start_lr + (base_lr - start_lr) * (epoch / warmup_epochs)
        else:
            lr = base_lr
        # learning_rateの値を直接assignで更新
        optimizer.learning_rate.assign(lr)
        print(f"  lr = {lr:.6f}")
    return keras.callbacks.LambdaCallback(on_epoch_begin=on_epoch_begin)

このCallbackは汎用的に使えます。warmup_epochs・base_lr・start_lrを引数で自由に変えられるため、今後の実験でも再利用できます。

実験設定:warmupあり vs なしのコード

使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。

モデルはConv2D×2層 + Denseの構成で固定し、warmupの有無だけを変えた2パターンを比較します。

実験コード

環境準備(最初に一度だけ実行)

# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fonts-ipafont-mincho
The following NEW packages will be installed:
  fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 37 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 3s (2,787 kB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 121852 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 9.8 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
  Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了

import・データ準備・モデル・warmup関数

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib

# CIFAR-10データの読み込み・前処理
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test  = x_test.astype('float32')  / 255.0

BASE_LR      = 0.001   # ベース学習率(warmup後の目標値)
START_LR     = 0.0001  # warmup開始時の学習率(BASE_LRの1/10)
WARMUP_EPOCHS = 5      # warmupを行うエポック数
TOTAL_EPOCHS  = 30

def build_model():
    return keras.Sequential([
        keras.layers.Input(shape=(32, 32, 3)),
        keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(10,  activation='softmax')
    ])

def warmup_callback(optimizer, warmup_epochs, base_lr, start_lr):
    def on_epoch_begin(epoch, logs):
        if epoch < warmup_epochs:
            lr = start_lr + (base_lr - start_lr) * (epoch / warmup_epochs)
        else:
            lr = base_lr
        # learning_rateの値を直接assignで更新
        optimizer.learning_rate.assign(lr)
        print(f"  lr = {lr:.6f}")
    return keras.callbacks.LambdaCallback(on_epoch_begin=on_epoch_begin)
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 8s 0us/step

Pattern A:warmup なし

print("\n=== Pattern A:warmup なし ===")
model_A   = build_model()
optimizer_A = keras.optimizers.Adam(learning_rate=BASE_LR)
model_A.compile(
    optimizer=optimizer_A,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
history_A = model_A.fit(
    x_train, y_train,
    epochs=TOTAL_EPOCHS,
    batch_size=64,
    validation_split=0.2,
    verbose=1
)
実行結果をクリックして内容を開く
=== Pattern A:warmup なし ===
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 12s 9ms/step - accuracy: 0.3825 - loss: 1.6982 - val_accuracy: 0.6033 - val_loss: 1.1400
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6338 - loss: 1.0647 - val_accuracy: 0.6453 - val_loss: 1.0169
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6881 - loss: 0.8920 - val_accuracy: 0.6782 - val_loss: 0.9396
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.7374 - loss: 0.7550 - val_accuracy: 0.7077 - val_loss: 0.8641
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7748 - loss: 0.6520 - val_accuracy: 0.7143 - val_loss: 0.8560
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8050 - loss: 0.5633 - val_accuracy: 0.7096 - val_loss: 0.9107
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.8372 - loss: 0.4718 - val_accuracy: 0.7129 - val_loss: 0.9086
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8615 - loss: 0.4006 - val_accuracy: 0.7187 - val_loss: 0.9301
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8896 - loss: 0.3221 - val_accuracy: 0.7161 - val_loss: 0.9946
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9166 - loss: 0.2485 - val_accuracy: 0.7044 - val_loss: 1.1332
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9346 - loss: 0.2001 - val_accuracy: 0.7171 - val_loss: 1.1722
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9447 - loss: 0.1638 - val_accuracy: 0.7044 - val_loss: 1.2767
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.9610 - loss: 0.1209 - val_accuracy: 0.7027 - val_loss: 1.3919
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9617 - loss: 0.1130 - val_accuracy: 0.7088 - val_loss: 1.4896
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9695 - loss: 0.0929 - val_accuracy: 0.7006 - val_loss: 1.6493
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9739 - loss: 0.0815 - val_accuracy: 0.6963 - val_loss: 1.7709
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9759 - loss: 0.0709 - val_accuracy: 0.6971 - val_loss: 1.7601
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9811 - loss: 0.0570 - val_accuracy: 0.6981 - val_loss: 1.8924
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9774 - loss: 0.0670 - val_accuracy: 0.6971 - val_loss: 2.0042
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9822 - loss: 0.0528 - val_accuracy: 0.6934 - val_loss: 2.0100
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9814 - loss: 0.0558 - val_accuracy: 0.6860 - val_loss: 2.1607
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9820 - loss: 0.0512 - val_accuracy: 0.6987 - val_loss: 2.3172
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9779 - loss: 0.0662 - val_accuracy: 0.6934 - val_loss: 2.2710
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9824 - loss: 0.0521 - val_accuracy: 0.6887 - val_loss: 2.3849
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9838 - loss: 0.0474 - val_accuracy: 0.6950 - val_loss: 2.4208
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 10ms/step - accuracy: 0.9861 - loss: 0.0391 - val_accuracy: 0.7004 - val_loss: 2.4474
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 7ms/step - accuracy: 0.9844 - loss: 0.0444 - val_accuracy: 0.6947 - val_loss: 2.4415
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9806 - loss: 0.0556 - val_accuracy: 0.6962 - val_loss: 2.5873
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9853 - loss: 0.0428 - val_accuracy: 0.6953 - val_loss: 2.5537
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9878 - loss: 0.0369 - val_accuracy: 0.6888 - val_loss: 2.5468

Pattern B:warmup あり(5エポック線形増加)

print("\n=== Pattern B:warmup あり(5エポック)===")
model_B   = build_model()
optimizer_B = keras.optimizers.Adam(learning_rate=START_LR)  # 初期値はSTART_LR
model_B.compile(
    optimizer=optimizer_B,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
warmup_cb = warmup_callback(
    optimizer=optimizer_B,
    warmup_epochs=WARMUP_EPOCHS,
    base_lr=BASE_LR,
    start_lr=START_LR
)
history_B = model_B.fit(
    x_train, y_train,
    epochs=TOTAL_EPOCHS,
    batch_size=64,
    validation_split=0.2,
    callbacks=[warmup_cb],
    verbose=1
)

# テストデータで最終精度を評価
test_results = {
    'A:warmupなし': model_A.evaluate(x_test, y_test, verbose=0),
    'B:warmupあり': model_B.evaluate(x_test, y_test, verbose=0),
}
実行結果をクリックして内容を開く
=== Pattern B:warmup あり(5エポック)===
  lr = 0.000100
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 12s 9ms/step - accuracy: 0.2996 - loss: 1.9428 - val_accuracy: 0.4768 - val_loss: 1.4945
  lr = 0.000280
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4849 - loss: 1.4413 - val_accuracy: 0.5269 - val_loss: 1.3139
  lr = 0.000460
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5733 - loss: 1.2079 - val_accuracy: 0.6204 - val_loss: 1.0881
  lr = 0.000640
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6414 - loss: 1.0284 - val_accuracy: 0.6654 - val_loss: 0.9794
  lr = 0.000820
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6765 - loss: 0.9330 - val_accuracy: 0.6848 - val_loss: 0.9186
  lr = 0.001000
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7096 - loss: 0.8242 - val_accuracy: 0.6922 - val_loss: 0.8980
  lr = 0.001000
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.7535 - loss: 0.7114 - val_accuracy: 0.7177 - val_loss: 0.8407
  lr = 0.001000
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7982 - loss: 0.5851 - val_accuracy: 0.7135 - val_loss: 0.8494
  lr = 0.001000
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8273 - loss: 0.4945 - val_accuracy: 0.7266 - val_loss: 0.8439
  lr = 0.001000
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.8745 - loss: 0.3668 - val_accuracy: 0.7156 - val_loss: 0.9181
  lr = 0.001000
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9031 - loss: 0.2861 - val_accuracy: 0.7201 - val_loss: 1.0318
  lr = 0.001000
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9310 - loss: 0.2042 - val_accuracy: 0.7217 - val_loss: 1.0764
  lr = 0.001000
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9486 - loss: 0.1545 - val_accuracy: 0.7131 - val_loss: 1.1130
  lr = 0.001000
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9623 - loss: 0.1176 - val_accuracy: 0.7175 - val_loss: 1.2721
  lr = 0.001000
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9749 - loss: 0.0811 - val_accuracy: 0.7153 - val_loss: 1.3557
  lr = 0.001000
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9757 - loss: 0.0757 - val_accuracy: 0.7111 - val_loss: 1.4932
  lr = 0.001000
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9723 - loss: 0.0846 - val_accuracy: 0.7155 - val_loss: 1.5977
  lr = 0.001000
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9795 - loss: 0.0592 - val_accuracy: 0.7139 - val_loss: 1.6486
  lr = 0.001000
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9789 - loss: 0.0634 - val_accuracy: 0.6973 - val_loss: 1.7539
  lr = 0.001000
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9749 - loss: 0.0744 - val_accuracy: 0.7157 - val_loss: 1.8188
  lr = 0.001000
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9839 - loss: 0.0481 - val_accuracy: 0.7030 - val_loss: 1.9211
  lr = 0.001000
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9846 - loss: 0.0478 - val_accuracy: 0.7091 - val_loss: 1.8856
  lr = 0.001000
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.9790 - loss: 0.0637 - val_accuracy: 0.7045 - val_loss: 2.0432
  lr = 0.001000
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9831 - loss: 0.0507 - val_accuracy: 0.7082 - val_loss: 2.0653
  lr = 0.001000
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9848 - loss: 0.0442 - val_accuracy: 0.7056 - val_loss: 1.9828
  lr = 0.001000
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9853 - loss: 0.0450 - val_accuracy: 0.7062 - val_loss: 2.1443
  lr = 0.001000
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9836 - loss: 0.0525 - val_accuracy: 0.7019 - val_loss: 2.0845
  lr = 0.001000
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.9897 - loss: 0.0318 - val_accuracy: 0.7066 - val_loss: 2.2455
  lr = 0.001000
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9863 - loss: 0.0408 - val_accuracy: 0.7114 - val_loss: 2.2354
  lr = 0.001000
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9845 - loss: 0.0460 - val_accuracy: 0.7023 - val_loss: 2.3389

学習曲線の可視化と最終結果サマリー

histories = {
    'A:warmupなし': history_A,
    'B:warmupあり': history_B,
}

# ── 全体グラフ(30エポック)──────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
    axes[0].plot(h.history['val_accuracy'], label=label)
    axes[1].plot(h.history['val_loss'],     label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
    ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('warmup_full.png', dpi=150)
plt.show()

# ── 序盤拡大グラフ(1〜8エポック)──────────────────
EARLY = 8
fig2, axes2 = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
    axes2[0].plot(h.history['val_accuracy'][:EARLY], label=label, marker='o')
    axes2[1].plot(h.history['val_loss'][:EARLY],     label=label, marker='o')
axes2[0].set_title(f'val_accuracy の比較(序盤1〜{EARLY}エポック)')
axes2[1].set_title(f'val_loss の比較(序盤1〜{EARLY}エポック)')
for ax in axes2:
    ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('warmup_early.png', dpi=150)
plt.show()

# ── 最終結果サマリー ─────────────────────────────────
key_order = ['A:warmupなし', 'B:warmupあり']
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>16} | {'Val Acc @5ep':>13} | {'Final Val Acc':>14} | {'Test Acc':>9}")
print("-" * 62)
for key in key_order:
    h = histories[key]
    val_acc_5  = h.history['val_accuracy'][4]   # 5epoch目(0-indexed)
    val_acc_final = h.history['val_accuracy'][-1]
    test_loss, test_acc = test_results[key]
    print(f"{key:>16} | {val_acc_5:>13.4f} | {val_acc_final:>14.4f} | {test_acc:>9.4f}")
print("-" * 62)

最終結果サマリー

===== 最終結果サマリー =====
         Pattern |  Val Acc @5ep |  Final Val Acc |  Test Acc
--------------------------------------------------------------
      A:warmupなし |        0.7143 |         0.6888 |    0.6866
      B:warmupあり |        0.6848 |         0.7023 |    0.6939
--------------------------------------------------------------

実験結果:初期epochの挙動に注目

全体の学習曲線比較(30エポック)

精度グラフ

warmup あり vs なし 学習曲線比較 精度グラフ

損失グラフ

warmup あり vs なし 学習曲線比較 損失グラフ

全体グラフから読み取れること:

  • 最終精度(30エポック時点)はwarmupあり・なしで大きな差はありません。
  • val_lossはwarmupありの方が後半にかけてやや安定する傾向が見られます。
  • 注目点として、5エポック時点ではwarmupありの精度がなしを下回っていますが、最終的に逆転しています。warmupによって序盤をゆっくり進んだ分、後半の収束が安定したと考えられます。

序盤(1〜8エポック)の拡大グラフ

精度グラフ

warmup あり vs なし 序盤学習曲線比較 精度グラフ

損失グラフ

warmup あり vs なし 序盤学習曲線比較 損失グラフ

序盤を拡大すると違いが見えてきます:

  • パターンA(warmupなし)は1エポック目から高い学習率で動くため、val_lossが最初に大きく振れることがあります。
  • パターンB(warmupあり)は序盤のval_lossが安定しており、5エポックかけてゆっくり加速していく様子が確認できます。
パターンval_acc(5epoch時点)val_acc(最終)test_acc序盤val_lossの安定性
A:warmupなし71.43%68.88%68.66%やや不安定
B:warmupあり68.48%(低め)70.23%69.39%安定

結論:warmupはどんな場合に効くのか

今回の実験結果をもとに結論をまとめます。

CIFAR-10 + シンプルなCNNでは → 最終精度への効果は限定的

今回の実験では最終精度の差は小さい結果でした。モデルが比較的シンプルで、学習が元々安定しているためです。

大きなモデル・転移学習では → warmupが特に有効

事前学習済みモデル(MobileNetV2・EfficientNetなど)をファインチューニングする場合、序盤に高い学習率で動くと事前学習で得た重みを壊してしまうリスクがあります。warmupで最初を低く抑えることで、この問題を軽減できます。

損失が序盤に大きく振れるときは → warmupを試す価値あり

学習の初期にlossが不安定・発散気味になる場合は、warmupが有効な対策になります。

一方、安定して学習できているモデルにわざわざwarmupを入れる必要はありません。「序盤が不安定なとき・大きなモデルのとき」に導入を検討するのが実用的な判断基準です。

補足:warmupのエポック数・開始学習率の決め方

よく使われる経験則を紹介します。

warmupのエポック数:総エポック数の10〜20%

今回は30エポック中5エポック(約17%)を使いました。一般的には5〜10エポックが多く使われます。

開始学習率:ベース学習率の1/10

今回は BASE_LR=0.001 に対して START_LR=0.0001(1/10)を使いました。1/100にすると立ち上がりが遅すぎて非効率になるため、1/10が無難です。

これらはあくまで経験則です。モデルや学習の安定性に応じて調整してみてください。

まとめ

今回はKerasとCIFAR-10を使い、学習率ウォームアップのあり・なしを実験で比較しました。

  • warmupはLambdaCallbackで実装できる(Kerasに直接設定はない)
  • シンプルなCNNでの最終精度差は小さい(序盤のval_lossの安定性に差が出る)
  • 大きなモデル・転移学習・序盤が不安定なときに導入を検討

学習率スケジューリングを深掘りしたい方は関連記事もどうぞ: