BatchNormalizationの効果をKerasで実験|あり vs なしを比較

投稿日:2025年9月8日月曜日 最終更新日:

BatchNormalization Google Colab Keras MNIST

X f B! L
Batch Normalization解説記事のアイキャッチ画像(イメージ)

本記事では、Batch Normalization(BN)を使った場合と使わない場合で、学習の安定性・収束速度・最終精度がどう変わるかを MNIST + CNN で実験し、分かりやすくまとめています。
そのまま動かせる Google Colabコード付き なので、「どう効くのかを実際に確かめたい」方に最適です。


Batch Normalizationとは?(超シンプルに)

Batch Normalization(BN)は、各層の出力を標準化して学習を安定化させる手法です。勾配が暴れにくくなるため、 「速く・安定して学習が進む」というメリットがあります。

なぜ効くのか?直感的な理解

  • 勾配が安定する:スケールが整うことで爆発や消失を防ぎやすい
  • 収束が速い:多少高めの学習率でも壊れにくい
  • わずかな正則化効果:ミニバッチ統計の揺らぎが過学習を抑制

TIP(実務)
Conv層では Conv → BatchNorm → ReLU が定番。
Dense層でも Dense → BatchNorm → ReLU がよく使われます。

実験設定(MNIST・CNN)

  • データ: MNIST(28×28グレースケール・10クラス)
  • モデル: 同一アーキテクチャで BNありBNなし を比較
  • 評価: トレーニング/検証の精度とロスの推移、最終テスト精度
  • 実行環境: Google Colab + TensorFlow/Keras(CPUでも数分)

Colabで実行:BNあり/なしを比較

以下のコードをColabに貼り付ければ、そのまま実験できます(ランダム性により結果は毎回少し変わります)。

# %% Colab-ready: TensorFlow/Keras MNIST BN vs No-BN
import os, random, numpy as np, tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models

# 再現性(目安)
seed = 42
random.seed(seed); np.random.seed(seed); tf.random.set_seed(seed)

# 1) データセット読み込み
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = (x_train.astype("float32") / 255.0)[..., None]  # (60000,28,28,1)
x_test  = (x_test.astype("float32")  / 255.0)[..., None]

# 検証分割
x_train, x_val = x_train[:-5000], x_train[-5000:]
y_train, y_val = y_train[:-5000], y_train[-5000:]

def conv_block(filters, use_bn):
    block = keras.Sequential()
    block.add(layers.Conv2D(filters, 3, padding="same", use_bias=not use_bn))
    if use_bn:
        block.add(layers.BatchNormalization())
    block.add(layers.ReLU())
    return block

def build_model(use_bn=False):
    inputs = keras.Input(shape=(28,28,1))
    x = conv_block(32, use_bn)(inputs)
    x = layers.MaxPooling2D()(x)
    x = conv_block(64, use_bn)(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, use_bias=not use_bn)(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    outputs = layers.Dense(10, activation="softmax")(x)
    model = keras.Model(inputs, outputs, name=f"mnist_cnn_bn_{use_bn}")
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-3),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model

bn_false = build_model(use_bn=False)
bn_true  = build_model(use_bn=True)

callbacks = [keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=3, restore_best_weights=True)]

hist_no = bn_false.fit(
    x_train, y_train, validation_data=(x_val, y_val),
    epochs=10, batch_size=128, callbacks=callbacks, verbose=1)

hist_bn = bn_true.fit(
    x_train, y_train, validation_data=(x_val, y_val),
    epochs=10, batch_size=128, callbacks=callbacks, verbose=1)

test_no = bn_false.evaluate(x_test, y_test, verbose=0)
test_bn = bn_true.evaluate(x_test, y_test, verbose=0)

print("=== Test Accuracy ===")
print("No-BN :", round(float(test_no[1]), 4))
print("BN    :", round(float(test_bn[1]), 4))

# 観察用:最終エポックの検証精度を出力
print("ValAcc(No-BN):", round(hist_no.history["val_accuracy"][-1], 4))
print("ValAcc(BN)   :", round(hist_bn.history["val_accuracy"][-1], 4))

実行結果

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
Epoch 1/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 12s 12ms/step - accuracy: 0.8475 - loss: 0.4886 - val_accuracy: 0.9804 - val_loss: 0.0707
Epoch 2/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 12s 5ms/step - accuracy: 0.9817 - loss: 0.0601 - val_accuracy: 0.9894 - val_loss: 0.0449
Epoch 3/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9885 - loss: 0.0384 - val_accuracy: 0.9874 - val_loss: 0.0493
Epoch 4/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9909 - loss: 0.0295 - val_accuracy: 0.9902 - val_loss: 0.0408
Epoch 5/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9928 - loss: 0.0228 - val_accuracy: 0.9894 - val_loss: 0.0411
Epoch 6/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - accuracy: 0.9950 - loss: 0.0167 - val_accuracy: 0.9896 - val_loss: 0.0435
Epoch 7/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9963 - loss: 0.0128 - val_accuracy: 0.9892 - val_loss: 0.0436
Epoch 1/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 8s 12ms/step - accuracy: 0.9353 - loss: 0.2316 - val_accuracy: 0.4478 - val_loss: 1.7774
Epoch 2/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 6s 5ms/step - accuracy: 0.9907 - loss: 0.0360 - val_accuracy: 0.9868 - val_loss: 0.0466
Epoch 3/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9962 - loss: 0.0171 - val_accuracy: 0.9908 - val_loss: 0.0316
Epoch 4/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9988 - loss: 0.0080 - val_accuracy: 0.9796 - val_loss: 0.0655
Epoch 5/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9989 - loss: 0.0057 - val_accuracy: 0.9894 - val_loss: 0.0410
Epoch 6/10
430/430 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.9985 - loss: 0.0058 - val_accuracy: 0.9886 - val_loss: 0.0472
=== Test Accuracy ===
No-BN : 0.9848
BN    : 0.9915
ValAcc(No-BN): 0.9892
ValAcc(BN)   : 0.9886

観察ポイント(収束・汎化・安定性)

  • 収束速度: 同じエポック数ならBNありモデルの方が検証精度の立ち上がりが早いことが多い。
  • 汎化: BNにより検証ロスの乱高下が減り、過学習に入りにくい挙動が見える場合がある。
  • 安定性: バッチ統計でスケールが整い、学習率を少し上げても破綻しにくい傾向。

※ 実際の数値は実行環境・乱数により変動します。上記は一般的に観察されやすい傾向です。

うまく使うコツ & よくある落とし穴

コツ

  • 層の順序: Conv/Dense → BatchNorm → 活性化(ReLUなど)。
  • 学習率: BN導入時は 1e-3 から開始し、学習が安定する上限を探る。
  • ドロップアウトとの併用: まずBNだけで安定させ、必要に応じて最後段に少量のDropoutを追加。

落とし穴

  • バッチサイズが極端に小さい: 統計が不安定。32〜128程度を目安に。
  • 推論時の挙動: 学習時と推論時でBNは動作が異なるため、model.eval()(PyTorch)や training=False(TF一部API)などの扱いに注意。
  • 転移学習時: 事前学習モデル内のBNを凍結/微調整する戦略で精度が変わる。両方試すと良い。

CIFAR-10でも試してみた:MNISTとの違い

MNISTは28×28のグレースケール画像で、比較的シンプルなデータセットです。そのため前のセクションで見たようにBNあり・なしの差が出にくい傾向があります。では、より難しいCIFAR-10(32×32カラー画像・10クラス)ではどうでしょうか?同じ構成でBNあり・なしを比較してみます。

コード(CIFAR-10版)

MNISTのコードから変更するのはデータ読み込みとinput shapeの2箇所だけです。

import random, numpy as np, tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

seed = 42
random.seed(seed); np.random.seed(seed); tf.random.set_seed(seed)

