📘 この記事でわかること
- SEブロック単体・Residual単体・組み合わせ・なしの4パターンを一度に比較する方法
- 2つの手法を組み合わせたときに相乗効果が生まれるかどうか
- 精度・過学習・パラメータ数それぞれへの影響
- 実務で「どの組み合わせを選ぶか」の判断基準
SEブロックも残差接続(Residual)も、単体では効果が実証されています。では2つを同時に使ったらどうなるか?相乗効果が出るのか、それとも片方だけで十分なのか。今回はCIFAR-10で4パターンを一気に比較しました。
各手法の単体検証は以下も参考にしてください。
- SEブロック単体 → SEブロック(Squeeze-and-Excitation)を追加すると精度は上がるか?
- Residual単体 → Residual接続(スキップ接続)あり vs なし を比較
SEブロックとResidual接続のおさらい
Residual接続(スキップ接続)は入力をブロックの出力にそのまま足すことで勾配消失を防ぎ、深いネットワークを安定して学習させます。
SEブロック(Squeeze-and-Excitation)はチャネルごとの重要度をGlobal Average Poolingで「搾り出し(Squeeze)」、Denseで重み付けして「励起(Excitation)」する軽量なAttention機構です。
\[ \tilde{x} = F_{scale}(x,\, s) = x \cdot \sigma\!\left(W_2\,\delta(W_1\, z)\right), \quad z = \frac{1}{HW}\sum_{h,w} x_{h,w} \]
2つを組み合わせることで、「深さの安定化(Residual)」と「チャネル重要度の調整(SE)」を同時に得ることが期待できます。果たして実験結果はその期待に応えたでしょうか?
| パターン | Residual | SEブロック | 狙い |
|---|---|---|---|
| A:ベースライン | ✗ | ✗ | 比較基準 |
| B:Residualのみ | ✓ | ✗ | 勾配安定・過学習抑制 |
| C:SEのみ | ✗ | ✓ | チャネル重要度の調整 |
| D:SE + Residual | ✓ | ✓ | 相乗効果を狙う |
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。SEブロック・Residualの有無以外の条件は全て同一にして、それぞれの寄与を取り出します。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 51 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (28.5 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122363 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 54.1 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備・ブロック定義
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
# ── データ読み込み ─────────────────────────────────────
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# ── SEブロック定義 ─────────────────────────────────────
def se_block(x, ratio=8):
"""Squeeze-and-Excitation Block"""
filters = x.shape[-1]
# Squeeze: チャネルごとのグローバル平均
se = keras.layers.GlobalAveragePooling2D()(x) # (batch, C)
se = keras.layers.Reshape((1, 1, filters))(se) # (batch, 1, 1, C)
# Excitation: 2層DenseでチャネルAttention重みを生成
se = keras.layers.Dense(filters // ratio, activation='relu', use_bias=False)(se)
se = keras.layers.Dense(filters, activation='sigmoid', use_bias=False)(se)
return keras.layers.Multiply()([x, se]) # チャネルワイズ重み付け
# ── Conv ブロック(BN + ReLU付き) ─────────────────────
def conv_block(x, filters, kernel_size=3):
x = keras.layers.Conv2D(filters, kernel_size, padding='same', use_bias=False)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('relu')(x)
return x
# ── 4パターンのモデル構築 ──────────────────────────────
def build_model(name, use_residual=False, use_se=False):
inputs = keras.Input(shape=(32, 32, 3))
x = conv_block(inputs, 64)
x = keras.layers.MaxPooling2D((2, 2))(x)
# ブロック1
shortcut = x
out = conv_block(x, 128)
out = conv_block(out, 128)
if use_se:
out = se_block(out, ratio=8)
if use_residual:
# チャネル数が変わるので 1×1 Conv でshortcutを調整
shortcut = keras.layers.Conv2D(128, 1, padding='same', use_bias=False)(shortcut)
shortcut = keras.layers.BatchNormalization()(shortcut)
out = keras.layers.Add()([out, shortcut])
x = keras.layers.MaxPooling2D((2, 2))(out)
# ブロック2
shortcut2 = x
out2 = conv_block(x, 256)
out2 = conv_block(out2, 256)
if use_se:
out2 = se_block(out2, ratio=8)
if use_residual:
shortcut2 = keras.layers.Conv2D(256, 1, padding='same', use_bias=False)(shortcut2)
shortcut2 = keras.layers.BatchNormalization()(shortcut2)
out2 = keras.layers.Add()([out2, shortcut2])
x = keras.layers.GlobalAveragePooling2D()(out2)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.3)(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
return keras.Model(inputs, outputs, name=name)
def compile_and_fit(model, epochs=30):
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
start = time.time()
history = model.fit(x_train, y_train, epochs=epochs, batch_size=64,
validation_split=0.2, verbose=1)
return history, time.time() - start
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 11s 0us/step
4パターンの学習実行
configs = [
('A_baseline', False, False),
('B_residual', True, False),
('C_se', False, True),
('D_se_residual', True, True),
]
histories, times, scores, params = {}, {}, {}, {}
for name, use_res, use_se in configs:
print(f"\n=== {name} ===")
model = build_model(name, use_residual=use_res, use_se=use_se)
h, t = compile_and_fit(model)
s = model.evaluate(x_test, y_test, verbose=0)
label = name.split('_', 1)[1]
histories[label] = h
times[label] = t
scores[label] = s
params[label] = model.count_params()
print(f"学習時間:{t:.1f}秒 パラメータ数:{model.count_params():,} test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_baseline === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 21s 16ms/step - accuracy: 0.5010 - loss: 1.3793 - val_accuracy: 0.4396 - val_loss: 1.6545 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 12ms/step - accuracy: 0.6460 - loss: 1.0004 - val_accuracy: 0.5950 - val_loss: 1.1252 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 11s 18ms/step - accuracy: 0.7087 - loss: 0.8293 - val_accuracy: 0.6524 - val_loss: 1.0070 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 13ms/step - accuracy: 0.7532 - loss: 0.7107 - val_accuracy: 0.6429 - val_loss: 1.0778 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 12ms/step - accuracy: 0.7890 - loss: 0.6131 - val_accuracy: 0.6676 - val_loss: 1.0202 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.8161 - loss: 0.5360 - val_accuracy: 0.5819 - val_loss: 1.5081 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.8361 - loss: 0.4730 - val_accuracy: 0.6693 - val_loss: 1.0717 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 12ms/step - accuracy: 0.8583 - loss: 0.4114 - val_accuracy: 0.6398 - val_loss: 1.2132 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.8767 - loss: 0.3593 - val_accuracy: 0.7164 - val_loss: 0.9505 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.8932 - loss: 0.3124 - val_accuracy: 0.7586 - val_loss: 0.8198 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9046 - loss: 0.2766 - val_accuracy: 0.7141 - val_loss: 1.0403 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9190 - loss: 0.2351 - val_accuracy: 0.6722 - val_loss: 1.4325 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9313 - loss: 0.2036 - val_accuracy: 0.7605 - val_loss: 0.8748 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9386 - loss: 0.1764 - val_accuracy: 0.6898 - val_loss: 1.2430 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9456 - loss: 0.1593 - val_accuracy: 0.7621 - val_loss: 0.8903 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9536 - loss: 0.1363 - val_accuracy: 0.7419 - val_loss: 1.1569 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9603 - loss: 0.1167 - val_accuracy: 0.7600 - val_loss: 1.2443 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9611 - loss: 0.1127 - val_accuracy: 0.7700 - val_loss: 1.0161 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9650 - loss: 0.1026 - val_accuracy: 0.6290 - val_loss: 2.3993 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9664 - loss: 0.0975 - val_accuracy: 0.6444 - val_loss: 1.8738 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9709 - loss: 0.0863 - val_accuracy: 0.7429 - val_loss: 1.3325 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9715 - loss: 0.0827 - val_accuracy: 0.7908 - val_loss: 1.0279 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9751 - loss: 0.0747 - val_accuracy: 0.7839 - val_loss: 1.0695 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9771 - loss: 0.0680 - val_accuracy: 0.7843 - val_loss: 0.9813 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9771 - loss: 0.0648 - val_accuracy: 0.7234 - val_loss: 1.6567 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9780 - loss: 0.0656 - val_accuracy: 0.8109 - val_loss: 0.8687 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9793 - loss: 0.0611 - val_accuracy: 0.7869 - val_loss: 1.0697 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9795 - loss: 0.0601 - val_accuracy: 0.7445 - val_loss: 1.4373 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 13ms/step - accuracy: 0.9804 - loss: 0.0563 - val_accuracy: 0.6932 - val_loss: 1.6232 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9816 - loss: 0.0543 - val_accuracy: 0.7507 - val_loss: 1.3225 学習時間:272.5秒 パラメータ数:1,145,162 test_accuracy:0.7517 === B_residual === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 20s 19ms/step - accuracy: 0.4909 - loss: 1.4058 - val_accuracy: 0.4806 - val_loss: 1.3650 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 14s 15ms/step - accuracy: 0.6375 - loss: 1.0151 - val_accuracy: 0.4145 - val_loss: 1.8112 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.7046 - loss: 0.8442 - val_accuracy: 0.4723 - val_loss: 2.1169 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.7501 - loss: 0.7202 - val_accuracy: 0.5744 - val_loss: 1.2563 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.7854 - loss: 0.6197 - val_accuracy: 0.6792 - val_loss: 1.0155 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.8092 - loss: 0.5432 - val_accuracy: 0.5325 - val_loss: 2.0607 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.8365 - loss: 0.4740 - val_accuracy: 0.6023 - val_loss: 1.2024 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.8597 - loss: 0.4157 - val_accuracy: 0.6509 - val_loss: 1.3299 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.8778 - loss: 0.3578 - val_accuracy: 0.6867 - val_loss: 1.0979 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.8935 - loss: 0.3074 - val_accuracy: 0.6482 - val_loss: 1.3541 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9075 - loss: 0.2708 - val_accuracy: 0.7109 - val_loss: 1.0997 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9208 - loss: 0.2274 - val_accuracy: 0.7210 - val_loss: 1.0349 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9329 - loss: 0.1962 - val_accuracy: 0.7269 - val_loss: 1.0687 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 14ms/step - accuracy: 0.9397 - loss: 0.1751 - val_accuracy: 0.7498 - val_loss: 1.0163 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9471 - loss: 0.1504 - val_accuracy: 0.6321 - val_loss: 1.9068 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9513 - loss: 0.1426 - val_accuracy: 0.7905 - val_loss: 0.8223 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9610 - loss: 0.1127 - val_accuracy: 0.7428 - val_loss: 1.1915 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9642 - loss: 0.1066 - val_accuracy: 0.7939 - val_loss: 0.9305 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9646 - loss: 0.1033 - val_accuracy: 0.5745 - val_loss: 2.5853 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9683 - loss: 0.0923 - val_accuracy: 0.7527 - val_loss: 1.1583 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9718 - loss: 0.0827 - val_accuracy: 0.7680 - val_loss: 1.1941 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9743 - loss: 0.0750 - val_accuracy: 0.7320 - val_loss: 1.4908 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9758 - loss: 0.0686 - val_accuracy: 0.7869 - val_loss: 1.0109 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9781 - loss: 0.0657 - val_accuracy: 0.6248 - val_loss: 2.2460 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9778 - loss: 0.0648 - val_accuracy: 0.7205 - val_loss: 1.6026 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9782 - loss: 0.0639 - val_accuracy: 0.7917 - val_loss: 0.9875 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9811 - loss: 0.0571 - val_accuracy: 0.7834 - val_loss: 1.0851 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9825 - loss: 0.0518 - val_accuracy: 0.7590 - val_loss: 1.2902 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9803 - loss: 0.0573 - val_accuracy: 0.7453 - val_loss: 1.3956 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9826 - loss: 0.0497 - val_accuracy: 0.7667 - val_loss: 1.1933 学習時間:287.7秒 パラメータ数:1,187,658 test_accuracy:0.7635 === C_se === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 18s 17ms/step - accuracy: 0.5034 - loss: 1.3595 - val_accuracy: 0.3969 - val_loss: 2.0409 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.6567 - loss: 0.9728 - val_accuracy: 0.4379 - val_loss: 1.8056 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.7195 - loss: 0.8110 - val_accuracy: 0.3706 - val_loss: 3.3363 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.7643 - loss: 0.6858 - val_accuracy: 0.6126 - val_loss: 1.2420 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.7967 - loss: 0.5912 - val_accuracy: 0.6013 - val_loss: 1.2500 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.8248 - loss: 0.5134 - val_accuracy: 0.6808 - val_loss: 0.9687 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.8472 - loss: 0.4488 - val_accuracy: 0.6825 - val_loss: 1.0340 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.8685 - loss: 0.3876 - val_accuracy: 0.6509 - val_loss: 1.2048 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.8818 - loss: 0.3473 - val_accuracy: 0.7666 - val_loss: 0.7699 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9021 - loss: 0.2927 - val_accuracy: 0.7440 - val_loss: 0.8714 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9112 - loss: 0.2589 - val_accuracy: 0.5757 - val_loss: 1.8692 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9227 - loss: 0.2274 - val_accuracy: 0.7362 - val_loss: 0.9520 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9337 - loss: 0.1945 - val_accuracy: 0.6888 - val_loss: 1.2127 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9446 - loss: 0.1618 - val_accuracy: 0.7649 - val_loss: 0.8613 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 14ms/step - accuracy: 0.9471 - loss: 0.1527 - val_accuracy: 0.7778 - val_loss: 0.8750 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9536 - loss: 0.1379 - val_accuracy: 0.6978 - val_loss: 1.3429 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 14ms/step - accuracy: 0.9582 - loss: 0.1241 - val_accuracy: 0.7501 - val_loss: 1.0609 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 14ms/step - accuracy: 0.9647 - loss: 0.1028 - val_accuracy: 0.7735 - val_loss: 0.9707 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9671 - loss: 0.0958 - val_accuracy: 0.8226 - val_loss: 0.7650 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9694 - loss: 0.0905 - val_accuracy: 0.7597 - val_loss: 1.1820 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9706 - loss: 0.0844 - val_accuracy: 0.7674 - val_loss: 1.0639 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9736 - loss: 0.0787 - val_accuracy: 0.7136 - val_loss: 1.5434 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9755 - loss: 0.0727 - val_accuracy: 0.6521 - val_loss: 2.2848 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 14ms/step - accuracy: 0.9769 - loss: 0.0702 - val_accuracy: 0.7485 - val_loss: 1.3225 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9779 - loss: 0.0646 - val_accuracy: 0.8047 - val_loss: 0.9423 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9816 - loss: 0.0550 - val_accuracy: 0.8013 - val_loss: 1.0058 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9810 - loss: 0.0562 - val_accuracy: 0.8012 - val_loss: 0.9391 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9837 - loss: 0.0496 - val_accuracy: 0.7488 - val_loss: 1.4797 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.9820 - loss: 0.0542 - val_accuracy: 0.8116 - val_loss: 0.8923 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9826 - loss: 0.0521 - val_accuracy: 0.8168 - val_loss: 0.9121 学習時間:264.9秒 パラメータ数:1,165,642 test_accuracy:0.8057 === D_se_residual === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 20s 18ms/step - accuracy: 0.4808 - loss: 1.4297 - val_accuracy: 0.4073 - val_loss: 1.6514 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.6365 - loss: 1.0219 - val_accuracy: 0.5459 - val_loss: 1.3926 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.7015 - loss: 0.8492 - val_accuracy: 0.5505 - val_loss: 1.4690 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.7516 - loss: 0.7168 - val_accuracy: 0.6661 - val_loss: 0.9424 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.7873 - loss: 0.6158 - val_accuracy: 0.6247 - val_loss: 1.2277 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 15ms/step - accuracy: 0.8139 - loss: 0.5416 - val_accuracy: 0.7233 - val_loss: 0.8242 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.8366 - loss: 0.4714 - val_accuracy: 0.6054 - val_loss: 1.4788 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.8575 - loss: 0.4144 - val_accuracy: 0.6271 - val_loss: 1.5656 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.8794 - loss: 0.3572 - val_accuracy: 0.6074 - val_loss: 1.5490 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.8924 - loss: 0.3120 - val_accuracy: 0.5481 - val_loss: 2.3595 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9035 - loss: 0.2768 - val_accuracy: 0.6651 - val_loss: 1.4218 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9198 - loss: 0.2337 - val_accuracy: 0.7438 - val_loss: 0.9495 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9286 - loss: 0.2051 - val_accuracy: 0.7532 - val_loss: 0.9021 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9407 - loss: 0.1771 - val_accuracy: 0.7892 - val_loss: 0.8444 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9483 - loss: 0.1485 - val_accuracy: 0.6489 - val_loss: 1.5408 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9536 - loss: 0.1360 - val_accuracy: 0.7720 - val_loss: 0.9113 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9587 - loss: 0.1221 - val_accuracy: 0.7383 - val_loss: 1.3147 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 15ms/step - accuracy: 0.9647 - loss: 0.1044 - val_accuracy: 0.6803 - val_loss: 1.6729 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9654 - loss: 0.0996 - val_accuracy: 0.7969 - val_loss: 0.9229 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9700 - loss: 0.0875 - val_accuracy: 0.7357 - val_loss: 1.4885 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9699 - loss: 0.0857 - val_accuracy: 0.7858 - val_loss: 1.0484 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9737 - loss: 0.0776 - val_accuracy: 0.7405 - val_loss: 1.3531 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9751 - loss: 0.0737 - val_accuracy: 0.6242 - val_loss: 2.4559 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9761 - loss: 0.0700 - val_accuracy: 0.7770 - val_loss: 1.0909 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9791 - loss: 0.0602 - val_accuracy: 0.7007 - val_loss: 1.5695 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 15ms/step - accuracy: 0.9786 - loss: 0.0637 - val_accuracy: 0.7593 - val_loss: 1.3016 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.9803 - loss: 0.0585 - val_accuracy: 0.6029 - val_loss: 2.4961 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9810 - loss: 0.0542 - val_accuracy: 0.7718 - val_loss: 1.3454 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9804 - loss: 0.0569 - val_accuracy: 0.7994 - val_loss: 1.0702 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.9825 - loss: 0.0532 - val_accuracy: 0.7694 - val_loss: 1.1959 学習時間:292.0秒 パラメータ数:1,208,138 test_accuracy:0.7600
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('se_residual_comparison.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(過学習の乖離) ────────────
fig2, axes2 = plt.subplots(2, 2, figsize=(14, 10))
axes2 = axes2.flatten()
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(label)
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('se_residual_overfit.png', dpi=150)
plt.show()
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>15} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>12}")
print("-" * 63)
for label in ['baseline', 'residual', 'se', 'se_residual']:
val_acc = histories[label].history['val_accuracy'][-1]
test_acc = scores[label][1]
t = times[label]
p = params[label]
print(f"{label:>15} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f} | {p:>12,}")
print("-" * 63)
最終結果サマリー
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Time(s) | Params
---------------------------------------------------------------
baseline | 0.7507 | 0.7517 | 272.5 | 1,145,162
residual | 0.7667 | 0.7635 | 287.7 | 1,187,658
se | 0.8168 | 0.8057 | 264.9 | 1,165,642
se_residual | 0.7694 | 0.7600 | 292.0 | 1,208,138
---------------------------------------------------------------
実験結果
精度グラフ
損失グラフ
A:ベースライン
B:Residualのみ
C:SEのみ
D:SE + Residual
| パターン | 最終 val_accuracy | 最終 test_accuracy | パラメータ数 | 学習時間 |
|---|---|---|---|---|
| A:ベースライン | 0.7507 | 0.7517 | 1,145,162 | 272.5秒 |
| B:Residualのみ | 0.7667 | 0.7635 | 1,187,658 | 287.7秒 |
| C:SEのみ | 0.8168 | 0.8057 | 1,165,642 | 264.9秒 |
| D:SE + Residual | 0.7694 | 0.7600 | 1,208,138 | 292.0秒 |
考察
① 予想外の結果:SEのみが最高精度、SE+Residualが伸び悩んだ
最も精度が高かったのはC(SEのみ)で test_accuracy 80.57%でした。「組み合わせで相乗効果」という期待に反し、D(SE+Residual)は 76.00% にとどまり、SEのみに約4.6ポイント及ばない結果になりました。
精度の順位をまとめると次のようになります。
② なぜSE+Residualがセ単体に負けたのか
最大の原因として考えられるのはval_lossの不安定さです。学習ログを見ると、Dパターンはエポック10(val_loss 2.36)、エポック23(val_loss 2.46)、エポック27(val_loss 2.50)と大きく跳ね上がる場面が複数回ありました。一方Cパターンは収束が比較的安定しており、30エポック時点のval_lossも 0.91 とDの 1.20 を大きく下回っています。
Residual接続のAdd()によって勾配の流れが変わり、SEブロックが生成するチャネルAttention重みの最適化と干渉した可能性があります。今回の構成(浅い2ブロック構成)ではResidualによる勾配安定の恩恵が小さく、むしろSEの効果を打ち消す方向に働いたと考えられます。
③ Residual単体の効果は限定的だった
B(Residualのみ)はベースライン比で +1.18ポイント(75.17% → 76.35%)の改善にとどまりました。今回のモデルは2ブロックと浅く、勾配消失が問題になりにくい構成のため、Residual接続の本来の強み(深いネットの学習安定化)が十分に発揮されなかったと考えられます。
④ SEブロックは最もコスト効率が高い
SEのみ(C)はベースライン比でパラメータ数の増加が 20,480(+1.8%)と最小級でありながら、精度向上は +5.40ポイントと最大でした。学習時間もベースラインより短い(264.9秒 vs 272.5秒)という結果で、4パターン中で最もコスト効率が高い手法であることが確認できました。
⑤ 組み合わせが有効な場面
SE-ResNetが有効に機能する場面としては、今回より深いネットワーク(ResNet-50以上)や大規模データセットが挙げられます。浅い構成ではSEの効果が支配的であり、Residualの追加メリットが相対的に小さくなります。実務では「まずSEを試し、モデルをさらに深くする段階でResidualを加える」という順序が合理的です。
⚠️ ハマりポイント
- Residualのチャネル数の不一致:入出力のチャネル数が異なる場合(例:64→128)、shortcutに 1×1 Conv を挟まないとAdd()がエラーになる。
- SEブロックのratio設定:ratio が大きすぎると中間層のユニット数が少なくなりすぎて表現力が落ちる。ratio=8〜16が一般的な目安。
- 組み合わせは必ずしも相乗効果にならない:今回の実験のように、浅いネットではSE+Residualの組み合わせがSE単体を下回るケースがある。モデルの深さやデータ規模に応じて判断することが重要。
- val_lossの乖離に注意:4パターンすべてでtrain_accuracyが97〜98%に達しているにもかかわらずval_accuracyは70〜80%台と大きく乖離している。EarlyStoppingやデータ拡張の併用を推奨。
✅まとめ
- 4パターン中で最高精度はSEのみ(test_accuracy 80.57%)。SE+Residualの組み合わせは期待に反して伸び悩んだ
- SEブロックはパラメータ増加が +1.8% と最小限でありながら精度向上 +5.40ポイントと、コスト効率が最も高かった
- Residual単体の効果は浅いネットでは限定的(+1.18ポイント)。深いネットほど本来の強みが発揮される
- SE+Residualの組み合わせが有効なのは、ResNet-50以上の深いネットワークや大規模データセットの場面。浅い構成では「まずSE単体」が合理的な選択
- 全パターンでtrain/val精度の乖離が大きく、データ拡張やEarlyStoppingとの併用が推奨される
関連記事
- SEブロック単体の検証 → SEブロック(Squeeze-and-Excitation)を追加すると精度は上がるか?
- Residual単体の検証 → Residual接続(スキップ接続)あり vs なし を比較
- Channel Attentionとの比較 → Channel Attentionを追加すると精度は上がるか?CBAM風実装をCIFAR-10で検証
- Bottleneck構造との組み合わせ → Bottleneck構造あり vs なし|1×1 Convの役割をCIFAR-10で実験検証







0 件のコメント:
コメントを投稿