Channel Attentionを追加すると精度は上がるか?CBAM風実装をCIFAR-10で検証【Keras実験】

投稿日:2026年5月17日日曜日 最終更新日:

Attention CBAM Channel Attention CIFAR-10 CNN Google Colab Keras 画像分類

X f B! P L
Channel Attentionを追加すると精度は上がるか?CBAM風実装をCIFAR-10で検証【Keras実験】 アイキャッチ画像

Channel Attentionとは何か、使うと何がうれしいのか

CNNの各Conv層は複数のチャンネル(フィルター)を持っています。たとえば128チャンネルある場合、どのチャンネルが今の画像分類に役立つかはネットワークが決められません——全チャンネルを均等に扱うだけです。

Channel Attentionはこの問題を解決します。各チャンネルに対して「重要度スコア(0〜1)」を計算し、重要なチャンネルを強調・不要なチャンネルを抑制する仕組みです。SENet(Squeeze-and-Excitation Networks)や CBAM(Convolutional Block Attention Module)で採用されており、精度向上に有効なことが多くの論文で示されています。

📘 この記事でわかること

  • Channel Attentionの仕組みと数式(Squeeze → Excitation)
  • KerasでCBAM風 Channel Attentionをカスタムレイヤーとして実装する方法
  • Channel Attentionあり vs なし の精度・過学習の差(CIFAR-10実験)
  • どんな場面で使うと効果的か

Channel Attentionの仕組み

処理の流れは大きく2ステップです。

① Squeeze(チャンネルごとのグローバル情報を集約)

特徴マップ(H × W × C)に対し、空間方向(H × W)をGlobal Average Poolingで潰し、チャンネルごとのスカラー値(1 × 1 × C)を得ます。これが「チャンネルの要約情報」です。

$$z_c = \frac{1}{H \times W} \sum_{i=1}^{H} \sum_{j=1}^{W} x_c(i, j)$$

② Excitation(重要度スコアを計算して掛け算)

集約したベクトル(C次元)をDense層でボトルネック圧縮(C/r次元)→ relu → Dense層(C次元)→ sigmoid に通します。出力は各チャンネルの重要度スコア(0〜1)。元の特徴マップにこのスコアを掛け算(スケーリング)します。

$$s = \sigma\!\left(W_2\,\delta\!\left(W_1\,z\right)\right), \quad \hat{x}_c = s_c \cdot x_c$$

ここで \( r \) はリダクション比(reduction ratio)と呼ばれ、中間のDense層の次元数を \( C/r \) に圧縮します。今回は r = 8 を使います。

実験コード

使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。Channel Attentionの有無以外の条件は全て同一にします。

環境準備(最初に一度だけ実行)

# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fonts-ipafont-mincho
The following NEW packages will be installed:
  fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 51 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (18.0 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122412 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 63.2 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
  Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了

import・データ準備

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test  = x_test.astype('float32')  / 255.0
print("データ準備完了")
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 5s 0us/step
データ準備完了

Channel Attentionレイヤーの実装

Kerasのカスタムレイヤーとして実装します。reduction_ratio(= r)で中間次元数を制御します。

class ChannelAttention(keras.layers.Layer):
    """CBAM風 Channel Attention レイヤー
    
    Squeeze(GAP)→ Excitation(FC → relu → FC → sigmoid)→ スケーリング
    """
    def __init__(self, reduction_ratio=8, **kwargs):
        super().__init__(**kwargs)
        self.reduction_ratio = reduction_ratio

    def build(self, input_shape):
        channels = input_shape[-1]
        reduced  = max(channels // self.reduction_ratio, 1)
        self.gap = keras.layers.GlobalAveragePooling2D()
        self.fc1 = keras.layers.Dense(reduced, activation='relu',
                                      use_bias=False)
        self.fc2 = keras.layers.Dense(channels, activation='sigmoid',
                                      use_bias=False)
        super().build(input_shape)

    def call(self, x):
        # Squeeze: (B, H, W, C) → (B, C)
        z = self.gap(x)
        # Excitation: (B, C) → (B, C/r) → (B, C)
        s = self.fc2(self.fc1(z))
        # スケーリング: (B, C) → (B, 1, 1, C) にreshapeして掛け算
        s = tf.reshape(s, (-1, 1, 1, tf.shape(s)[-1]))
        return x * s

    def get_config(self):
        config = super().get_config()
        config.update({'reduction_ratio': self.reduction_ratio})
        return config

⚠️ ハマりポイント

  • GlobalAveragePooling2Dの出力は2次元(B, C)になります。元の特徴マップ(B, H, W, C)に掛け算するには tf.reshape(s, (-1, 1, 1, C)) で形状を合わせる必要があります。形状を合わせないとエラーになります。
  • CBAMの原論文ではGAPに加えてGlobal Max Poolingも使いますが、今回はシンプルさを優先してGAPのみで実装しています。
  • reduction_ratioが大きすぎると中間次元数が0以下になるため、max(channels // r, 1) で最低1を確保しています。

モデル構築関数(あり / なし)

def build_model(use_attention, name):
    inputs = keras.Input(shape=(32, 32, 3))

    # Block 1
    x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x = keras.layers.MaxPooling2D((2, 2))(x)
    if use_attention:
        x = ChannelAttention(reduction_ratio=8)(x)

    # Block 2
    x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = keras.layers.MaxPooling2D((2, 2))(x)
    if use_attention:
        x = ChannelAttention(reduction_ratio=8)(x)

    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)
    outputs = keras.layers.Dense(10, activation='softmax')(x)

    return keras.Model(inputs, outputs, name=name)

def compile_and_fit(model):
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    start = time.time()
    history = model.fit(x_train, y_train, epochs=30, batch_size=64,
                        validation_split=0.2, verbose=1)
    return history, time.time() - start

2パターンの学習実行

configs = [
    (False, 'A_no_attention'),
    (True,  'B_channel_attention'),
]
histories, times, scores, params = {}, {}, {}, {}

for use_attn, name in configs:
    print(f"\n=== {name} ===")
    model = build_model(use_attn, name)
    model.summary()
    h, t = compile_and_fit(model)
    s = model.evaluate(x_test, y_test, verbose=0)
    label = name.split('_', 1)[1]          # 'no_attention' / 'channel_attention'
    histories[label] = h
    times[label]     = t
    scores[label]    = s
    params[label]    = model.count_params()
    print(f"学習時間:{t:.1f}秒 パラメータ数:{model.count_params():,} test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_no_attention ===
Model: "A_no_attention"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)        │ (None, 32, 32, 3)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d (Conv2D)                 │ (None, 32, 32, 64)     │         1,792 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 16, 16, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 16, 16, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 8, 8, 128)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_average_pooling2d        │ (None, 128)            │             0 │
│ (GlobalAveragePooling2D)        │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 128)            │        16,512 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 93,450 (365.04 KB)
 Trainable params: 93,450 (365.04 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 13ms/step - accuracy: 0.2661 - loss: 1.9191 - val_accuracy: 0.3548 - val_loss: 1.7288
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3636 - loss: 1.6903 - val_accuracy: 0.4049 - val_loss: 1.5964
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4256 - loss: 1.5565 - val_accuracy: 0.4683 - val_loss: 1.4680
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4636 - loss: 1.4612 - val_accuracy: 0.4975 - val_loss: 1.3767
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.4915 - loss: 1.3965 - val_accuracy: 0.4874 - val_loss: 1.4022
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5100 - loss: 1.3423 - val_accuracy: 0.5168 - val_loss: 1.3254
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 10ms/step - accuracy: 0.5261 - loss: 1.2996 - val_accuracy: 0.5410 - val_loss: 1.2705
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - accuracy: 0.5393 - loss: 1.2610 - val_accuracy: 0.5626 - val_loss: 1.2186
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5488 - loss: 1.2425 - val_accuracy: 0.5604 - val_loss: 1.2120
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5594 - loss: 1.2105 - val_accuracy: 0.5774 - val_loss: 1.1689
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5698 - loss: 1.1825 - val_accuracy: 0.5783 - val_loss: 1.1614
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5778 - loss: 1.1622 - val_accuracy: 0.5781 - val_loss: 1.1812
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5895 - loss: 1.1393 - val_accuracy: 0.5914 - val_loss: 1.1195
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5974 - loss: 1.1189 - val_accuracy: 0.6041 - val_loss: 1.0943
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6028 - loss: 1.0970 - val_accuracy: 0.6034 - val_loss: 1.0893
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6094 - loss: 1.0774 - val_accuracy: 0.6109 - val_loss: 1.0713
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6169 - loss: 1.0617 - val_accuracy: 0.6059 - val_loss: 1.1050
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6228 - loss: 1.0439 - val_accuracy: 0.6349 - val_loss: 1.0211
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6307 - loss: 1.0267 - val_accuracy: 0.6393 - val_loss: 1.0111
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6351 - loss: 1.0134 - val_accuracy: 0.6171 - val_loss: 1.0742
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6397 - loss: 1.0002 - val_accuracy: 0.6285 - val_loss: 1.0339
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6461 - loss: 0.9831 - val_accuracy: 0.6529 - val_loss: 0.9770
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6495 - loss: 0.9704 - val_accuracy: 0.6386 - val_loss: 1.0035
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6571 - loss: 0.9549 - val_accuracy: 0.6444 - val_loss: 0.9881
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6594 - loss: 0.9425 - val_accuracy: 0.6586 - val_loss: 0.9652
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6658 - loss: 0.9307 - val_accuracy: 0.6532 - val_loss: 0.9820
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6714 - loss: 0.9155 - val_accuracy: 0.6605 - val_loss: 0.9445
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6775 - loss: 0.9017 - val_accuracy: 0.6693 - val_loss: 0.9281
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6788 - loss: 0.8979 - val_accuracy: 0.6717 - val_loss: 0.9271
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6844 - loss: 0.8834 - val_accuracy: 0.6754 - val_loss: 0.9167
学習時間:139.1秒 パラメータ数:93,450 test_accuracy:0.6682

=== B_channel_attention ===
Model: "B_channel_attention"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer_1 (InputLayer)      │ (None, 32, 32, 3)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_2 (Conv2D)               │ (None, 32, 32, 64)     │         1,792 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_2 (MaxPooling2D)  │ (None, 16, 16, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ channel_attention               │ (None, 16, 16, 64)     │         1,024 │
│ (ChannelAttention)              │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_3 (Conv2D)               │ (None, 16, 16, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_3 (MaxPooling2D)  │ (None, 8, 8, 128)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ channel_attention_1             │ (None, 8, 8, 128)      │         4,096 │
│ (ChannelAttention)              │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_average_pooling2d_3      │ (None, 128)            │             0 │
│ (GlobalAveragePooling2D)        │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_6 (Dense)                 │ (None, 128)            │        16,512 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_1 (Dropout)             │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_7 (Dense)                 │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 98,570 (385.04 KB)
 Trainable params: 98,570 (385.04 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 9ms/step - accuracy: 0.2373 - loss: 1.9902 - val_accuracy: 0.3121 - val_loss: 1.7881
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.3390 - loss: 1.7326 - val_accuracy: 0.3642 - val_loss: 1.6694
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3877 - loss: 1.6282 - val_accuracy: 0.4286 - val_loss: 1.5523
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4315 - loss: 1.5301 - val_accuracy: 0.4628 - val_loss: 1.4603
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4658 - loss: 1.4502 - val_accuracy: 0.4874 - val_loss: 1.3932
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4899 - loss: 1.3958 - val_accuracy: 0.5046 - val_loss: 1.3408
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5047 - loss: 1.3535 - val_accuracy: 0.5204 - val_loss: 1.3034
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5189 - loss: 1.3122 - val_accuracy: 0.5334 - val_loss: 1.2638
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5354 - loss: 1.2780 - val_accuracy: 0.5386 - val_loss: 1.2522
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5427 - loss: 1.2550 - val_accuracy: 0.5550 - val_loss: 1.2059
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5529 - loss: 1.2279 - val_accuracy: 0.5495 - val_loss: 1.2221
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5645 - loss: 1.1995 - val_accuracy: 0.5736 - val_loss: 1.1648
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5739 - loss: 1.1733 - val_accuracy: 0.5836 - val_loss: 1.1349
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5804 - loss: 1.1504 - val_accuracy: 0.5973 - val_loss: 1.1036
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5933 - loss: 1.1295 - val_accuracy: 0.5645 - val_loss: 1.1857
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6039 - loss: 1.1022 - val_accuracy: 0.5997 - val_loss: 1.1036
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6101 - loss: 1.0848 - val_accuracy: 0.5996 - val_loss: 1.1203
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6187 - loss: 1.0689 - val_accuracy: 0.6206 - val_loss: 1.0487
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6241 - loss: 1.0462 - val_accuracy: 0.6304 - val_loss: 1.0248
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6339 - loss: 1.0245 - val_accuracy: 0.6301 - val_loss: 1.0183
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6364 - loss: 1.0202 - val_accuracy: 0.6371 - val_loss: 1.0078
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6389 - loss: 0.9981 - val_accuracy: 0.6375 - val_loss: 1.0083
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6445 - loss: 0.9880 - val_accuracy: 0.6466 - val_loss: 0.9834
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6533 - loss: 0.9673 - val_accuracy: 0.6465 - val_loss: 0.9805
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6591 - loss: 0.9573 - val_accuracy: 0.6481 - val_loss: 0.9787
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6632 - loss: 0.9425 - val_accuracy: 0.6420 - val_loss: 0.9913
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6666 - loss: 0.9322 - val_accuracy: 0.6619 - val_loss: 0.9354
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6708 - loss: 0.9214 - val_accuracy: 0.6612 - val_loss: 0.9526
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6749 - loss: 0.9143 - val_accuracy: 0.6588 - val_loss: 0.9393
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6801 - loss: 0.8991 - val_accuracy: 0.6654 - val_loss: 0.9345
学習時間:128.6秒 パラメータ数:98,570 test_accuracy:0.6631

グラフ+サマリー

# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
    axes[0].plot(h.history['val_accuracy'], label=label)
    axes[1].plot(h.history['val_loss'],     label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
    ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('channel_attention_comparison.png', dpi=150)
plt.show()

# ── train vs val(過学習チェック)────────────────────
fig2, axes2 = plt.subplots(2, 1, figsize=(7, 10))
for i, (label, h) in enumerate(histories.items()):
    axes2[i].plot(h.history['loss'],     label='train_loss')
    axes2[i].plot(h.history['val_loss'], label='val_loss')
    axes2[i].set_title(label)
    axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('channel_attention_overfit.png', dpi=150)
plt.show()

print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>20} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>12}")
print("-" * 68)
for label in ['no_attention', 'channel_attention']:
    val_acc  = histories[label].history['val_accuracy'][-1]
    test_acc = scores[label][1]
    t        = times[label]
    p        = params[label]
    print(f"{label:>20} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f} | {p:>12,}")
print("-" * 68)

最終結果サマリー

===== 最終結果サマリー =====
             Pattern |  Val Acc |  Test Acc |  Time(s) |       Params
--------------------------------------------------------------------
        no_attention |   0.6754 |    0.6682 |    139.1 |       93,450
   channel_attention |   0.6654 |    0.6631 |    128.6 |       98,570
--------------------------------------------------------------------

実験結果

精度グラフ

精度グラフ

損失グラフ

損失グラフ

no_attention

no_attention

channel_attention

channel_attention
パターン 最終 val_accuracy 最終 test_accuracy パラメータ数 学習時間
A:Channel Attentionなし 67.54% 66.82% 93,450 139.1秒
B:Channel Attentionあり 66.54% 66.31% 98,570 128.6秒

考察

① 今回はChannel Attentionなしが上回った——なぜか?

結果は val_accuracy で 67.54%(なし)対 66.54%(あり)と、Channel Attentionを追加した方が約1%低いという予想外の結果になりました。test_accuracyも同様に 66.82% 対 66.31% でなし側が上回っています。

これにはいくつかの原因が考えられます。

  • モデルが浅すぎる:今回のベースモデルはConv2Dが2層だけです。Channel AttentionはResNet・VGGのような深いネットワークで効果を発揮する設計です。浅いモデルでは各チャンネルがまだ十分に「役割分担」できておらず、重み付けの恩恵が薄い可能性があります。
  • チャンネル数が少ない:Block 1は64ch、Block 2は128chと比較的少ない構成です。SENetの原論文では256ch以上の深い層でChannel Attentionの効果が大きいとされています。
  • 最適化の難化:Channel Attentionのsigmoidスケーリングが加わることで損失の勾配の流れが変わり、同じエポック数・学習率では十分に収束しきれなかった可能性があります。

② パラメータ増加量の確認

Channel Attentionを2箇所追加したことによるパラメータ増加は 98,570 − 93,450 = 5,120パラメータでした。計算値(Block1: 64×8+8×64=1,024、Block2: 128×16+16×128=4,096 の合計5,120)と完全に一致しています。増加率は約5.5%です。これは非常に小さなコスト——問題はコストではなく、効果が出るかどうかです。

③ 学習時間はむしろ短縮された

Channel Attentionありの学習時間は128.6秒で、なし(139.1秒)より約10秒短くなっています。Channel Attentionのスケーリングが一種の正則化として働き、初期エポックの収束を助けた可能性があります。パラメータ増加があっても学習時間は増えないことが確認できました。

④ Channel Attentionが効果を発揮する条件

条件 効果の出やすさ 理由
深いネットワーク(ResNet等) ◎ 出やすい チャンネル数が多く役割分担が進んでいる
チャンネル数が多い(256ch以上) ○ 出やすい 重み付けの選択肢が多い
浅いCNN(今回) △ 出にくい チャンネルの役割分担が未熟
転移学習モデル(EfficientNet等) ◎ 出やすい 既に豊かな特徴表現があり選択が有効

⑤ reduction_ratioの選び方の目安

reduction_ratio(r) 中間次元数(Ch=128の場合) 特徴
4 32 表現力大・パラメータ増加大
8(今回) 16 原論文推奨。バランスが良い
16 8 軽量・チャンネル数が少ないモデル向け

まとめ

  • Channel Attentionは「どのチャンネルに注目するか」をGAP+Dense(C→C/r→C→sigmoid)で学習し、特徴マップを選択的にスケーリングする仕組み
  • 追加パラメータは5,120(約5.5%増)と非常に小さく、学習時間も増加しなかった
  • 今回の実験では Channel Attentionなし の方が約1%高い精度となった。浅いCNN・少ないチャンネル数の構成では効果が出にくいことを実験で確認できた
  • Channel Attentionの本領は深いネットワーク(ResNet等)や多チャンネル構成(256ch以上)で発揮される。転移学習モデルに差し込む用途が最も効果的
  • カスタムレイヤーとして実装しておけば、Functional APIのConv2D直後に1行で差し込めるので、大きなモデルに試す際の再利用コストはゼロ
  • 「Channel Attentionを追加すれば必ず精度が上がる」ではなく、ネットワークの深さ・チャンネル数・タスクの複雑さに依存する——これが今回最大の学びです

関連記事もあわせてどうぞ: