【Keras】glorot_uniformとhe_normalの違いをCNNで比較|ReLUならどれを選ぶ?

投稿日:2026年3月2日月曜日 最終更新日:

CNN Dense Google Colab He Keras MNIST Xavier

X f B! P L
kernel_initializer を変えると精度は変わる? Xavier vs He をCNNで比較 アイキャッチ画像

KerasでCNNを書くとき、kernel_initializerを意識したことはありますか?

デフォルト(glorot_uniform)のままにしている方が多いと思いますが、活性化関数によっては初期化を変えると収束が速くなることがあります。

そこで今回は、Google ColabとMNISTを使い、3種類の初期化方法を実際に動かして比較しました。

この記事を読むとわかること:

  • glorot_uniform(Xavier)とhe_normalで収束速度はどれだけ違うか
  • zerosで初期化するとどうなるか(対照実験)
  • ReLUを使う場合、どの初期化を選ぶべきか

kernel_initializerとは?

kernel_initializerは、モデル学習開始時に各層の重みをどんな値で初期化するかを指定するパラメータです。

「なぜ初期値が重要なのか」というと、初期値が悪いと勾配消失・爆発が起きやすくなり、学習が遅くなったり収束しなかったりするためです。

今回比較する3種類を簡単に整理します。

  • glorot_uniform(Xavier均一分布):Kerasのデフォルト。sigmoid・tanhとの相性が良い。
  • he_normal(He正規分布):ReLU系活性化関数向けに設計。分散をReLUの特性に合わせて調整している。
  • zeros:全ての重みを0で初期化。対称性の問題で学習が進まない「悪い例」として使用。

※各手法の数式・理論的背景は → Dense層の重み初期化とは?XavierとHeの違いをKeras実験で比較

実験設定:3パターンのモデルコード

使用環境はGoogle Colab(GPU:T4)、データセットはMNISTです。

モデルはConv2D×2層 + Dense層のCNNを使い、kernel_initializerの指定だけを変えた3パターンで比較しました。

それ以外の条件(層数・ユニット数・Optimizer・エポック数)は全て同一です。

実験コード

環境準備(最初に一度だけ実行)

# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fonts-ipafont-mincho
The following NEW packages will be installed:
  fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 2 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 3s (3,211 kB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 117540 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 36.5 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
  Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了

前準備と学習

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib

# MNISTデータの読み込み・前処理
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test  = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0

# 共通コンパイル&学習関数
def compile_and_fit(model):
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    return model.fit(
        x_train, y_train,
        epochs=20,
        batch_size=128,
        validation_split=0.2,
        verbose=1
    )

def build_cnn(initializer, name):
    return keras.Sequential([
        keras.layers.Input(shape=(28, 28, 1)),
        keras.layers.Conv2D(32, (3, 3), activation='relu',
                           kernel_initializer=initializer),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation='relu',
                           kernel_initializer=initializer),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(128, activation='relu',
                          kernel_initializer=initializer),
        keras.layers.Dense(10, activation='softmax')
    ], name=name)

print("\n=== Pattern A:glorot_uniform(デフォルト)===")
model_A = build_cnn('glorot_uniform', 'A_glorot')
history_A = compile_and_fit(model_A)

print("\n=== Pattern B:he_normal(ReLU推奨)===")
model_B = build_cnn('he_normal', 'B_he')
history_B = compile_and_fit(model_B)

print("\n=== Pattern C:zeros(対照実験)===")
model_C = build_cnn('zeros', 'C_zeros')
history_C = compile_and_fit(model_C)

# テストデータで最終精度を評価
test_results = {
    'A:glorot_uniform': model_A.evaluate(x_test, y_test, verbose=0),
    'B:he_normal':      model_B.evaluate(x_test, y_test, verbose=0),
    'C:zeros':          model_C.evaluate(x_test, y_test, verbose=0),
}
実行結果をクリックして内容を開く
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step

=== Pattern A:glorot_uniform(デフォルト)===
Epoch 1/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 87s 220ms/step - accuracy: 0.8462 - loss: 0.5088 - val_accuracy: 0.9747 - val_loss: 0.0879
Epoch 2/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 44s 117ms/step - accuracy: 0.9792 - loss: 0.0697 - val_accuracy: 0.9797 - val_loss: 0.0617
Epoch 3/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 110ms/step - accuracy: 0.9863 - loss: 0.0470 - val_accuracy: 0.9833 - val_loss: 0.0530
Epoch 4/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 46s 123ms/step - accuracy: 0.9886 - loss: 0.0373 - val_accuracy: 0.9883 - val_loss: 0.0408
Epoch 5/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 77s 110ms/step - accuracy: 0.9919 - loss: 0.0253 - val_accuracy: 0.9872 - val_loss: 0.0415
Epoch 6/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 81s 108ms/step - accuracy: 0.9940 - loss: 0.0202 - val_accuracy: 0.9885 - val_loss: 0.0392
Epoch 7/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 45s 119ms/step - accuracy: 0.9949 - loss: 0.0159 - val_accuracy: 0.9874 - val_loss: 0.0467
Epoch 8/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 81s 115ms/step - accuracy: 0.9963 - loss: 0.0125 - val_accuracy: 0.9892 - val_loss: 0.0390
Epoch 9/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 109ms/step - accuracy: 0.9972 - loss: 0.0086 - val_accuracy: 0.9884 - val_loss: 0.0431
Epoch 10/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 108ms/step - accuracy: 0.9972 - loss: 0.0091 - val_accuracy: 0.9887 - val_loss: 0.0461
Epoch 11/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 109ms/step - accuracy: 0.9971 - loss: 0.0089 - val_accuracy: 0.9883 - val_loss: 0.0476
Epoch 12/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 40s 108ms/step - accuracy: 0.9978 - loss: 0.0072 - val_accuracy: 0.9908 - val_loss: 0.0381
Epoch 13/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 109ms/step - accuracy: 0.9983 - loss: 0.0050 - val_accuracy: 0.9892 - val_loss: 0.0453
Epoch 14/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 110ms/step - accuracy: 0.9990 - loss: 0.0035 - val_accuracy: 0.9899 - val_loss: 0.0476
Epoch 15/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 84s 114ms/step - accuracy: 0.9978 - loss: 0.0071 - val_accuracy: 0.9904 - val_loss: 0.0418
Epoch 16/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 43s 114ms/step - accuracy: 0.9982 - loss: 0.0046 - val_accuracy: 0.9886 - val_loss: 0.0462
Epoch 17/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 109ms/step - accuracy: 0.9988 - loss: 0.0035 - val_accuracy: 0.9902 - val_loss: 0.0464
Epoch 18/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 40s 107ms/step - accuracy: 0.9989 - loss: 0.0035 - val_accuracy: 0.9906 - val_loss: 0.0452
Epoch 19/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 108ms/step - accuracy: 0.9988 - loss: 0.0035 - val_accuracy: 0.9908 - val_loss: 0.0502
Epoch 20/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 111ms/step - accuracy: 0.9990 - loss: 0.0028 - val_accuracy: 0.9905 - val_loss: 0.0438

=== Pattern B:he_normal(ReLU推奨)===
Epoch 1/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 44s 111ms/step - accuracy: 0.8763 - loss: 0.4003 - val_accuracy: 0.9782 - val_loss: 0.0707
Epoch 2/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 110ms/step - accuracy: 0.9841 - loss: 0.0533 - val_accuracy: 0.9837 - val_loss: 0.0541
Epoch 3/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 112ms/step - accuracy: 0.9884 - loss: 0.0386 - val_accuracy: 0.9892 - val_loss: 0.0403
Epoch 4/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 109ms/step - accuracy: 0.9923 - loss: 0.0248 - val_accuracy: 0.9864 - val_loss: 0.0478
Epoch 5/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 110ms/step - accuracy: 0.9933 - loss: 0.0214 - val_accuracy: 0.9885 - val_loss: 0.0366
Epoch 6/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 112ms/step - accuracy: 0.9955 - loss: 0.0143 - val_accuracy: 0.9875 - val_loss: 0.0488
Epoch 7/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 44s 117ms/step - accuracy: 0.9955 - loss: 0.0130 - val_accuracy: 0.9898 - val_loss: 0.0357
Epoch 8/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 45s 119ms/step - accuracy: 0.9972 - loss: 0.0087 - val_accuracy: 0.9891 - val_loss: 0.0420
Epoch 9/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 45s 119ms/step - accuracy: 0.9979 - loss: 0.0067 - val_accuracy: 0.9876 - val_loss: 0.0562
Epoch 10/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 109ms/step - accuracy: 0.9967 - loss: 0.0079 - val_accuracy: 0.9880 - val_loss: 0.0422
Epoch 11/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 85s 116ms/step - accuracy: 0.9983 - loss: 0.0049 - val_accuracy: 0.9898 - val_loss: 0.0464
Epoch 12/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 43s 114ms/step - accuracy: 0.9985 - loss: 0.0045 - val_accuracy: 0.9906 - val_loss: 0.0455
Epoch 13/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 81s 111ms/step - accuracy: 0.9986 - loss: 0.0043 - val_accuracy: 0.9898 - val_loss: 0.0478
Epoch 14/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 111ms/step - accuracy: 0.9984 - loss: 0.0044 - val_accuracy: 0.9897 - val_loss: 0.0464
Epoch 15/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 111ms/step - accuracy: 0.9982 - loss: 0.0047 - val_accuracy: 0.9894 - val_loss: 0.0517
Epoch 16/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 45s 119ms/step - accuracy: 0.9982 - loss: 0.0056 - val_accuracy: 0.9911 - val_loss: 0.0418
Epoch 17/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 79s 111ms/step - accuracy: 0.9986 - loss: 0.0037 - val_accuracy: 0.9878 - val_loss: 0.0616
Epoch 18/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 111ms/step - accuracy: 0.9986 - loss: 0.0042 - val_accuracy: 0.9904 - val_loss: 0.0543
Epoch 19/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 112ms/step - accuracy: 0.9992 - loss: 0.0024 - val_accuracy: 0.9887 - val_loss: 0.0585
Epoch 20/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 83s 115ms/step - accuracy: 0.9987 - loss: 0.0041 - val_accuracy: 0.9910 - val_loss: 0.0459

=== Pattern C:zeros(対照実験)===
Epoch 1/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 106ms/step - accuracy: 0.1118 - loss: 2.3020 - val_accuracy: 0.1060 - val_loss: 2.3019
Epoch 2/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 40s 107ms/step - accuracy: 0.1139 - loss: 2.3012 - val_accuracy: 0.1060 - val_loss: 2.3020
Epoch 3/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 108ms/step - accuracy: 0.1149 - loss: 2.3009 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 4/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 109ms/step - accuracy: 0.1160 - loss: 2.3006 - val_accuracy: 0.1060 - val_loss: 2.3020
Epoch 5/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 110ms/step - accuracy: 0.1172 - loss: 2.3008 - val_accuracy: 0.1060 - val_loss: 2.3020
Epoch 6/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 40s 107ms/step - accuracy: 0.1164 - loss: 2.3009 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 7/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 109ms/step - accuracy: 0.1125 - loss: 2.3011 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 8/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 43s 114ms/step - accuracy: 0.1136 - loss: 2.3012 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 9/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 81s 112ms/step - accuracy: 0.1137 - loss: 2.3008 - val_accuracy: 0.1060 - val_loss: 2.3020
Epoch 10/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 112ms/step - accuracy: 0.1129 - loss: 2.3011 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 11/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 109ms/step - accuracy: 0.1138 - loss: 2.3012 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 12/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 108ms/step - accuracy: 0.1150 - loss: 2.3008 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 13/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 112ms/step - accuracy: 0.1150 - loss: 2.3009 - val_accuracy: 0.1060 - val_loss: 2.3020
Epoch 14/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 40s 107ms/step - accuracy: 0.1135 - loss: 2.3010 - val_accuracy: 0.1060 - val_loss: 2.3020
Epoch 15/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 40s 108ms/step - accuracy: 0.1140 - loss: 2.3011 - val_accuracy: 0.1060 - val_loss: 2.3020
Epoch 16/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 42s 112ms/step - accuracy: 0.1132 - loss: 2.3010 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 17/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 82s 112ms/step - accuracy: 0.1155 - loss: 2.3008 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 18/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 41s 108ms/step - accuracy: 0.1115 - loss: 2.3015 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 19/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 40s 106ms/step - accuracy: 0.1126 - loss: 2.3012 - val_accuracy: 0.1060 - val_loss: 2.3021
Epoch 20/20
375/375 ━━━━━━━━━━━━━━━━━━━━ 43s 111ms/step - accuracy: 0.1169 - loss: 2.3006 - val_accuracy: 0.1060 - val_loss: 2.3020

