Channel Attentionとは何か、使うと何がうれしいのか
CNNの各Conv層は複数のチャンネル(フィルター)を持っています。たとえば128チャンネルある場合、どのチャンネルが今の画像分類に役立つかはネットワークが決められません——全チャンネルを均等に扱うだけです。
Channel Attentionはこの問題を解決します。各チャンネルに対して「重要度スコア(0〜1)」を計算し、重要なチャンネルを強調・不要なチャンネルを抑制する仕組みです。SENet(Squeeze-and-Excitation Networks)や CBAM(Convolutional Block Attention Module)で採用されており、精度向上に有効なことが多くの論文で示されています。
📘 この記事でわかること
- Channel Attentionの仕組みと数式(Squeeze → Excitation)
- KerasでCBAM風 Channel Attentionをカスタムレイヤーとして実装する方法
- Channel Attentionあり vs なし の精度・過学習の差(CIFAR-10実験)
- どんな場面で使うと効果的か
Channel Attentionの仕組み
処理の流れは大きく2ステップです。
① Squeeze(チャンネルごとのグローバル情報を集約)
特徴マップ(H × W × C)に対し、空間方向(H × W)をGlobal Average Poolingで潰し、チャンネルごとのスカラー値(1 × 1 × C)を得ます。これが「チャンネルの要約情報」です。
$$z_c = \frac{1}{H \times W} \sum_{i=1}^{H} \sum_{j=1}^{W} x_c(i, j)$$
② Excitation(重要度スコアを計算して掛け算)
集約したベクトル(C次元)をDense層でボトルネック圧縮(C/r次元)→ relu → Dense層(C次元)→ sigmoid に通します。出力は各チャンネルの重要度スコア(0〜1)。元の特徴マップにこのスコアを掛け算(スケーリング)します。
$$s = \sigma\!\left(W_2\,\delta\!\left(W_1\,z\right)\right), \quad \hat{x}_c = s_c \cdot x_c$$
ここで \( r \) はリダクション比(reduction ratio)と呼ばれ、中間のDense層の次元数を \( C/r \) に圧縮します。今回は r = 8 を使います。
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。Channel Attentionの有無以外の条件は全て同一にします。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 51 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (18.0 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122412 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 63.2 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
print("データ準備完了")
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 5s 0us/step データ準備完了
Channel Attentionレイヤーの実装
Kerasのカスタムレイヤーとして実装します。reduction_ratio(= r)で中間次元数を制御します。
class ChannelAttention(keras.layers.Layer):
"""CBAM風 Channel Attention レイヤー
Squeeze(GAP)→ Excitation(FC → relu → FC → sigmoid)→ スケーリング
"""
def __init__(self, reduction_ratio=8, **kwargs):
super().__init__(**kwargs)
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
channels = input_shape[-1]
reduced = max(channels // self.reduction_ratio, 1)
self.gap = keras.layers.GlobalAveragePooling2D()
self.fc1 = keras.layers.Dense(reduced, activation='relu',
use_bias=False)
self.fc2 = keras.layers.Dense(channels, activation='sigmoid',
use_bias=False)
super().build(input_shape)
def call(self, x):
# Squeeze: (B, H, W, C) → (B, C)
z = self.gap(x)
# Excitation: (B, C) → (B, C/r) → (B, C)
s = self.fc2(self.fc1(z))
# スケーリング: (B, C) → (B, 1, 1, C) にreshapeして掛け算
s = tf.reshape(s, (-1, 1, 1, tf.shape(s)[-1]))
return x * s
def get_config(self):
config = super().get_config()
config.update({'reduction_ratio': self.reduction_ratio})
return config
⚠️ ハマりポイント
GlobalAveragePooling2Dの出力は2次元(B, C)になります。元の特徴マップ(B, H, W, C)に掛け算するにはtf.reshape(s, (-1, 1, 1, C))で形状を合わせる必要があります。形状を合わせないとエラーになります。- CBAMの原論文ではGAPに加えてGlobal Max Poolingも使いますが、今回はシンプルさを優先してGAPのみで実装しています。
reduction_ratioが大きすぎると中間次元数が0以下になるため、max(channels // r, 1)で最低1を確保しています。
モデル構築関数(あり / なし)
def build_model(use_attention, name):
inputs = keras.Input(shape=(32, 32, 3))
# Block 1
x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
x = keras.layers.MaxPooling2D((2, 2))(x)
if use_attention:
x = ChannelAttention(reduction_ratio=8)(x)
# Block 2
x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = keras.layers.MaxPooling2D((2, 2))(x)
if use_attention:
x = ChannelAttention(reduction_ratio=8)(x)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.2)(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
return keras.Model(inputs, outputs, name=name)
def compile_and_fit(model):
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
start = time.time()
history = model.fit(x_train, y_train, epochs=30, batch_size=64,
validation_split=0.2, verbose=1)
return history, time.time() - start
2パターンの学習実行
configs = [
(False, 'A_no_attention'),
(True, 'B_channel_attention'),
]
histories, times, scores, params = {}, {}, {}, {}
for use_attn, name in configs:
print(f"\n=== {name} ===")
model = build_model(use_attn, name)
model.summary()
h, t = compile_and_fit(model)
s = model.evaluate(x_test, y_test, verbose=0)
label = name.split('_', 1)[1] # 'no_attention' / 'channel_attention'
histories[label] = h
times[label] = t
scores[label] = s
params[label] = model.count_params()
print(f"学習時間:{t:.1f}秒 パラメータ数:{model.count_params():,} test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_no_attention === Model: "A_no_attention" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 32, 32, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 93,450 (365.04 KB) Trainable params: 93,450 (365.04 KB) Non-trainable params: 0 (0.00 B) Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 13ms/step - accuracy: 0.2661 - loss: 1.9191 - val_accuracy: 0.3548 - val_loss: 1.7288 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3636 - loss: 1.6903 - val_accuracy: 0.4049 - val_loss: 1.5964 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4256 - loss: 1.5565 - val_accuracy: 0.4683 - val_loss: 1.4680 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4636 - loss: 1.4612 - val_accuracy: 0.4975 - val_loss: 1.3767 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.4915 - loss: 1.3965 - val_accuracy: 0.4874 - val_loss: 1.4022 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5100 - loss: 1.3423 - val_accuracy: 0.5168 - val_loss: 1.3254 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 10ms/step - accuracy: 0.5261 - loss: 1.2996 - val_accuracy: 0.5410 - val_loss: 1.2705 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - accuracy: 0.5393 - loss: 1.2610 - val_accuracy: 0.5626 - val_loss: 1.2186 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5488 - loss: 1.2425 - val_accuracy: 0.5604 - val_loss: 1.2120 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5594 - loss: 1.2105 - val_accuracy: 0.5774 - val_loss: 1.1689 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5698 - loss: 1.1825 - val_accuracy: 0.5783 - val_loss: 1.1614 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5778 - loss: 1.1622 - val_accuracy: 0.5781 - val_loss: 1.1812 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5895 - loss: 1.1393 - val_accuracy: 0.5914 - val_loss: 1.1195 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5974 - loss: 1.1189 - val_accuracy: 0.6041 - val_loss: 1.0943 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6028 - loss: 1.0970 - val_accuracy: 0.6034 - val_loss: 1.0893 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6094 - loss: 1.0774 - val_accuracy: 0.6109 - val_loss: 1.0713 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6169 - loss: 1.0617 - val_accuracy: 0.6059 - val_loss: 1.1050 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6228 - loss: 1.0439 - val_accuracy: 0.6349 - val_loss: 1.0211 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6307 - loss: 1.0267 - val_accuracy: 0.6393 - val_loss: 1.0111 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6351 - loss: 1.0134 - val_accuracy: 0.6171 - val_loss: 1.0742 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6397 - loss: 1.0002 - val_accuracy: 0.6285 - val_loss: 1.0339 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6461 - loss: 0.9831 - val_accuracy: 0.6529 - val_loss: 0.9770 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6495 - loss: 0.9704 - val_accuracy: 0.6386 - val_loss: 1.0035 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6571 - loss: 0.9549 - val_accuracy: 0.6444 - val_loss: 0.9881 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6594 - loss: 0.9425 - val_accuracy: 0.6586 - val_loss: 0.9652 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6658 - loss: 0.9307 - val_accuracy: 0.6532 - val_loss: 0.9820 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6714 - loss: 0.9155 - val_accuracy: 0.6605 - val_loss: 0.9445 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6775 - loss: 0.9017 - val_accuracy: 0.6693 - val_loss: 0.9281 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6788 - loss: 0.8979 - val_accuracy: 0.6717 - val_loss: 0.9271 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6844 - loss: 0.8834 - val_accuracy: 0.6754 - val_loss: 0.9167 学習時間:139.1秒 パラメータ数:93,450 test_accuracy:0.6682 === B_channel_attention === Model: "B_channel_attention" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer_1 (InputLayer) │ (None, 32, 32, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_2 (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_2 (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ channel_attention │ (None, 16, 16, 64) │ 1,024 │ │ (ChannelAttention) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ channel_attention_1 │ (None, 8, 8, 128) │ 4,096 │ │ (ChannelAttention) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d_3 │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_6 (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_7 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 98,570 (385.04 KB) Trainable params: 98,570 (385.04 KB) Non-trainable params: 0 (0.00 B) Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 9ms/step - accuracy: 0.2373 - loss: 1.9902 - val_accuracy: 0.3121 - val_loss: 1.7881 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.3390 - loss: 1.7326 - val_accuracy: 0.3642 - val_loss: 1.6694 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3877 - loss: 1.6282 - val_accuracy: 0.4286 - val_loss: 1.5523 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4315 - loss: 1.5301 - val_accuracy: 0.4628 - val_loss: 1.4603 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4658 - loss: 1.4502 - val_accuracy: 0.4874 - val_loss: 1.3932 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4899 - loss: 1.3958 - val_accuracy: 0.5046 - val_loss: 1.3408 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5047 - loss: 1.3535 - val_accuracy: 0.5204 - val_loss: 1.3034 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5189 - loss: 1.3122 - val_accuracy: 0.5334 - val_loss: 1.2638 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5354 - loss: 1.2780 - val_accuracy: 0.5386 - val_loss: 1.2522 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5427 - loss: 1.2550 - val_accuracy: 0.5550 - val_loss: 1.2059 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5529 - loss: 1.2279 - val_accuracy: 0.5495 - val_loss: 1.2221 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5645 - loss: 1.1995 - val_accuracy: 0.5736 - val_loss: 1.1648 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5739 - loss: 1.1733 - val_accuracy: 0.5836 - val_loss: 1.1349 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5804 - loss: 1.1504 - val_accuracy: 0.5973 - val_loss: 1.1036 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5933 - loss: 1.1295 - val_accuracy: 0.5645 - val_loss: 1.1857 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6039 - loss: 1.1022 - val_accuracy: 0.5997 - val_loss: 1.1036 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6101 - loss: 1.0848 - val_accuracy: 0.5996 - val_loss: 1.1203 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6187 - loss: 1.0689 - val_accuracy: 0.6206 - val_loss: 1.0487 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6241 - loss: 1.0462 - val_accuracy: 0.6304 - val_loss: 1.0248 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6339 - loss: 1.0245 - val_accuracy: 0.6301 - val_loss: 1.0183 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6364 - loss: 1.0202 - val_accuracy: 0.6371 - val_loss: 1.0078 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6389 - loss: 0.9981 - val_accuracy: 0.6375 - val_loss: 1.0083 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6445 - loss: 0.9880 - val_accuracy: 0.6466 - val_loss: 0.9834 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6533 - loss: 0.9673 - val_accuracy: 0.6465 - val_loss: 0.9805 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6591 - loss: 0.9573 - val_accuracy: 0.6481 - val_loss: 0.9787 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6632 - loss: 0.9425 - val_accuracy: 0.6420 - val_loss: 0.9913 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6666 - loss: 0.9322 - val_accuracy: 0.6619 - val_loss: 0.9354 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6708 - loss: 0.9214 - val_accuracy: 0.6612 - val_loss: 0.9526 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6749 - loss: 0.9143 - val_accuracy: 0.6588 - val_loss: 0.9393 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6801 - loss: 0.8991 - val_accuracy: 0.6654 - val_loss: 0.9345 学習時間:128.6秒 パラメータ数:98,570 test_accuracy:0.6631
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('channel_attention_comparison.png', dpi=150)
plt.show()
# ── train vs val(過学習チェック)────────────────────
fig2, axes2 = plt.subplots(2, 1, figsize=(7, 10))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(label)
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('channel_attention_overfit.png', dpi=150)
plt.show()
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>20} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>12}")
print("-" * 68)
for label in ['no_attention', 'channel_attention']:
val_acc = histories[label].history['val_accuracy'][-1]
test_acc = scores[label][1]
t = times[label]
p = params[label]
print(f"{label:>20} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f} | {p:>12,}")
print("-" * 68)
最終結果サマリー
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Time(s) | Params
--------------------------------------------------------------------
no_attention | 0.6754 | 0.6682 | 139.1 | 93,450
channel_attention | 0.6654 | 0.6631 | 128.6 | 98,570
--------------------------------------------------------------------
実験結果
精度グラフ
損失グラフ
no_attention
channel_attention
| パターン | 最終 val_accuracy | 最終 test_accuracy | パラメータ数 | 学習時間 |
|---|---|---|---|---|
| A:Channel Attentionなし | 67.54% | 66.82% | 93,450 | 139.1秒 |
| B:Channel Attentionあり | 66.54% | 66.31% | 98,570 | 128.6秒 |
考察
① 今回はChannel Attentionなしが上回った——なぜか?
結果は val_accuracy で 67.54%(なし)対 66.54%(あり)と、Channel Attentionを追加した方が約1%低いという予想外の結果になりました。test_accuracyも同様に 66.82% 対 66.31% でなし側が上回っています。
これにはいくつかの原因が考えられます。
- モデルが浅すぎる:今回のベースモデルはConv2Dが2層だけです。Channel AttentionはResNet・VGGのような深いネットワークで効果を発揮する設計です。浅いモデルでは各チャンネルがまだ十分に「役割分担」できておらず、重み付けの恩恵が薄い可能性があります。
- チャンネル数が少ない:Block 1は64ch、Block 2は128chと比較的少ない構成です。SENetの原論文では256ch以上の深い層でChannel Attentionの効果が大きいとされています。
- 最適化の難化:Channel Attentionのsigmoidスケーリングが加わることで損失の勾配の流れが変わり、同じエポック数・学習率では十分に収束しきれなかった可能性があります。
② パラメータ増加量の確認
Channel Attentionを2箇所追加したことによるパラメータ増加は 98,570 − 93,450 = 5,120パラメータでした。計算値(Block1: 64×8+8×64=1,024、Block2: 128×16+16×128=4,096 の合計5,120)と完全に一致しています。増加率は約5.5%です。これは非常に小さなコスト——問題はコストではなく、効果が出るかどうかです。
③ 学習時間はむしろ短縮された
Channel Attentionありの学習時間は128.6秒で、なし(139.1秒)より約10秒短くなっています。Channel Attentionのスケーリングが一種の正則化として働き、初期エポックの収束を助けた可能性があります。パラメータ増加があっても学習時間は増えないことが確認できました。
④ Channel Attentionが効果を発揮する条件
| 条件 | 効果の出やすさ | 理由 |
|---|---|---|
| 深いネットワーク(ResNet等) | ◎ 出やすい | チャンネル数が多く役割分担が進んでいる |
| チャンネル数が多い(256ch以上) | ○ 出やすい | 重み付けの選択肢が多い |
| 浅いCNN(今回) | △ 出にくい | チャンネルの役割分担が未熟 |
| 転移学習モデル(EfficientNet等) | ◎ 出やすい | 既に豊かな特徴表現があり選択が有効 |
⑤ reduction_ratioの選び方の目安
| reduction_ratio(r) | 中間次元数(Ch=128の場合) | 特徴 |
|---|---|---|
| 4 | 32 | 表現力大・パラメータ増加大 |
| 8(今回) | 16 | 原論文推奨。バランスが良い |
| 16 | 8 | 軽量・チャンネル数が少ないモデル向け |
✅ まとめ
- Channel Attentionは「どのチャンネルに注目するか」をGAP+Dense(C→C/r→C→sigmoid)で学習し、特徴マップを選択的にスケーリングする仕組み
- 追加パラメータは5,120(約5.5%増)と非常に小さく、学習時間も増加しなかった
- 今回の実験では Channel Attentionなし の方が約1%高い精度となった。浅いCNN・少ないチャンネル数の構成では効果が出にくいことを実験で確認できた
- Channel Attentionの本領は深いネットワーク(ResNet等)や多チャンネル構成(256ch以上)で発揮される。転移学習モデルに差し込む用途が最も効果的
- カスタムレイヤーとして実装しておけば、Functional APIのConv2D直後に1行で差し込めるので、大きなモデルに試す際の再利用コストはゼロ
- 「Channel Attentionを追加すれば必ず精度が上がる」ではなく、ネットワークの深さ・チャンネル数・タスクの複雑さに依存する——これが今回最大の学びです
関連記事もあわせてどうぞ:
- GAPとFlattenの比較 → Global Average Pooling vs Flatten|CNNの最終層、どっちが精度・速度で有利か?【Keras実験】
- GlobalAveragePooling vs GlobalMaxPooling → GAP vs GMP|Global Poolingの種類で精度はどう変わる?【Keras×CIFAR-10実験】
- Dropout率の比較実験 → Dropoutの割合(0.0 vs 0.2 vs 0.5)を変えると過学習はどう変わる?【Keras×CIFAR-10実験】
- BatchNorm vs Dropout → BatchNormalizationとDropoutを比較【Keras実験】





0 件のコメント:
コメントを投稿