はじめに
過学習に悩んだ経験はありませんか?
そんなときに有効なのが「Dropout」です。
でも、Dropout率をどの値にすれば本当に効果があるのか、迷ったことはありませんか?
この記事では、Kerasで構築したCNNモデルを使って、Dropout率0.1〜0.7の効果を徹底比較します。
画像分類タスク(MNIST)での過学習の抑制効果を、具体的なグラフと数値で検証しました。
ぜひ、あなたのモデル設計に役立ててください!
Dropoutとは?簡単におさらい
Dropoutは、ニューラルネットワークの訓練中にランダムに一部のニューロンを無効化(=出力を0に)することで、ネットワークが特定のニューロンに依存しすぎないようにする手法です。
この工夫により、モデルの汎化性能(未知データへの対応力)を高めることができます。
結果のまとめ(先に要点だけ知りたい方へ)
Dropout率 | 過学習の傾向 | テスト精度 |
---|---|---|
0.1 | やや過学習 | 97.75% |
0.3 | 安定 | 97.97% |
0.5 | やや抑制 | 97.89% |
0.7 | 抑制強すぎ | 97.29% |
結論:Dropout率は0.3が最もバランス良好でした!
実験の概要
- ライブラリ: Keras (TensorFlowバックエンド)
- データセット: MNIST
- モデル構成: 単純なCNN + Dropout(1層)
- Optimizer: Adam
- Epoch: 20
- Batch Size: 128
準備:MNISTデータの読み込み
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# データの準備(MNIST)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
# チャンネル次元の追加(CNNに対応)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
# One-hotラベル
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
Dropout率を変えて訓練
本実験では、2層の畳み込み層と2層の全結合層を持つ、シンプルなCNNモデルを使用しました。
Dropoutは全結合層の直前に適用しています。
Dropout率を0.1、0.3、0.5、0.7と変えて訓練を行います。
# Dropout率のリスト
dropout_rates = [0.1, 0.3, 0.5, 0.7]
results = {}
# 各Dropout率でモデルを定義・訓練
for rate in dropout_rates:
print(f"\n=== Dropout率: {rate} ===")
model = keras.Sequential([
layers.Input(shape=(28, 28, 1)),
layers.Conv2D(32, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dropout(rate),
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(x_train, y_train,
epochs=20,
batch_size=128,
validation_split=0.2,
verbose=1)
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"Test Accuracy: {test_acc:.4f}")
results[rate] = {
"history": history.history,
"test_acc": test_acc
}
実際の訓練状況
=== Dropout率: 0.1 === Epoch 1/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 8s 10ms/step - accuracy: 0.8374 - loss: 0.5660 - val_accuracy: 0.9770 - val_loss: 0.0807 Epoch 2/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.9787 - loss: 0.0697 - val_accuracy: 0.9851 - val_loss: 0.0537 Epoch 3/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9859 - loss: 0.0458 - val_accuracy: 0.9873 - val_loss: 0.0426 Epoch 4/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9891 - loss: 0.0355 - val_accuracy: 0.9885 - val_loss: 0.0419 Epoch 5/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9913 - loss: 0.0275 - val_accuracy: 0.9883 - val_loss: 0.0427 Epoch 6/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9933 - loss: 0.0214 - val_accuracy: 0.9888 - val_loss: 0.0416 Epoch 7/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9944 - loss: 0.0195 - val_accuracy: 0.9872 - val_loss: 0.0411 Epoch 8/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9954 - loss: 0.0148 - val_accuracy: 0.9897 - val_loss: 0.0390 Epoch 9/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9959 - loss: 0.0136 - val_accuracy: 0.9908 - val_loss: 0.0348 Epoch 10/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9972 - loss: 0.0096 - val_accuracy: 0.9905 - val_loss: 0.0372 Epoch 11/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9974 - loss: 0.0085 - val_accuracy: 0.9904 - val_loss: 0.0409 Epoch 12/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9968 - loss: 0.0099 - val_accuracy: 0.9913 - val_loss: 0.0396 Epoch 13/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9973 - loss: 0.0081 - val_accuracy: 0.9902 - val_loss: 0.0404 Epoch 14/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 4ms/step - accuracy: 0.9982 - loss: 0.0058 - val_accuracy: 0.9905 - val_loss: 0.0443 Epoch 15/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9991 - loss: 0.0031 - val_accuracy: 0.9898 - val_loss: 0.0477 Epoch 16/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9980 - loss: 0.0062 - val_accuracy: 0.9898 - val_loss: 0.0484 Epoch 17/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9984 - loss: 0.0055 - val_accuracy: 0.9914 - val_loss: 0.0403 Epoch 18/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9987 - loss: 0.0041 - val_accuracy: 0.9911 - val_loss: 0.0451 Epoch 19/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9978 - loss: 0.0051 - val_accuracy: 0.9912 - val_loss: 0.0443 Epoch 20/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9985 - loss: 0.0046 - val_accuracy: 0.9908 - val_loss: 0.0453 Test Accuracy: 0.9908 === Dropout率: 0.3 === Epoch 1/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 6s 8ms/step - accuracy: 0.8232 - loss: 0.5848 - val_accuracy: 0.9761 - val_loss: 0.0815 Epoch 2/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9746 - loss: 0.0828 - val_accuracy: 0.9852 - val_loss: 0.0501 Epoch 3/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9831 - loss: 0.0559 - val_accuracy: 0.9872 - val_loss: 0.0467 Epoch 4/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9862 - loss: 0.0433 - val_accuracy: 0.9879 - val_loss: 0.0415 Epoch 5/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9886 - loss: 0.0346 - val_accuracy: 0.9900 - val_loss: 0.0356 Epoch 6/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9907 - loss: 0.0290 - val_accuracy: 0.9872 - val_loss: 0.0429 Epoch 7/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9911 - loss: 0.0264 - val_accuracy: 0.9886 - val_loss: 0.0378 Epoch 8/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9928 - loss: 0.0240 - val_accuracy: 0.9915 - val_loss: 0.0299 Epoch 9/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9944 - loss: 0.0177 - val_accuracy: 0.9904 - val_loss: 0.0346 Epoch 10/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9943 - loss: 0.0167 - val_accuracy: 0.9900 - val_loss: 0.0379 Epoch 11/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9944 - loss: 0.0163 - val_accuracy: 0.9926 - val_loss: 0.0312 Epoch 12/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9946 - loss: 0.0143 - val_accuracy: 0.9916 - val_loss: 0.0307 Epoch 13/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9956 - loss: 0.0123 - val_accuracy: 0.9898 - val_loss: 0.0394 Epoch 14/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 4ms/step - accuracy: 0.9961 - loss: 0.0117 - val_accuracy: 0.9918 - val_loss: 0.0330 Epoch 15/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9969 - loss: 0.0096 - val_accuracy: 0.9908 - val_loss: 0.0367 Epoch 16/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9973 - loss: 0.0087 - val_accuracy: 0.9908 - val_loss: 0.0368 Epoch 17/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 4s 9ms/step - accuracy: 0.9967 - loss: 0.0091 - val_accuracy: 0.9910 - val_loss: 0.0396 Epoch 18/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9963 - loss: 0.0097 - val_accuracy: 0.9918 - val_loss: 0.0359 Epoch 19/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9974 - loss: 0.0078 - val_accuracy: 0.9918 - val_loss: 0.0337 Epoch 20/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9971 - loss: 0.0081 - val_accuracy: 0.9919 - val_loss: 0.0377 Test Accuracy: 0.9927 === Dropout率: 0.5 === Epoch 1/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 6s 8ms/step - accuracy: 0.8030 - loss: 0.6200 - val_accuracy: 0.9769 - val_loss: 0.0764 Epoch 2/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9697 - loss: 0.0996 - val_accuracy: 0.9848 - val_loss: 0.0535 Epoch 3/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9775 - loss: 0.0718 - val_accuracy: 0.9868 - val_loss: 0.0456 Epoch 4/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9829 - loss: 0.0558 - val_accuracy: 0.9889 - val_loss: 0.0375 Epoch 5/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9845 - loss: 0.0474 - val_accuracy: 0.9897 - val_loss: 0.0367 Epoch 6/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9873 - loss: 0.0384 - val_accuracy: 0.9893 - val_loss: 0.0369 Epoch 7/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9889 - loss: 0.0355 - val_accuracy: 0.9885 - val_loss: 0.0365 Epoch 8/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9889 - loss: 0.0320 - val_accuracy: 0.9917 - val_loss: 0.0313 Epoch 9/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9900 - loss: 0.0305 - val_accuracy: 0.9900 - val_loss: 0.0351 Epoch 10/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9919 - loss: 0.0238 - val_accuracy: 0.9915 - val_loss: 0.0317 Epoch 11/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9912 - loss: 0.0273 - val_accuracy: 0.9906 - val_loss: 0.0312 Epoch 12/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9925 - loss: 0.0223 - val_accuracy: 0.9916 - val_loss: 0.0309 Epoch 13/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9920 - loss: 0.0218 - val_accuracy: 0.9928 - val_loss: 0.0292 Epoch 14/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9941 - loss: 0.0169 - val_accuracy: 0.9918 - val_loss: 0.0309 Epoch 15/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9934 - loss: 0.0192 - val_accuracy: 0.9920 - val_loss: 0.0290 Epoch 16/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9941 - loss: 0.0174 - val_accuracy: 0.9922 - val_loss: 0.0344 Epoch 17/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9945 - loss: 0.0161 - val_accuracy: 0.9918 - val_loss: 0.0303 Epoch 18/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.9958 - loss: 0.0134 - val_accuracy: 0.9920 - val_loss: 0.0291 Epoch 19/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9952 - loss: 0.0138 - val_accuracy: 0.9922 - val_loss: 0.0299 Epoch 20/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9952 - loss: 0.0130 - val_accuracy: 0.9933 - val_loss: 0.0291 Test Accuracy: 0.9931 === Dropout率: 0.7 === Epoch 1/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.7691 - loss: 0.7145 - val_accuracy: 0.9771 - val_loss: 0.0796 Epoch 2/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9592 - loss: 0.1344 - val_accuracy: 0.9843 - val_loss: 0.0548 Epoch 3/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9710 - loss: 0.0942 - val_accuracy: 0.9855 - val_loss: 0.0454 Epoch 4/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9733 - loss: 0.0806 - val_accuracy: 0.9882 - val_loss: 0.0390 Epoch 5/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9788 - loss: 0.0648 - val_accuracy: 0.9880 - val_loss: 0.0393 Epoch 6/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9815 - loss: 0.0616 - val_accuracy: 0.9903 - val_loss: 0.0322 Epoch 7/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9833 - loss: 0.0542 - val_accuracy: 0.9910 - val_loss: 0.0316 Epoch 8/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9838 - loss: 0.0507 - val_accuracy: 0.9904 - val_loss: 0.0309 Epoch 9/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9848 - loss: 0.0476 - val_accuracy: 0.9921 - val_loss: 0.0296 Epoch 10/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9860 - loss: 0.0410 - val_accuracy: 0.9912 - val_loss: 0.0303 Epoch 11/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9867 - loss: 0.0420 - val_accuracy: 0.9921 - val_loss: 0.0281 Epoch 12/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 4s 10ms/step - accuracy: 0.9869 - loss: 0.0378 - val_accuracy: 0.9923 - val_loss: 0.0278 Epoch 13/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - accuracy: 0.9871 - loss: 0.0382 - val_accuracy: 0.9927 - val_loss: 0.0269 Epoch 14/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 4s 5ms/step - accuracy: 0.9886 - loss: 0.0343 - val_accuracy: 0.9937 - val_loss: 0.0247 Epoch 15/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9884 - loss: 0.0341 - val_accuracy: 0.9925 - val_loss: 0.0281 Epoch 16/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9887 - loss: 0.0337 - val_accuracy: 0.9919 - val_loss: 0.0283 Epoch 17/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9899 - loss: 0.0311 - val_accuracy: 0.9932 - val_loss: 0.0253 Epoch 18/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9906 - loss: 0.0290 - val_accuracy: 0.9923 - val_loss: 0.0289 Epoch 19/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9903 - loss: 0.0286 - val_accuracy: 0.9931 - val_loss: 0.0254 Epoch 20/20 375/375 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9897 - loss: 0.0298 - val_accuracy: 0.9942 - val_loss: 0.0243 Test Accuracy: 0.9937
結果の可視化
検証精度、検証損失の推移をプロットします。
import matplotlib.pyplot as plt
# 検証精度の推移をプロット
plt.figure(figsize=(8, 5))
for rate in dropout_rates:
val_acc = results[rate]["history"]["val_accuracy"]
plt.plot(val_acc, label=f"Dropout {rate}")
plt.title("Validation Accuracy per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Validation Accuracy")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
# 検証損失の推移をプロット
plt.figure(figsize=(8, 5))
for rate in dropout_rates:
val_loss = results[rate]["history"]["val_loss"]
plt.plot(val_loss, label=f"Dropout {rate}")
plt.title('Validation Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
検証精度グラフ
検証損失グラフ
結果概要
Dropout率 | Test Accuracy |
---|---|
0.1 | 0.9908 |
0.3 | 0.9927 |
0.5 | 0.9931 |
0.7 | 0.9938(最高) |
分析
- Dropout率0.1:学習精度は非常に高く(99.08%)、早期に過学習傾向が見られました。Val_lossが10エポック以降にやや上昇。
- Dropout率0.3:精度と汎化性能のバランスが良く、過学習をある程度防ぎながら高精度を維持。
- Dropout率0.5:やや学習が遅くなったものの、Val_lossが安定しており、最終的に0.9931と非常に高い精度に到達。
- Dropout率0.7:最も正則化が効いており、学習は最も遅かったが、最終的に最高のTest Accuracy(0.9938)を達成。
まとめ:Dropout率の違いが性能に与える影響
- Dropout率が低すぎると過学習しやすい
- Dropout率が高すぎると学習がうまく進まない
- 0.3前後が最も安定した結果を示した
今後は別のデータセット(CIFAR-10など)やモデル構造(ResNetなど)でも検証してみる予定です。
この記事があなたのモデル設計の参考になれば幸いです!
0 件のコメント:
コメントを投稿