学習曲線の可視化と最終結果サマリー

histories = {
    'A:glorot_uniform': history_A,
    'B:he_normal':      history_B,
    'C:zeros':          history_C,
}

# ── 全体グラフ(20エポック)──────────────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
    axes[0].plot(h.history['val_accuracy'], label=label)
    axes[1].plot(h.history['val_loss'],     label=label)
axes[0].set_title('val_accuracy の比較(全20エポック)')
axes[1].set_title('val_loss の比較(全20エポック)')
for ax in axes:
    ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('init_comparison_full.png', dpi=150)
plt.show()

# ── 全体グラフ(20エポック):zeros除外 ──────────────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
    if label != 'C:zeros':  # 「C:zeros」を除外
        axes[0].plot(h.history['val_accuracy'], label=label)
        axes[1].plot(h.history['val_loss'],     label=label)
axes[0].set_title('val_accuracy の比較(全20エポック):zeros除外')
axes[1].set_title('val_loss の比較(全20エポック):zeros除外')
for ax in axes:
    ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('init_comparison_full2.png', dpi=150)
plt.show()

# ── 序盤拡大グラフ(1〜5エポック)──────────────────
fig2, axes2 = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
    if label != 'C:zeros':  # 「C:zeros」を除外
        axes2[0].plot(h.history['val_accuracy'][:5], label=label, marker='o')
        axes2[1].plot(h.history['val_loss'][:5],     label=label, marker='o')
axes2[0].set_title('val_accuracy の比較(序盤1〜5エポック)')
axes2[1].set_title('val_loss の比較(序盤1〜5エポック)')
for ax in axes2:
    ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('init_comparison_early.png', dpi=150)
plt.show()

# ── 最終結果サマリー ─────────────────────────────
key_order = ['A:glorot_uniform', 'B:he_normal', 'C:zeros']
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>20} | {'Train Acc':>10} | {'Val Acc':>8} | {'Test Acc':>9}")
print("-" * 58)
for key in key_order:
    h = histories[key]
    test_loss, test_acc = test_results[key]
    train_acc = h.history['accuracy'][-1]
    val_acc   = h.history['val_accuracy'][-1]
    print(f"{key:>20} | {train_acc:>10.4f} | {val_acc:>8.4f} | {test_acc:>9.4f}")
print("-" * 58)

最終結果サマリーの実行結果

===== 最終結果サマリー =====
             Pattern |  Train Acc |  Val Acc |  Test Acc
----------------------------------------------------------
    A:glorot_uniform |     0.9986 |   0.9905 |    0.9911
         B:he_normal |     0.9994 |   0.9910 |    0.9920
             C:zeros |     0.1140 |   0.1060 |    0.1135
----------------------------------------------------------

実験結果:学習曲線グラフを3本並べて比較

val_accuracy の比較(全20エポック)