# ── CIFAR-10 読み込み(MNISTからの変更点①)──────────
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test  = x_test.astype("float32")  / 255.0

x_train, x_val = x_train[:-5000], x_train[-5000:]
y_train, y_val = y_train[:-5000], y_train[-5000:]

def conv_block(filters, use_bn):
    block = keras.Sequential()
    block.add(layers.Conv2D(filters, 3, padding="same", use_bias=not use_bn))
    if use_bn:
        block.add(layers.BatchNormalization())
    block.add(layers.ReLU())
    return block

def build_model(use_bn=False):
    inputs = keras.Input(shape=(32, 32, 3))  # 変更点②:input_shape
    x = conv_block(32, use_bn)(inputs)
    x = layers.MaxPooling2D()(x)
    x = conv_block(64, use_bn)(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, use_bias=not use_bn)(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    outputs = layers.Dense(10, activation="softmax")(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-3),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model

# EarlyStoppingなし・固定30エポックで比較
for use_bn in [False, True]:
    label = "BNあり" if use_bn else "BNなし"
    print(f"\n===== {label} =====")
    model = build_model(use_bn=use_bn)
    hist = model.fit(x_train, y_train,
                     validation_data=(x_val, y_val),
                     epochs=30, batch_size=128, verbose=1)
    test = model.evaluate(x_test, y_test, verbose=0)
    val_acc = max(hist.history['val_accuracy'])
    print(f"{label}:best_val_accuracy={val_acc:.4f} test_accuracy={test[1]:.4f}")
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 14s 0us/step

===== BNなし =====
Epoch 1/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 7s 13ms/step - accuracy: 0.4494 - loss: 1.5351 - val_accuracy: 0.5426 - val_loss: 1.2923
Epoch 2/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.5868 - loss: 1.1711 - val_accuracy: 0.5950 - val_loss: 1.1450
Epoch 3/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.6416 - loss: 1.0248 - val_accuracy: 0.6384 - val_loss: 1.0419
Epoch 4/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.6770 - loss: 0.9299 - val_accuracy: 0.6586 - val_loss: 0.9914
Epoch 5/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.7046 - loss: 0.8589 - val_accuracy: 0.6712 - val_loss: 0.9503
Epoch 6/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.7254 - loss: 0.7990 - val_accuracy: 0.6882 - val_loss: 0.9160
Epoch 7/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.7417 - loss: 0.7482 - val_accuracy: 0.6946 - val_loss: 0.8864
Epoch 8/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.7587 - loss: 0.6998 - val_accuracy: 0.6990 - val_loss: 0.8779
Epoch 9/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.7734 - loss: 0.6569 - val_accuracy: 0.7140 - val_loss: 0.8590
Epoch 10/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.7854 - loss: 0.6195 - val_accuracy: 0.7164 - val_loss: 0.8668
Epoch 11/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.7954 - loss: 0.5845 - val_accuracy: 0.7138 - val_loss: 0.8904
Epoch 12/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.8102 - loss: 0.5466 - val_accuracy: 0.7098 - val_loss: 0.9375
Epoch 13/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.8269 - loss: 0.5005 - val_accuracy: 0.7082 - val_loss: 0.9747
Epoch 14/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.8418 - loss: 0.4613 - val_accuracy: 0.7116 - val_loss: 0.9884
Epoch 15/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.8514 - loss: 0.4344 - val_accuracy: 0.6990 - val_loss: 1.0352
Epoch 16/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.8553 - loss: 0.4198 - val_accuracy: 0.7058 - val_loss: 1.0163
Epoch 17/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.8580 - loss: 0.4061 - val_accuracy: 0.7042 - val_loss: 1.0268
Epoch 18/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.8696 - loss: 0.3763 - val_accuracy: 0.6990 - val_loss: 1.0816
Epoch 19/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.8788 - loss: 0.3487 - val_accuracy: 0.6908 - val_loss: 1.1960
Epoch 20/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.8875 - loss: 0.3239 - val_accuracy: 0.6856 - val_loss: 1.2390
Epoch 21/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.8959 - loss: 0.3000 - val_accuracy: 0.6800 - val_loss: 1.3187
Epoch 22/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9028 - loss: 0.2784 - val_accuracy: 0.6846 - val_loss: 1.3382
Epoch 23/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9100 - loss: 0.2536 - val_accuracy: 0.6832 - val_loss: 1.3863
Epoch 24/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9128 - loss: 0.2458 - val_accuracy: 0.6798 - val_loss: 1.4176
Epoch 25/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9105 - loss: 0.2471 - val_accuracy: 0.6788 - val_loss: 1.5320
Epoch 26/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9140 - loss: 0.2425 - val_accuracy: 0.6646 - val_loss: 1.6540
Epoch 27/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9224 - loss: 0.2190 - val_accuracy: 0.6762 - val_loss: 1.5737
Epoch 28/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9384 - loss: 0.1776 - val_accuracy: 0.6818 - val_loss: 1.6426
Epoch 29/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 2s 7ms/step - accuracy: 0.9467 - loss: 0.1539 - val_accuracy: 0.6846 - val_loss: 1.7134
Epoch 30/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9512 - loss: 0.1426 - val_accuracy: 0.6934 - val_loss: 1.7816
BNなし:best_val_accuracy=0.7164 test_accuracy=0.6733

===== BNあり =====
Epoch 1/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 10s 19ms/step - accuracy: 0.5779 - loss: 1.1945 - val_accuracy: 0.3310 - val_loss: 2.3573
Epoch 2/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.7122 - loss: 0.8294 - val_accuracy: 0.6486 - val_loss: 0.9739
Epoch 3/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.7767 - loss: 0.6571 - val_accuracy: 0.6152 - val_loss: 1.1427
Epoch 4/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.8297 - loss: 0.5169 - val_accuracy: 0.5866 - val_loss: 1.3512
Epoch 5/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.8770 - loss: 0.3958 - val_accuracy: 0.5194 - val_loss: 1.7318
Epoch 6/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9117 - loss: 0.3010 - val_accuracy: 0.5908 - val_loss: 1.4540
Epoch 7/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9336 - loss: 0.2332 - val_accuracy: 0.6068 - val_loss: 1.4358
Epoch 8/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9545 - loss: 0.1696 - val_accuracy: 0.6498 - val_loss: 1.2906
Epoch 9/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - accuracy: 0.9715 - loss: 0.1210 - val_accuracy: 0.6256 - val_loss: 1.6121
Epoch 10/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.9800 - loss: 0.0903 - val_accuracy: 0.6014 - val_loss: 1.8824
Epoch 11/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9829 - loss: 0.0754 - val_accuracy: 0.6034 - val_loss: 2.0884
Epoch 12/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9846 - loss: 0.0656 - val_accuracy: 0.5380 - val_loss: 2.5366
Epoch 13/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - accuracy: 0.9826 - loss: 0.0664 - val_accuracy: 0.5426 - val_loss: 2.5628
Epoch 14/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9835 - loss: 0.0615 - val_accuracy: 0.5908 - val_loss: 2.1027
Epoch 15/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9864 - loss: 0.0507 - val_accuracy: 0.6514 - val_loss: 1.6660
Epoch 16/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.9908 - loss: 0.0407 - val_accuracy: 0.6692 - val_loss: 1.6354
Epoch 17/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9925 - loss: 0.0326 - val_accuracy: 0.6756 - val_loss: 1.6630
Epoch 18/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9923 - loss: 0.0309 - val_accuracy: 0.6638 - val_loss: 1.7410
Epoch 19/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9922 - loss: 0.0316 - val_accuracy: 0.6750 - val_loss: 1.7515
Epoch 20/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9926 - loss: 0.0292 - val_accuracy: 0.6700 - val_loss: 1.6944
Epoch 21/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.9903 - loss: 0.0336 - val_accuracy: 0.6846 - val_loss: 1.6626
Epoch 22/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9920 - loss: 0.0294 - val_accuracy: 0.6202 - val_loss: 2.3051
Epoch 23/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9936 - loss: 0.0256 - val_accuracy: 0.6774 - val_loss: 1.8275
Epoch 24/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9937 - loss: 0.0234 - val_accuracy: 0.6792 - val_loss: 1.7882
Epoch 25/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.9926 - loss: 0.0273 - val_accuracy: 0.6546 - val_loss: 2.0435
Epoch 26/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9930 - loss: 0.0254 - val_accuracy: 0.6656 - val_loss: 1.9317
Epoch 27/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9952 - loss: 0.0189 - val_accuracy: 0.6748 - val_loss: 1.9493
Epoch 28/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - accuracy: 0.9965 - loss: 0.0153 - val_accuracy: 0.6598 - val_loss: 1.9936
Epoch 29/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - accuracy: 0.9952 - loss: 0.0183 - val_accuracy: 0.6852 - val_loss: 1.8923
Epoch 30/30
352/352 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9940 - loss: 0.0210 - val_accuracy: 0.6896 - val_loss: 1.9063
BNあり:best_val_accuracy=0.6896 test_accuracy=0.6749

MNISTとCIFAR-10の結果比較

データセット BNなし(test_accuracy) BNあり(test_accuracy)
MNIST 99.1% 99.2% +0.1pt(ほぼ同じ)
CIFAR-10 67.33% 67.49% +0.16pt(ほぼ同じ)

予想外の結果:CIFAR-10でもBNの差はほとんど出なかった

test_accuracyはBNなし 67.33% vs BNあり 67.49% と差はわずか +0.16ポイント。MNISTと同様に、この浅い構成ではBNの有無が最終精度にほとんど影響しないという結果になりました。

⚠️ ハマりポイント:EarlyStoppingとBNの相性

最初にEarlyStoppingを使った実験ではBNありのtest_accuracy=0.31という異常値が出ました。BNありモデルはCIFAR-10の序盤でval_accuracyが不安定になりやすく、patience=5で早々に打ち切られ、restore_best_weightsが正しく機能しなかったことが原因です。BNを使う場合はEarlyStoppingのpatienceを大きめに設定するか、固定エポックで比較するのが安全です。

なぜ浅いネットではBNの差が出にくいのか

BNの本来の強みは深いネットワークでの勾配安定化にあります。今回の2Conv構成は層が浅いため、BNなしでも勾配消失は起きにくく、BNの恩恵が精度差として現れにくいと考えられます。

BNが効果を発揮しやすい条件は以下の通りです。

条件BNの効果
浅いネット(Conv 2〜3層)小さい(本実験の結果)
深いネット(Conv 5層以上・ResNetなど)大きい(収束安定・精度向上)
大きいバッチサイズ(128以上)安定して効果が出る
小さいバッチサイズ(16以下)統計が不安定になりむしろ不利な場合も

「BNを使えば必ず精度が上がる」わけではなく、モデルの深さとバッチサイズに依存するというのがこの実験から得られる実践的な結論です。

まとめ

  • BNは学習の安定化収束の加速に寄与しやすく、MNISTでもその傾向を観察しやすい。
  • 実装は簡単:Conv/Dense → BN → 活性化 を基本に設計。
  • 学習率・バッチサイズ・Dropoutなどと合わせて最適点を探すと、より効果が出る。

FAQ

Q. 小規模データでもBNは有効?

A. 有効なことが多いですが、バッチサイズが極小だと不安定になることがあります。GroupNormやLayerNormを試すのも手です。

Q. 活性化の前と後、どっちに置く?

A. 本記事では一般的な実装として活性化の前(Conv/Dense → BN → ReLU)を推奨しています。

Q. BNとDropoutはどちらを先に?

A. まずBNで安定化し、必要に応じて最終段にDropoutを少量追加する構成を試すと、チューニングしやすいです。