KerasでモデルをFitするとき、batch_size はなんとなく 32 や 64 にしていませんか?
バッチサイズは精度・学習時間・収束の安定性に影響する重要なハイパーパラメータです。今回はGoogle ColabとCIFAR-10を使い、batch_sizeを16・64・256の3パターンで比較しました。
なお、MNISTでのバッチサイズ比較は → バッチサイズを変えると精度や学習効率はどう変わる?画像分類で徹底比較 をご覧ください。本記事はCIFAR-10・GAP構成でより踏み込んだ比較を行います。
- バッチサイズを変えると精度・学習時間・収束の安定性がどう変わるか
- 小さすぎる・大きすぎるバッチサイズで何が起きるか
- CIFAR-10+GAP構成での最適なバッチサイズの目安
バッチサイズとは?変えると何が起きるか
バッチサイズは1回の重み更新に使うデータ数です。
| batch_size | 1エポックの更新回数(訓練データ4万枚の場合) | 期待される挙動 |
|---|---|---|
| 16(小さい) | 2,500回 | 更新頻度が高くノイズが多い。汎化性能が上がりやすいが不安定・学習時間長 |
| 64(中程度) | 625回 | バランスが良い。多くのタスクで標準的に使われる |
| 256(大きい) | 157回 | 更新が安定・学習時間短。ただし局所最適解に陥りやすく汎化性能が下がることも |
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。batch_size以外の条件は全て同一にして、バッチサイズの影響だけを取り出します。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 42 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 3s (2,832 kB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122354 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 62.2 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備・モデル構築関数
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
def build_model(name):
return keras.Sequential([
keras.layers.Input(shape=(32, 32, 3)),
keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax'),
], name=name)
def compile_and_fit(model, batch_size):
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
start = time.time()
history = model.fit(
x_train, y_train,
epochs=30,
batch_size=batch_size,
validation_split=0.2,
verbose=1
)
return history, time.time() - start
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 18s 0us/step
3パターンの学習実行
configs = [(16, 'A_batch16'), (64, 'B_batch64'), (256, 'C_batch256')]
histories, times, scores = {}, {}, {}
for batch_size, name in configs:
print(f"\n=== {name} ===")
model = build_model(name)
h, t = compile_and_fit(model, batch_size)
s = model.evaluate(x_test, y_test, verbose=0)
label = name.split('_')[1]
histories[label] = h
times[label] = t
scores[label] = s
print(f"学習時間:{t:.1f}秒 test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_batch16 === Epoch 1/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 25s 7ms/step - accuracy: 0.3146 - loss: 1.8155 - val_accuracy: 0.3951 - val_loss: 1.6440 Epoch 2/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 28s 4ms/step - accuracy: 0.4491 - loss: 1.5047 - val_accuracy: 0.4872 - val_loss: 1.4078 Epoch 3/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.4961 - loss: 1.3832 - val_accuracy: 0.5306 - val_loss: 1.2816 Epoch 4/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.5230 - loss: 1.3015 - val_accuracy: 0.5407 - val_loss: 1.2454 Epoch 5/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.5497 - loss: 1.2374 - val_accuracy: 0.5409 - val_loss: 1.2399 Epoch 6/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.5699 - loss: 1.1819 - val_accuracy: 0.5685 - val_loss: 1.1768 Epoch 7/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 13s 5ms/step - accuracy: 0.5882 - loss: 1.1390 - val_accuracy: 0.5999 - val_loss: 1.0900 Epoch 8/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6053 - loss: 1.0953 - val_accuracy: 0.6120 - val_loss: 1.0536 Epoch 9/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6172 - loss: 1.0622 - val_accuracy: 0.6163 - val_loss: 1.0645 Epoch 10/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6282 - loss: 1.0300 - val_accuracy: 0.6418 - val_loss: 1.0063 Epoch 11/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6389 - loss: 1.0026 - val_accuracy: 0.6210 - val_loss: 1.0349 Epoch 12/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6515 - loss: 0.9751 - val_accuracy: 0.6477 - val_loss: 0.9804 Epoch 13/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6578 - loss: 0.9504 - val_accuracy: 0.6394 - val_loss: 1.0019 Epoch 14/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6649 - loss: 0.9333 - val_accuracy: 0.6637 - val_loss: 0.9465 Epoch 15/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6734 - loss: 0.9076 - val_accuracy: 0.6820 - val_loss: 0.8956 Epoch 16/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6827 - loss: 0.8866 - val_accuracy: 0.6734 - val_loss: 0.9118 Epoch 17/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 14s 6ms/step - accuracy: 0.6871 - loss: 0.8750 - val_accuracy: 0.6787 - val_loss: 0.8999 Epoch 18/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.6914 - loss: 0.8585 - val_accuracy: 0.6867 - val_loss: 0.8781 Epoch 19/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7006 - loss: 0.8354 - val_accuracy: 0.6771 - val_loss: 0.9135 Epoch 20/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7052 - loss: 0.8260 - val_accuracy: 0.6955 - val_loss: 0.8671 Epoch 21/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7099 - loss: 0.8102 - val_accuracy: 0.6979 - val_loss: 0.8638 Epoch 22/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7166 - loss: 0.7953 - val_accuracy: 0.7028 - val_loss: 0.8539 Epoch 23/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7200 - loss: 0.7832 - val_accuracy: 0.6972 - val_loss: 0.8644 Epoch 24/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7233 - loss: 0.7705 - val_accuracy: 0.6945 - val_loss: 0.8566 Epoch 25/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7287 - loss: 0.7572 - val_accuracy: 0.7033 - val_loss: 0.8415 Epoch 26/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7338 - loss: 0.7449 - val_accuracy: 0.6942 - val_loss: 0.8691 Epoch 27/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7362 - loss: 0.7342 - val_accuracy: 0.6997 - val_loss: 0.8533 Epoch 28/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 21s 4ms/step - accuracy: 0.7406 - loss: 0.7234 - val_accuracy: 0.6904 - val_loss: 0.8809 Epoch 29/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 9s 4ms/step - accuracy: 0.7456 - loss: 0.7118 - val_accuracy: 0.7026 - val_loss: 0.8432 Epoch 30/30 2500/2500 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7509 - loss: 0.7019 - val_accuracy: 0.7172 - val_loss: 0.8246 学習時間:355.8秒 test_accuracy:0.7089 === B_batch64 === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 8ms/step - accuracy: 0.2689 - loss: 1.9214 - val_accuracy: 0.3603 - val_loss: 1.7100 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3604 - loss: 1.6923 - val_accuracy: 0.3889 - val_loss: 1.6307 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4168 - loss: 1.5721 - val_accuracy: 0.4653 - val_loss: 1.4703 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.4587 - loss: 1.4687 - val_accuracy: 0.4859 - val_loss: 1.4006 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4878 - loss: 1.3955 - val_accuracy: 0.5062 - val_loss: 1.3518 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5045 - loss: 1.3500 - val_accuracy: 0.5137 - val_loss: 1.3334 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5225 - loss: 1.3052 - val_accuracy: 0.5157 - val_loss: 1.3012 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5358 - loss: 1.2657 - val_accuracy: 0.5595 - val_loss: 1.2136 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5501 - loss: 1.2355 - val_accuracy: 0.5475 - val_loss: 1.2552 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5567 - loss: 1.2127 - val_accuracy: 0.5650 - val_loss: 1.2076 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5688 - loss: 1.1844 - val_accuracy: 0.5749 - val_loss: 1.1848 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5785 - loss: 1.1580 - val_accuracy: 0.5961 - val_loss: 1.1117 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5897 - loss: 1.1344 - val_accuracy: 0.5905 - val_loss: 1.1116 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.5956 - loss: 1.1107 - val_accuracy: 0.6078 - val_loss: 1.0762 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6076 - loss: 1.0893 - val_accuracy: 0.6165 - val_loss: 1.0504 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6131 - loss: 1.0695 - val_accuracy: 0.6186 - val_loss: 1.0463 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6190 - loss: 1.0524 - val_accuracy: 0.6391 - val_loss: 1.0217 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6307 - loss: 1.0307 - val_accuracy: 0.6336 - val_loss: 1.0148 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6340 - loss: 1.0156 - val_accuracy: 0.6425 - val_loss: 1.0045 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6381 - loss: 0.9983 - val_accuracy: 0.6376 - val_loss: 1.0080 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6482 - loss: 0.9814 - val_accuracy: 0.6470 - val_loss: 0.9738 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6500 - loss: 0.9678 - val_accuracy: 0.6466 - val_loss: 0.9899 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6545 - loss: 0.9551 - val_accuracy: 0.6547 - val_loss: 0.9673 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6591 - loss: 0.9399 - val_accuracy: 0.6576 - val_loss: 0.9428 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6671 - loss: 0.9280 - val_accuracy: 0.6689 - val_loss: 0.9121 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6720 - loss: 0.9132 - val_accuracy: 0.6741 - val_loss: 0.9168 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6790 - loss: 0.9004 - val_accuracy: 0.6734 - val_loss: 0.9156 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6786 - loss: 0.8931 - val_accuracy: 0.6823 - val_loss: 0.8888 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6827 - loss: 0.8838 - val_accuracy: 0.6749 - val_loss: 0.9126 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6887 - loss: 0.8703 - val_accuracy: 0.6838 - val_loss: 0.8838 学習時間:124.1秒 test_accuracy:0.6761 === C_batch256 === Epoch 1/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 9s 36ms/step - accuracy: 0.2200 - loss: 2.0542 - val_accuracy: 0.2927 - val_loss: 1.8794 Epoch 2/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 19ms/step - accuracy: 0.2933 - loss: 1.8405 - val_accuracy: 0.3423 - val_loss: 1.7421 Epoch 3/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.3403 - loss: 1.7300 - val_accuracy: 0.3805 - val_loss: 1.6555 Epoch 4/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.3703 - loss: 1.6694 - val_accuracy: 0.3704 - val_loss: 1.6683 Epoch 5/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.3990 - loss: 1.6122 - val_accuracy: 0.4050 - val_loss: 1.6039 Epoch 6/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 5s 19ms/step - accuracy: 0.4168 - loss: 1.5766 - val_accuracy: 0.4301 - val_loss: 1.5615 Epoch 7/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4336 - loss: 1.5401 - val_accuracy: 0.4418 - val_loss: 1.5169 Epoch 8/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4479 - loss: 1.4985 - val_accuracy: 0.4732 - val_loss: 1.4456 Epoch 9/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4645 - loss: 1.4569 - val_accuracy: 0.4700 - val_loss: 1.4492 Epoch 10/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 19ms/step - accuracy: 0.4793 - loss: 1.4243 - val_accuracy: 0.4887 - val_loss: 1.4007 Epoch 11/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4913 - loss: 1.3964 - val_accuracy: 0.4949 - val_loss: 1.3885 Epoch 12/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4970 - loss: 1.3719 - val_accuracy: 0.5000 - val_loss: 1.3637 Epoch 13/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5067 - loss: 1.3508 - val_accuracy: 0.5191 - val_loss: 1.3246 Epoch 14/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5156 - loss: 1.3281 - val_accuracy: 0.5325 - val_loss: 1.2952 Epoch 15/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5188 - loss: 1.3129 - val_accuracy: 0.5294 - val_loss: 1.2890 Epoch 16/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5233 - loss: 1.3020 - val_accuracy: 0.5342 - val_loss: 1.2754 Epoch 17/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5306 - loss: 1.2857 - val_accuracy: 0.5304 - val_loss: 1.2797 Epoch 18/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5371 - loss: 1.2727 - val_accuracy: 0.5371 - val_loss: 1.2859 Epoch 19/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 19ms/step - accuracy: 0.5422 - loss: 1.2584 - val_accuracy: 0.5541 - val_loss: 1.2318 Epoch 20/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 5s 18ms/step - accuracy: 0.5481 - loss: 1.2395 - val_accuracy: 0.5243 - val_loss: 1.2764 Epoch 21/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5504 - loss: 1.2346 - val_accuracy: 0.5482 - val_loss: 1.2261 Epoch 22/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 5s 19ms/step - accuracy: 0.5569 - loss: 1.2183 - val_accuracy: 0.5681 - val_loss: 1.1830 Epoch 23/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5603 - loss: 1.2117 - val_accuracy: 0.5584 - val_loss: 1.2174 Epoch 24/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5611 - loss: 1.2100 - val_accuracy: 0.5609 - val_loss: 1.2076 Epoch 25/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5659 - loss: 1.1950 - val_accuracy: 0.5577 - val_loss: 1.2040 Epoch 26/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 5s 18ms/step - accuracy: 0.5681 - loss: 1.1887 - val_accuracy: 0.5742 - val_loss: 1.1669 Epoch 27/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5715 - loss: 1.1728 - val_accuracy: 0.5817 - val_loss: 1.1503 Epoch 28/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5791 - loss: 1.1635 - val_accuracy: 0.5812 - val_loss: 1.1436 Epoch 29/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 5s 19ms/step - accuracy: 0.5811 - loss: 1.1550 - val_accuracy: 0.5886 - val_loss: 1.1386 Epoch 30/30 157/157 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5856 - loss: 1.1494 - val_accuracy: 0.5811 - val_loss: 1.1466 学習時間:104.0秒 test_accuracy:0.5790
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('batchsize_comparison.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(収束の安定性を見る)────────
fig2, axes2 = plt.subplots(3, 1, figsize=(7, 15))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(f'{label}')
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('batchsize_stability.png', dpi=150)
plt.show()
# ── 最終結果サマリー ─────────────────────────────────
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>12} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8}")
print("-" * 48)
for label in ['batch16', 'batch64', 'batch256']:
val_acc = histories[label].history['val_accuracy'][-1]
test_acc = scores[label][1]
t = times[label]
print(f"{label:>12} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f}")
print("-" * 48)
結果サマリ
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Time(s)
------------------------------------------------
batch16 | 0.7172 | 0.7089 | 355.8
batch64 | 0.6838 | 0.6761 | 124.1
batch256 | 0.5811 | 0.5790 | 104.0
------------------------------------------------
実験結果
精度グラフ
損失グラフ
batch16
batch64
batch256
| パターン | 最終 val_accuracy | 最終 test_accuracy | 学習時間 | 収束の傾向 |
|---|---|---|---|---|
| A:batch_size=16 | 71.72% | 70.89% | 355.8秒 | 最高精度・最長学習時間 |
| B:batch_size=64 | 68.38% | 67.61% | 124.1秒 | 精度・速度のバランス型 |
| C:batch_size=256 | 58.11% | 57.90% | 104.0秒 | 最低精度・最短学習時間 |
考察
① batch_size=16(小さい)が最高精度になった理由
batch_size=16 が最高精度(70.89%)となり、batch64より約3.3%、batch256より約13%上回りました。1エポックあたり2,500回という高頻度の更新が、より細かいパラメータの調整を可能にした結果です。
小さいバッチサイズでは各バッチのサンプル数が少ないため勾配にノイズが乗りやすくなります。このノイズが正則化として機能し、シャープな局所最適解を避けて汎化性能の高い解に収束しやすくなるという効果があります。今回はDropout=0.2という正則化と組み合わさり、特に効果が出た可能性があります。
ただし学習時間は355.8秒と、batch64の約2.9倍かかっています。精度と計算コストのトレードオフを意識した上で選択する必要があります。
② batch_size=256(大きい)が大幅に精度を落とした理由
batch_size=256 のtest_accuracy(57.90%)はbatch16比で約13%低く、3パターンの中で最も大きな差が出ました。1エポックあたりの更新回数は157回と少なく、勾配の推定は安定しますが、更新回数が少なすぎて30エポックでは十分に学習できなかった可能性があります。
また大きいバッチサイズでは「シャープな局所最適解」に収束しやすくなることが研究で知られています。Adamのデフォルト学習率(0.001)はbatch64前後を前提に設計されているため、batch256でデフォルト学習率をそのまま使うと学習が不十分になりやすいのです。batch256で精度を上げたい場合は学習率を大きくする(例:0.003〜0.01)かエポック数を増やすことを検討する必要があります。
③ バッチサイズと学習率・エポック数はセットで調整すべき
今回の実験では学習率はAdamのデフォルト(0.001)・エポック数は30に固定したため、バッチサイズを変えた場合の不利が最も大きく出たのはbatch256です。
「Linear Scaling Rule」という経験則では、バッチサイズをk倍にしたら学習率もk倍にすると良いとされています。この法則に従えば:
| batch_size | 今回の学習率(固定) | Linear Scaling Ruleに基づく推奨学習率 |
|---|---|---|
| 16 | 0.001 | 0.00025(÷4) |
| 64(基準) | 0.001 | 0.001 |
| 256 | 0.001 | 0.004(×4) |
今回はあえて学習率を固定することで「バッチサイズだけを変えた場合の純粋な影響」を観察しています。実際の開発ではバッチサイズと学習率をセットで調整することで、より公平な比較と最適な精度が得られます。
④ 精度と学習時間のトレードオフ
3パターンの精度と学習時間を比較すると次のようになります。
| パターン | Test Acc | 学習時間 | 精度/時間(効率) |
|---|---|---|---|
| batch16 | 70.89% | 355.8秒 | 0.199%/秒 |
| batch64 | 67.61% | 124.1秒 | 0.545%/秒 |
| batch256 | 57.90% | 104.0秒 | 0.557%/秒 |
精度を時間で割った「効率」で見ると、batch64とbatch256がほぼ同等で、batch16は最も非効率です。精度を最優先にするならbatch16、時間あたりの精度効率を重視するならbatch64が最適という結果になりました。
まとめ
- batch16が最高精度(70.89%)。小さいバッチサイズの勾配ノイズが正則化として機能した
- batch256は最低精度(57.90%)。デフォルト学習率・30エポックでは学習が不十分になった。学習率かエポック数の調整が必要
- 精度を時間効率で見るとbatch64が最もバランスが良い(0.545%/秒)
- バッチサイズを変える場合は学習率をセットで調整するのが本来の正しいアプローチ
- 精度最優先 → batch16〜32、速度と精度のバランス → batch64、速度優先かつ学習率調整あり → batch128〜256
関連記事もあわせてどうぞ:
- MNISTでのバッチサイズ比較 → バッチサイズを変えると精度や学習効率はどう変わる?画像分類で徹底比較
- Adamのlearning_rate比較 → Adamのlearning_rateを変えると何が起きる?(0.001 vs 0.01 vs 0.0001)【Keras×CIFAR-10実験】
- Dropout率の比較実験 → Dropoutの割合(0.0 vs 0.2 vs 0.5)を変えると過学習はどう変わる?【Keras×CIFAR-10実験】
- EarlyStoppingの実験 → EarlyStoppingは本当に効く?restore_best_weightsとpatienceを実験で比較【Keras×CIFAR-10】
▶EN English Summary
Batch size determines how many samples are used in each weight update during training.
It is one of the most important hyperparameters in deep learning, affecting
accuracy, training speed, and convergence stability.
This article compares three values — 16, 64, and 256 —
on CIFAR-10 using Keras.
| Batch size | Updates per epoch* | Characteristics |
|---|---|---|
16 (small) | 2,500 | Noisy updates, better generalization, slower training |
64 (medium) | 625 | Balanced — widely used as a default |
256 (large) | 157 | Stable updates, faster training, risk of sharp minima |
* Based on 40,000 training samples (CIFAR-10 train split)
Key findings16: highest validation accuracy but slowest training; more prone to instability.64: best overall balance of accuracy and training time.256: fastest to train but validation accuracy dropped slightly — a sign of reduced generalization.
Large batches tend to converge to sharp minima in the loss landscape, which generalize poorly to unseen data. Small batches introduce gradient noise that acts as implicit regularization, helping the model find flat minima that generalize better. This is known as the large-batch training problem.
Practical recommendation
Start with batch_size=64 as a default for CIFAR-10.
If you have limited GPU memory, 32 is also a safe choice.
Avoid going above 256 without a corresponding increase in learning rate
(e.g. linear scaling rule: lr × batch_size / 64).
model.fit(x_train, y_train, batch_size=64, epochs=50)






0 件のコメント:
コメントを投稿