はじめに
ディープラーニングでは「学習率(Learning Rate)」がモデルの性能に大きな影響を与えます。今回は、Kerasを使って手書き数字(MNIST)分類タスクにおいて、異なる学習率で訓練を行い、精度の違いを比較してみました。
学習率(Learning Rate)とは、ニューラルネットワークを訓練するときに、パラメータ(重み)をどれだけ更新するかを決める重要なハイパーパラメータです。
学習率が小さい | 学習率が大きい |
学習が遅い | 学習が速いが不安定 |
安定して収束 | 発散の可能性 |
精度が上がりにくい | 最小値にたどり着けないことも |
実験の概要
- データセット:MNIST(手書き数字画像)
- モデル:シンプルなCNN(Conv2D + MaxPool + Dense)
- 使用したOptimizer:SGD
- 比較する学習率:
0.0001
,0.001
,0.01
,0.1
ポイント:SGDは最も基本的なOptimizerなので、学習率の変化による影響が最も顕著に現れるため、今回の実験に使用しています。 一方、AdamやRMSpropは「自動で調整してくれる」ように見えるが、学習率によって最終精度や学習速度が変わります。
準備:MNISTデータの読み込み
from tensorflow import keras
from tensorflow.keras import layers
# MNISTデータの読み込み
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# 正規化とチャンネル次元追加
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train[..., None]
x_test = x_test[..., None]
モデル定義のコード
model = keras.Sequential([
layers.Input(shape=(28, 28, 1)),
layers.Conv2D(32, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
])
学習率を変えて訓練
学習率を0.0001、0.001、0.01、0.1と変えて訓練を行う。
results = {}
histories = {}
for lr in [0.0001, 0.001, 0.01, 0.1]:
print(f"Learning rate: {lr}")
optimizer = keras.optimizers.SGD(learning_rate=lr)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
histories[lr] = model.fit(x_train, y_train, epochs=10, validation_split=0.2)
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
results[lr] = {
"val_acc": histories[lr].history["val_accuracy"][-1],
"test_acc": test_acc
}
実際の訓練状況
Learning rate: 0.0001 Epoch 1/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.9987 - loss: 0.0055 - val_accuracy: 0.9847 - val_loss: 0.0610 Epoch 2/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 9s 3ms/step - accuracy: 0.9989 - loss: 0.0049 - val_accuracy: 0.9851 - val_loss: 0.0605 Epoch 3/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9993 - loss: 0.0045 - val_accuracy: 0.9851 - val_loss: 0.0601 Epoch 4/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9993 - loss: 0.0044 - val_accuracy: 0.9851 - val_loss: 0.0597 Epoch 5/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9994 - loss: 0.0041 - val_accuracy: 0.9852 - val_loss: 0.0594 Epoch 6/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9992 - loss: 0.0042 - val_accuracy: 0.9852 - val_loss: 0.0591 Epoch 7/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9993 - loss: 0.0041 - val_accuracy: 0.9852 - val_loss: 0.0589 Epoch 8/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9996 - loss: 0.0038 - val_accuracy: 0.9854 - val_loss: 0.0587 Epoch 9/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9995 - loss: 0.0041 - val_accuracy: 0.9852 - val_loss: 0.0585 Epoch 10/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9996 - loss: 0.0037 - val_accuracy: 0.9852 - val_loss: 0.0583 Learning rate: 0.001 Epoch 1/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9993 - loss: 0.0037 - val_accuracy: 0.9855 - val_loss: 0.0572 Epoch 2/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9998 - loss: 0.0031 - val_accuracy: 0.9856 - val_loss: 0.0566 Epoch 3/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 10s 3ms/step - accuracy: 0.9998 - loss: 0.0029 - val_accuracy: 0.9855 - val_loss: 0.0562 Epoch 4/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9997 - loss: 0.0028 - val_accuracy: 0.9852 - val_loss: 0.0558 Epoch 5/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9998 - loss: 0.0027 - val_accuracy: 0.9852 - val_loss: 0.0556 Epoch 6/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 4ms/step - accuracy: 0.9998 - loss: 0.0024 - val_accuracy: 0.9852 - val_loss: 0.0554 Epoch 7/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 12s 5ms/step - accuracy: 0.9998 - loss: 0.0024 - val_accuracy: 0.9852 - val_loss: 0.0553 Epoch 8/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9998 - loss: 0.0024 - val_accuracy: 0.9853 - val_loss: 0.0552 Epoch 9/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.9997 - loss: 0.0025 - val_accuracy: 0.9853 - val_loss: 0.0551 Epoch 10/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9998 - loss: 0.0025 - val_accuracy: 0.9856 - val_loss: 0.0550 Learning rate: 0.01 Epoch 1/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.9999 - loss: 0.0022 - val_accuracy: 0.9862 - val_loss: 0.0549 Epoch 2/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9999 - loss: 0.0019 - val_accuracy: 0.9861 - val_loss: 0.0547 Epoch 3/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9998 - loss: 0.0018 - val_accuracy: 0.9866 - val_loss: 0.0549 Epoch 4/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9999 - loss: 0.0016 - val_accuracy: 0.9868 - val_loss: 0.0555 Epoch 5/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 0.9999 - loss: 0.0016 - val_accuracy: 0.9867 - val_loss: 0.0559 Epoch 6/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 8s 5ms/step - accuracy: 0.9999 - loss: 0.0015 - val_accuracy: 0.9866 - val_loss: 0.0557 Epoch 7/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.9999 - loss: 0.0015 - val_accuracy: 0.9862 - val_loss: 0.0559 Epoch 8/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 4ms/step - accuracy: 1.0000 - loss: 0.0014 - val_accuracy: 0.9862 - val_loss: 0.0562 Epoch 9/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 1.0000 - loss: 0.0014 - val_accuracy: 0.9864 - val_loss: 0.0566 Epoch 10/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 1.0000 - loss: 0.0013 - val_accuracy: 0.9863 - val_loss: 0.0567 Learning rate: 0.1 Epoch 1/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - accuracy: 0.9998 - loss: 0.0018 - val_accuracy: 0.9849 - val_loss: 0.0633 Epoch 2/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9992 - loss: 0.0033 - val_accuracy: 0.9865 - val_loss: 0.0598 Epoch 3/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.9999 - loss: 0.0013 - val_accuracy: 0.9843 - val_loss: 0.0708 Epoch 4/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - accuracy: 0.9999 - loss: 0.0015 - val_accuracy: 0.9862 - val_loss: 0.0624 Epoch 5/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 12s 5ms/step - accuracy: 1.0000 - loss: 6.9728e-04 - val_accuracy: 0.9866 - val_loss: 0.0631 Epoch 6/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - accuracy: 1.0000 - loss: 5.1740e-04 - val_accuracy: 0.9861 - val_loss: 0.0636 Epoch 7/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 1.0000 - loss: 4.5181e-04 - val_accuracy: 0.9868 - val_loss: 0.0644 Epoch 8/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 9s 3ms/step - accuracy: 1.0000 - loss: 4.4254e-04 - val_accuracy: 0.9865 - val_loss: 0.0654 Epoch 9/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - accuracy: 1.0000 - loss: 3.8648e-04 - val_accuracy: 0.9864 - val_loss: 0.0656 Epoch 10/10 1500/1500 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 1.0000 - loss: 3.3382e-04 - val_accuracy: 0.9865 - val_loss: 0.0668
結果の可視化
import matplotlib.pyplot as plt
for lr in [0.0001, 0.001, 0.01, 0.1]:
print(f"{lr}: {results[lr]}")
for lr in [0.0001, 0.001, 0.01, 0.1]:
plt.plot(histories[lr].history['val_accuracy'], label=f'{lr} val_accuracy')
plt.title('Validation Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()
for lr in [0.0001, 0.001, 0.01, 0.1]:
plt.plot(histories[lr].history['val_loss'], label=f'{lr} val_loss')
plt.title('Validation Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
最終結果
0.0001: {'val_acc': 0.9852499961853027, 'test_acc': 0.9868000149726868} 0.001: {'val_acc': 0.9855833053588867, 'test_acc': 0.9872999787330627} 0.01: {'val_acc': 0.9863333106040955, 'test_acc': 0.9872999787330627} 0.1: {'val_acc': 0.9865000247955322, 'test_acc': 0.9868999719619751}
精度グラフ
損失グラフ
結果概要
Learning Rate | 最終Val Accuracy | 最終Val Loss | 特徴 |
---|---|---|---|
0.0001 | 0.9852 | 0.0583 | 非常に安定だが改善の伸びが少ない |
0.001 | 0.9856 | 0.0550 | 最もバランス良く安定 |
0.01 | 0.9863 | 0.0567 | 精度は高いが val_loss がやや悪化 |
0.1 | 0.9865 | 0.0668 | 精度は出るが不安定、過学習傾向 |
考察
- 0.001 は val_loss の推移も滑らかで、最も実用的な学習率である。
- 0.01 は精度は良いが、過学習ぎみで、エポック数を調整すべき。
- 0.1 は学習が速すぎて、val_loss が悪化しやすい。
- 0.0001 は安定だが収束が遅く、改善が鈍い。
まとめ
学習率は小さすぎても大きすぎても問題があり、タスクやモデルに応じて適切な値を選ぶ必要があります。特にSGDのようなシンプルなOptimizerでは、学習率の設定が成果を大きく左右します。
今後の応用
学習率は「固定値」ではなく、「学習率スケジューラ(Learning Rate Scheduler)」や「EarlyStopping」などと組み合わせることで、より洗練された学習が可能になります。
最後に
学習率を意識するだけで、モデルの性能を大きく改善することができます。ぜひ、あなたのプロジェクトでも試行錯誤しながら最適な学習率を見つけてみてください。
付録:Kerasのデフォルト学習率(learning rate)
Optimizer名 | クラス名 | デフォルトのlearning_rate | 特徴 |
---|---|---|---|
SGD | tf.keras.optimizers.SGD |
0.01 | 最も基本的な手法。学習率調整が重要。 |
Adam | tf.keras.optimizers.Adam |
0.001 | 自動調整あり。ほとんどのケースで高性能。 |
RMSprop | tf.keras.optimizers.RMSprop |
0.001 | 勾配の分散を利用して調整。RNN系でよく使われる。 |
Adagrad | tf.keras.optimizers.Adagrad |
0.001 | 学習率が段々小さくなる。疎なデータに強い。 |
Adadelta | tf.keras.optimizers.Adadelta |
1.0 | 学習率指定不要の工夫あり。 |
Nadam | tf.keras.optimizers.Nadam |
0.001 | Adam + Nesterov Momentum。収束が速い。 |
0 件のコメント:
コメントを投稿