「SENetって名前は聞いたことあるけど、実際どんな仕組みで、本当に精度が上がるの?」
SENet(Squeeze-and-Excitation Networks)は2017年のILSVRC(ImageNet大規模視覚認識チャレンジ)で優勝したアーキテクチャに使われた手法です。その核心にあるSEブロック(Squeeze-and-Excitation Block)は、チャンネルの重要度を動的に学習してスケーリングするシンプルな仕組みです。
前回のChannel Attention(CBAM風)の実験では「なし」が「あり」を上回るという結果になりました。今回はSENetの原論文に忠実な実装(GAPとGMPを両方使う)でCBAMとの違いも確認します。
📘 この記事でわかること
- SEブロックの仕組みとCBAM風Channel Attentionとの違い
- KerasでSEブロックをカスタムレイヤーとして実装する方法
- SEブロックあり vs なし の精度・過学習・パラメータ数の差(CIFAR-10実験)
- SENetが効果を発揮する条件と、今回の結果の考察
SEブロックとは何か――CBAMとの違い
Channel Attentionの考え方は「どのチャンネルに注目するか」をネットワークが学習するものです。SEブロックはそのChannel Attentionの元祖にあたる設計で、CBAM(前回の実験)はSEブロックをさらに発展させた手法です。
| 手法 | Squeeze方法 | Excitation方法 | 空間Attention |
|---|---|---|---|
| SENet(今回) | GAP のみ | FC → ReLU → FC → Sigmoid | なし |
| CBAM(前回) | GAP+GMP の両方 | FC → ReLU → FC → Sigmoid(両経路を加算) | あり(空間Attentionも持つ) |
SEブロックはCBAMより構造がシンプルです。今回の実装ではSENetの原論文に倣い、GAPのみのSqueezeを使います。前回の「CBAM風」実装(GAPのみ)と実は同じ構造になりますが、今回はそれが本来のSENet実装であることを明示し、正しく位置づけます。
SEブロックの仕組み
処理の流れは2ステップです。
① Squeeze(チャンネルごとの要約)
特徴マップ(H × W × C)を Global Average Pooling で空間方向に集約し、チャンネルごとのスカラー(1 × 1 × C)を得ます。
\[ z_c = \frac{1}{H \times W} \sum_{i=1}^{H} \sum_{j=1}^{W} x_c(i, j) \]
② Excitation(重要度スコアを計算して掛け算)
集約したベクトルを FC → ReLU → FC → Sigmoid に通し、各チャンネルの重要度スコア(0〜1)を計算します。元の特徴マップにこのスコアを掛け算(スケーリング)します。
\[ s = \sigma\!\left(W_2\,\delta\!\left(W_1\,z\right)\right), \quad \hat{x}_c = s_c \cdot x_c \]
rはリダクション比(reduction ratio)で、中間のDense層を C/r 次元に圧縮します。原論文では r=16 を推奨していますが、チャンネル数が少ない場合は r=8 が一般的です。
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。SEブロックの有無以外の条件は全て同一にします。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 51 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (25.5 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122363 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 55.4 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
print("データ準備完了")
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 4s 0us/step データ準備完了
SEブロックのカスタムレイヤー実装
Kerasのカスタムレイヤーとして実装します。reduction_ratio(= r)で中間次元数を制御します。
class SEBlock(keras.layers.Layer):
"""Squeeze-and-Excitation Block(SENet原論文準拠)
Squeeze(GAP)→ Excitation(FC → ReLU → FC → Sigmoid)→ スケーリング
"""
def __init__(self, reduction_ratio=16, **kwargs):
super().__init__(**kwargs)
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
channels = input_shape[-1]
reduced = max(channels // self.reduction_ratio, 1)
self.gap = keras.layers.GlobalAveragePooling2D()
self.fc1 = keras.layers.Dense(reduced, activation='relu',
use_bias=False)
self.fc2 = keras.layers.Dense(channels, activation='sigmoid',
use_bias=False)
super().build(input_shape)
def call(self, x):
# Squeeze: (B, H, W, C) → (B, C)
z = self.gap(x)
# Excitation: (B, C) → (B, C/r) → (B, C)
s = self.fc2(self.fc1(z))
# スケーリング: (B, C) → (B, 1, 1, C) にreshapeして掛け算
s = tf.reshape(s, (-1, 1, 1, tf.shape(s)[-1]))
return x * s
def get_config(self):
config = super().get_config()
config.update({'reduction_ratio': self.reduction_ratio})
return config
⚠️ ハマりポイント
GlobalAveragePooling2Dの出力は2次元(B, C)になります。元の特徴マップ(B, H, W, C)にチャンネルごとのスコアを掛け算するにはtf.reshape(s, (-1, 1, 1, C))で形状を合わせる必要があります。形状が合わないとブロードキャストに失敗してエラーになります。- reduction_ratioが大きすぎると中間次元数が0以下になります。チャンネル数が小さいモデルでは
max(channels // r, 1)のようにして最低1を確保してください。原論文推奨の r=16 はチャンネル数が256以上の場合向けです。今回のように64ch・128chの浅いCNNでは r=8 の方が安全です。 - SEブロックはConv2D後・MaxPooling後のどちらに挿入するかで挙動が変わります。今回はMaxPooling後に挿入して空間サイズが小さい状態でAttentionを計算します。原論文のSENet設計ではConvブロックの最後に挿入するのが一般的です。
モデル構築関数(あり / なし)
def build_model(use_se, name):
inputs = keras.Input(shape=(32, 32, 3))
# Block 1
x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
x = keras.layers.MaxPooling2D((2, 2))(x)
if use_se:
x = SEBlock(reduction_ratio=8)(x) # ← SEブロック挿入
# Block 2
x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = keras.layers.MaxPooling2D((2, 2))(x)
if use_se:
x = SEBlock(reduction_ratio=8)(x) # ← SEブロック挿入
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.2)(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
return keras.Model(inputs, outputs, name=name)
def compile_and_fit(model):
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
start = time.time()
history = model.fit(x_train, y_train, epochs=30, batch_size=64,
validation_split=0.2, verbose=1)
return history, time.time() - start
2パターンの学習実行
configs = [
(False, 'A_no_se'),
(True, 'B_se_block'),
]
histories, times, scores, params = {}, {}, {}, {}
for use_se, name in configs:
print(f"\n=== {name} ===")
model = build_model(use_se, name)
model.summary()
h, t = compile_and_fit(model)
s = model.evaluate(x_test, y_test, verbose=0)
label = name.split('_', 1)[1] # 'no_se' / 'se_block'
histories[label] = h
times[label] = t
scores[label] = s
params[label] = model.count_params()
print(f"学習時間:{t:.1f}秒 パラメータ数:{model.count_params():,} test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_no_se === Model: "A_no_se" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 32, 32, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 93,450 (365.04 KB) Trainable params: 93,450 (365.04 KB) Non-trainable params: 0 (0.00 B) Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 8ms/step - accuracy: 0.2704 - loss: 1.9110 - val_accuracy: 0.3466 - val_loss: 1.7526 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.3762 - loss: 1.6646 - val_accuracy: 0.4259 - val_loss: 1.5714 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4308 - loss: 1.5410 - val_accuracy: 0.4480 - val_loss: 1.4790 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4669 - loss: 1.4552 - val_accuracy: 0.4771 - val_loss: 1.4267 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 7ms/step - accuracy: 0.4893 - loss: 1.3913 - val_accuracy: 0.4867 - val_loss: 1.3983 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5100 - loss: 1.3393 - val_accuracy: 0.5372 - val_loss: 1.2798 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 10ms/step - accuracy: 0.5295 - loss: 1.2894 - val_accuracy: 0.5364 - val_loss: 1.2631 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - accuracy: 0.5438 - loss: 1.2563 - val_accuracy: 0.5500 - val_loss: 1.2358 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5558 - loss: 1.2247 - val_accuracy: 0.5730 - val_loss: 1.1912 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5701 - loss: 1.1900 - val_accuracy: 0.5693 - val_loss: 1.1806 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5791 - loss: 1.1680 - val_accuracy: 0.5737 - val_loss: 1.1669 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5865 - loss: 1.1395 - val_accuracy: 0.5959 - val_loss: 1.1184 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6021 - loss: 1.1083 - val_accuracy: 0.6018 - val_loss: 1.0888 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6060 - loss: 1.0881 - val_accuracy: 0.6163 - val_loss: 1.0527 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6174 - loss: 1.0632 - val_accuracy: 0.6091 - val_loss: 1.0600 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6204 - loss: 1.0557 - val_accuracy: 0.6131 - val_loss: 1.0619 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6286 - loss: 1.0341 - val_accuracy: 0.6265 - val_loss: 1.0290 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6355 - loss: 1.0138 - val_accuracy: 0.6312 - val_loss: 1.0154 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6410 - loss: 0.9982 - val_accuracy: 0.6471 - val_loss: 0.9780 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6475 - loss: 0.9826 - val_accuracy: 0.6567 - val_loss: 0.9578 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6515 - loss: 0.9695 - val_accuracy: 0.6395 - val_loss: 0.9942 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6608 - loss: 0.9521 - val_accuracy: 0.6606 - val_loss: 0.9409 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.6618 - loss: 0.9434 - val_accuracy: 0.6637 - val_loss: 0.9412 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6680 - loss: 0.9285 - val_accuracy: 0.6542 - val_loss: 0.9612 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6746 - loss: 0.9141 - val_accuracy: 0.6775 - val_loss: 0.9047 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6775 - loss: 0.9064 - val_accuracy: 0.6356 - val_loss: 1.0197 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6784 - loss: 0.8990 - val_accuracy: 0.6746 - val_loss: 0.9070 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6829 - loss: 0.8868 - val_accuracy: 0.6632 - val_loss: 0.9290 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6878 - loss: 0.8717 - val_accuracy: 0.6830 - val_loss: 0.8956 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6910 - loss: 0.8629 - val_accuracy: 0.6726 - val_loss: 0.9188 学習時間:132.1秒 パラメータ数:93,450 test_accuracy:0.6644 === B_se_block === Model: "B_se_block" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer_1 (InputLayer) │ (None, 32, 32, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_2 (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_2 (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ se_block (SEBlock) │ (None, 16, 16, 64) │ 1,024 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ se_block_1 (SEBlock) │ (None, 8, 8, 128) │ 4,096 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d_3 │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_6 (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_7 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 98,570 (385.04 KB) Trainable params: 98,570 (385.04 KB) Non-trainable params: 0 (0.00 B) Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - accuracy: 0.2395 - loss: 1.9953 - val_accuracy: 0.3028 - val_loss: 1.8310 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.3340 - loss: 1.7513 - val_accuracy: 0.3511 - val_loss: 1.6927 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.3923 - loss: 1.6155 - val_accuracy: 0.4234 - val_loss: 1.5542 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4345 - loss: 1.5214 - val_accuracy: 0.4552 - val_loss: 1.4621 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4604 - loss: 1.4565 - val_accuracy: 0.4803 - val_loss: 1.4011 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4830 - loss: 1.4079 - val_accuracy: 0.4967 - val_loss: 1.3677 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4965 - loss: 1.3690 - val_accuracy: 0.5008 - val_loss: 1.3377 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5113 - loss: 1.3306 - val_accuracy: 0.5130 - val_loss: 1.3188 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5275 - loss: 1.2936 - val_accuracy: 0.5323 - val_loss: 1.2705 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5369 - loss: 1.2682 - val_accuracy: 0.5467 - val_loss: 1.2359 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5479 - loss: 1.2406 - val_accuracy: 0.5602 - val_loss: 1.2014 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5591 - loss: 1.2125 - val_accuracy: 0.5731 - val_loss: 1.1756 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5673 - loss: 1.1923 - val_accuracy: 0.5793 - val_loss: 1.1594 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5775 - loss: 1.1674 - val_accuracy: 0.5811 - val_loss: 1.1606 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5877 - loss: 1.1433 - val_accuracy: 0.5912 - val_loss: 1.1287 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5911 - loss: 1.1282 - val_accuracy: 0.6004 - val_loss: 1.1013 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6008 - loss: 1.1072 - val_accuracy: 0.5984 - val_loss: 1.1128 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6084 - loss: 1.0907 - val_accuracy: 0.6065 - val_loss: 1.0857 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6158 - loss: 1.0701 - val_accuracy: 0.6158 - val_loss: 1.0538 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6219 - loss: 1.0558 - val_accuracy: 0.6238 - val_loss: 1.0397 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6269 - loss: 1.0386 - val_accuracy: 0.6205 - val_loss: 1.0453 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6306 - loss: 1.0298 - val_accuracy: 0.6213 - val_loss: 1.0493 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6378 - loss: 1.0097 - val_accuracy: 0.6103 - val_loss: 1.0713 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6365 - loss: 1.0048 - val_accuracy: 0.6298 - val_loss: 1.0235 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6453 - loss: 0.9851 - val_accuracy: 0.6532 - val_loss: 0.9693 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6499 - loss: 0.9779 - val_accuracy: 0.6471 - val_loss: 0.9837 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6539 - loss: 0.9708 - val_accuracy: 0.6438 - val_loss: 0.9882 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6583 - loss: 0.9536 - val_accuracy: 0.6569 - val_loss: 0.9617 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6612 - loss: 0.9464 - val_accuracy: 0.6524 - val_loss: 0.9663 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6659 - loss: 0.9340 - val_accuracy: 0.6530 - val_loss: 0.9627 学習時間:132.2秒 パラメータ数:98,570 test_accuracy:0.6550
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('se_block_comparison.png', dpi=150)
plt.show()
# ── train vs val(過学習チェック)────────────────────
fig2, axes2 = plt.subplots(2, 1, figsize=(7, 10))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(label)
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('se_block_overfit.png', dpi=150)
plt.show()
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>12} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>12}")
print("-" * 60)
for label in ['no_se', 'se_block']:
val_acc = histories[label].history['val_accuracy'][-1]
test_acc = scores[label][1]
t = times[label]
p = params[label]
print(f"{label:>12} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f} | {p:>12,}")
print("-" * 60)
最終結果サマリー
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Time(s) | Params
------------------------------------------------------------
no_se | 0.6726 | 0.6644 | 132.1 | 93,450
se_block | 0.6530 | 0.6550 | 132.2 | 98,570
------------------------------------------------------------
実験結果
精度グラフ
損失グラフ
SEブロックなし
SEブロックあり
| パターン | 最終 val_accuracy | 最終 test_accuracy | パラメータ数 | 学習時間 |
|---|---|---|---|---|
| A:SEブロックなし | 67.26% | 66.44% | 93,450 | 132.1秒 |
| B:SEブロックあり | 65.30% | 65.50% | 98,570 | 132.2秒 |
考察
① SEブロックのパラメータ増加量
SEブロックを2箇所(Block1: 64ch, Block2: 128ch)に追加した場合のパラメータ増加は以下の通りです。
- Block1(64ch, r=8):fc1: 64×8=512、fc2: 8×64=512 → 1,024パラメータ
- Block2(128ch, r=8):fc1: 128×16=2,048、fc2: 16×128=2,048 → 4,096パラメータ
- 合計追加:5,120パラメータ(93,450 → 98,570、約5.5%増)
実行結果のパラメータ数(98,570)と計算値(93,450 + 5,120 = 98,570)が一致しており、SEブロックの実装が正確であることが確認できました。前回のCBAM風実装(98,570)とも同値であることから、今回の実装は「GAPのみのSqueeze」という点でCBAM風と構造的に同一です。
② SENetが本領を発揮する条件
SENetの原論文(Hu et al., 2018)ではResNet・VGG・Inceptionなどの深いネットワークにSEブロックを組み込んで検証しています。今回のような浅いCNN(Conv2D 2層)では、前回のCBAM実験と同様に効果が出にくい可能性があります。
| 条件 | 効果の出やすさ | 理由 |
|---|---|---|
| 深いネットワーク(ResNet等) | ◎ 出やすい | 多数のチャンネルが役割分担し、重み付けが有効に機能する |
| チャンネル数が多い(256ch以上) | ○ 出やすい | 重要度スコアの選択肢が増えてAttentionが有効になる |
| 浅いCNN(今回) | △ 出にくい | チャンネルの役割分担が未熟でAttentionの効果が薄い |
| 転移学習モデル(EfficientNet等) | ◎ 出やすい | すでに豊かな特徴表現があり、選択的スケーリングが有効 |
③ SENet vs CBAM ――どちらが精度で有利か
今回の実験とCBAM記事の結果を並べると、浅いCNNにおける「なし vs あり」の傾向が2回連続で確認できました。
| 手法 | なし test_accuracy | あり test_accuracy | 差 |
|---|---|---|---|
| CBAM風(前回) | 66.82% | 66.31% | −0.51% |
| SENet(今回) | 66.44% | 65.50% | −0.94% |
SENet(−0.94%)はCBAM(−0.51%)よりも「なし」との差がやや大きくなりました。今回のベースラインは前回とは別セッションの実行のため直接比較には注意が必要ですが、どちらの手法も浅いCNNでは「あり」が「なし」を上回らないという傾向は一貫しています。
浅いCNNでChannel Attention系手法が効果を出しにくい理由は、チャンネル数が少なく(64〜128ch)、各チャンネルの「役割分担」がまだ未熟な状態のためAttentionの選択機構がうまく機能しないからです。深いモデルでの効果を確認したい場合は、転移学習ベースのモデル(EfficientNetB0など)にSEブロックを差し込む実験が有効です。
④ reduction_ratioの選び方
| reduction_ratio(r) | 中間次元数(Ch=128の場合) | 特徴 |
|---|---|---|
| 4 | 32 | 表現力大・パラメータ増加大 |
| 8(今回) | 16 | チャンネル数64〜128の浅いモデル向けバランス設定 |
| 16(原論文推奨) | 8 | 深いモデル(256ch以上)で推奨。精度とコストのバランスが最良 |
✅ まとめ
- SEブロックは「チャンネルごとの重要度スコア(0〜1)」を GAP → FC → ReLU → FC → Sigmoid で計算し、元の特徴マップに掛け算するシンプルな仕組み
- 追加パラメータは5,120(93,450 → 98,570、約5.5%増)と非常に小さく、学習時間もほぼ変わらなかった(132.1秒 vs 132.2秒)
- 今回の実験ではSEブロックなし(66.44%)がSEブロックあり(65.50%)を0.94%上回った。前回のCBAM風実験(差0.51%)と合わせ、浅いCNNでは2回連続で「なし」が「あり」を上回る結果となった
- 浅いCNN(Conv2D 2層・64〜128ch)ではチャンネルの役割分担が未熟なためAttentionの選択機構が機能しにくい——これがChannel Attention系手法共通の注意点
- SENetの本領は深いネットワーク(ResNet等)や多チャンネル構成(256ch以上)で発揮される
- CBAMとSENetの最大の違いは「空間Attentionの有無」——CBAMはChannel Attentionに加えて空間Attentionも持つ分、パラメータと計算コストがやや大きい
- カスタムレイヤーとして実装しておけば
SEBlock()(x)の1行で任意のConv層の直後に差し込めるため、大きなモデルで試す際の再利用コストはゼロ
関連記事もあわせてどうぞ:
- Channel Attention(CBAM風)の実験 → Channel Attentionを追加すると精度は上がるか?CBAM風実装をCIFAR-10で検証【Keras実験】
- Residual接続あり vs なし → Residual接続(スキップ接続)あり vs なし|ResNetの核心を実験で検証
- Bottleneck構造あり vs なし → Bottleneck構造あり vs なし|1×1 Convの役割をCIFAR-10で実験検証【Keras】
- GAPとFlattenの比較 → Global Average Pooling vs Flatten|CNNの最終層、どっちが精度・速度で有利か?【Keras実験】
- Depthwise Separable Convolution vs Conv2D → Depthwise Separable Convolution vs 通常Conv2D|パラメータ削減で精度はどう変わる?





0 件のコメント:
コメントを投稿