精度グラフ(全20エポック)

Keras CNN 重み初期化 比較 精度グラフ(全20エポック)

損失グラフ(全20エポック)

Keras CNN 重み初期化 比較 損失グラフ(全20エポック)

パターンCの影響で、グラフが潰れてしまっています。パターンC(zeros)は、全ての重みが同一値で初期化されるため対称性が破れず、各ニューロンが同一の勾配更新を受け続けます。その結果、有効な特徴分離が起こらず、学習がほぼ進みません。

val_accuracy の比較(全20エポック):パターンC除外

精度グラフ(全20エポック):パターンC除外

Keras CNN 重み初期化 比較 精度グラフ(全20エポック):パターンC除外

損失グラフ(全20エポック):パターンC除外

Keras CNN 重み初期化 比較 損失グラフ(全20エポック):パターンC除外

グラフから読み取れること:

  • パターンA(glorot_uniform)とB(he_normal)の最終精度はほぼ同等でしたが、序盤の収束速度に違いが見られました。
  • val_lossはhe_normalの方が序盤に低くなる傾向があり、ReLUとの相性の良さが数値にも表れています。

序盤(1〜5エポック)の拡大グラフ

精度グラフ(1〜5エポック)

Keras CNN 重み初期化 比較 精度グラフ(1〜5エポック)

損失グラフ(1〜5エポック)

Keras CNN 重み初期化 比較 損失グラフ(1〜5エポック)

序盤を拡大すると、パターンB(he_normal)の方が1〜3エポック目の立ち上がりがやや速いことが確認できます。ReLUを使うモデルではhe_normalの方が理論的に適切なため、この結果は予想通りです。

結論:ReLUを使うなら何を選ぶべきか

今回の実験結果をもとに結論をまとめます。

ReLUを使っているなら → he_normal を指定する

glorot_uniformとの最終精度の差は小さいですが、序盤の収束がやや速くなります。コードを1行変えるだけなので、とりあえず入れておくのがおすすめです。

sigmoid / tanh を使っているなら → glorot_uniform(デフォルト)のまま

Kerasのデフォルトは意図的にXavierが選ばれています。sigmoid/tanhとの相性が良く、変える必要はありません。

zerosは使わない

対称性の問題で学習がほぼ進みません。今回の実験でも明確に確認できました。

正直なところ、MNISTのようなシンプルなタスクでは glorot_uniform と he_normal の差は小さく、「どちらを選んでも大きく変わらない」というのが実測での結論です。 ただし、より深いモデルや学習が不安定なケースでは初期化の影響が大きくなるため、ReLUを使うなら意識的に he_normal を選ぶ習慣をつけておくと良いでしょう。

補足:活性化関数ごとの推奨初期化まとめ

迷ったときのリファレンスとして表にまとめました。

活性化関数推奨 initializerKerasでの指定方法理由
ReLUhe_normal / he_uniformkernel_initializer='he_normal'ReLUの非線形性(負→0)に合わせて分散を調整
LeakyReLU / ELUhe_normalkernel_initializer='he_normal'ReLU系のため同様の理由で有効
sigmoid / tanhglorot_uniform / glorot_normalkernel_initializer='glorot_uniform'(デフォルト)両方向に活性化する関数向けのXavier設計
softmax(出力層)glorot_uniformデフォルトのまま出力層はどちらでも大差なし

まとめ

今回はKerasとMNISTを使い、kernel_initializerを3パターンで比較しました。

  • ReLUを使うなら he_normal(序盤の収束がやや速い)
  • sigmoid/tanhを使うなら glorot_uniform(Kerasのデフォルトのまま)
  • zerosは使わない(対称性の問題で学習が進まない)

MNISTレベルのタスクでは最終精度の差は小さいですが、初期化の選び方を意識するだけでコードの質が上がります。

実務ではBatchNormalizationを併用することも多く、その場合は初期化の影響はさらに小さくなる傾向があります。

関連記事もあわせてどうぞ: