Multi-Head Self-Attentionを追加すると精度は上がるか?ViT風実装をCIFAR-10で検証【Keras実験】

投稿日:2026年5月23日土曜日 最終更新日:

CIFAR-10 CNN Google Colab Keras 過学習 画像分類

X f B! P L
Multi-Head Self-Attentionを追加すると精度は上がるか?ViT風実装をCIFAR-10で検証【Keras実験】アイキャッチ画像

📘 この記事でわかること

  • Multi-Head Self-Attentionとは何か、CNNとどう組み合わせるか
  • KerasでMHSAをCNN後段に追加する実装方法
  • なしと比較したときの精度・パラメータ数・過学習の変化
  • ViT・ハイブリッドモデルの入口として何を理解しておくべきか

TransformerのAttention機構は自然言語処理で革命をもたらし、現在では画像分類(ViT)にも広く使われています。しかし「CNNに自己注意機構を足すだけで精度が上がるのか?」という疑問は意外と実験で確かめにくいテーマです。

今回はGoogle ColabとCIFAR-10を使い、CNN後段にMulti-Head Self-Attentionを追加したモデル vs 追加しないモデルを比較しました。

Channel Attentionとの違いは → Channel Attentionを追加すると精度は上がるか?CBAM風実装をCIFAR-10で検証 もあわせてどうぞ。

Multi-Head Self-Attentionとは

Self-Attention(自己注意機構)は、入力シーケンスの各要素が他のすべての要素との関連度を計算し、重み付き和として新しい表現を生成します。Multi-Headはこれを複数の「ヘッド」で並列実行し、異なる観点の注意パターンを同時に学習します。

\[ \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \]

CNN+MHSAのハイブリッド構成では、Conv2Dで局所的な特徴を抽出した後、MHSAでグローバルな依存関係を捉えることができます。

手法捉える関係計算コスト代表モデル
Conv2D局所的な特徴(近傍ピクセル)低〜中VGG, ResNet
Channel Attentionチャネル間の重要度SENet, CBAM
Multi-Head Self-Attention全位置間のグローバルな依存高(系列長の2乗)ViT, Swin Transformer

実験コード

使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。MHSAの有無以外の条件は全て同一にして、効果だけを取り出します。Conv2D 2層でGAPしたあとの128次元ベクトルにMHSAを適用します。

環境準備(最初に一度だけ実行)

# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fonts-ipafont-mincho
The following NEW packages will be installed:
  fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 51 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (23.9 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122363 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 34.7 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
  Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了

import・データ準備・モデル構築

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time

# ── データ読み込み ─────────────────────────────────────
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test  = x_test.astype('float32')  / 255.0

# ── ベースラインモデル(MHSAなし)────────────────────────
def build_baseline(name='baseline'):
    inputs = keras.Input(shape=(32, 32, 3))
    x = keras.layers.Conv2D(64,  (3, 3), activation='relu', padding='same')(inputs)
    x = keras.layers.MaxPooling2D((2, 2))(x)
    x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = keras.layers.MaxPooling2D((2, 2))(x)
    x = keras.layers.GlobalAveragePooling2D()(x)   # → (batch, 128)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)
    outputs = keras.layers.Dense(10, activation='softmax')(x)
    return keras.Model(inputs, outputs, name=name)

# ── MHSAありモデル ────────────────────────────────────
# Conv2D後の特徴マップ (batch, 8, 8, 128) をシーケンス化してMHSAを適用
def build_mhsa(name='mhsa', num_heads=4, key_dim=32):
    inputs = keras.Input(shape=(32, 32, 3))
    x = keras.layers.Conv2D(64,  (3, 3), activation='relu', padding='same')(inputs)
    x = keras.layers.MaxPooling2D((2, 2))(x)
    x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = keras.layers.MaxPooling2D((2, 2))(x)
    # (batch, 8, 8, 128) → (batch, 64, 128) にreshapeしてシーケンス化
    B, H, W, C = x.shape
    x = keras.layers.Reshape((H * W, C))(x)        # (batch, 64, 128)
    # Multi-Head Self-Attention
    attn_out = keras.layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=key_dim, dropout=0.1
    )(x, x)                                         # self-attention: query=key=value=x
    x = keras.layers.Add()([x, attn_out])           # 残差接続
    x = keras.layers.LayerNormalization()(x)
    # GAPに相当する操作(シーケンス方向の平均)
    x = keras.layers.GlobalAveragePooling1D()(x)    # (batch, 128)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)
    outputs = keras.layers.Dense(10, activation='softmax')(x)
    return keras.Model(inputs, outputs, name=name)

def compile_and_fit(model, epochs=30):
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    start = time.time()
    history = model.fit(x_train, y_train, epochs=epochs, batch_size=64,
                        validation_split=0.2, verbose=1)
    return history, time.time() - start
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 3s 0us/step

2パターンの学習実行

configs = [
    ('baseline', build_baseline),
    ('mhsa',     build_mhsa),
]
histories, times, scores, params = {}, {}, {}, {}

for label, builder in configs:
    print(f"\n=== {label} ===")
    model = builder(name=label)
    print(model.summary())
    h, t = compile_and_fit(model)
    s = model.evaluate(x_test, y_test, verbose=0)
    histories[label] = h
    times[label]     = t
    scores[label]    = s
    params[label]    = model.count_params()
    print(f"学習時間:{t:.1f}秒 パラメータ数:{model.count_params():,} test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== baseline ===
Model: "baseline"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)        │ (None, 32, 32, 3)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d (Conv2D)                 │ (None, 32, 32, 64)     │         1,792 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 16, 16, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 16, 16, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 8, 8, 128)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_average_pooling2d        │ (None, 128)            │             0 │
│ (GlobalAveragePooling2D)        │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 128)            │        16,512 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 93,450 (365.04 KB)
 Trainable params: 93,450 (365.04 KB)
 Non-trainable params: 0 (0.00 B)
None
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - accuracy: 0.2789 - loss: 1.9064 - val_accuracy: 0.3747 - val_loss: 1.7024
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3839 - loss: 1.6467 - val_accuracy: 0.4309 - val_loss: 1.5472
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4368 - loss: 1.5371 - val_accuracy: 0.4792 - val_loss: 1.4520
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4710 - loss: 1.4425 - val_accuracy: 0.4997 - val_loss: 1.3721
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4939 - loss: 1.3734 - val_accuracy: 0.5163 - val_loss: 1.3181
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5155 - loss: 1.3235 - val_accuracy: 0.5195 - val_loss: 1.3086
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5321 - loss: 1.2790 - val_accuracy: 0.5249 - val_loss: 1.2782
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5452 - loss: 1.2485 - val_accuracy: 0.5352 - val_loss: 1.2625
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5534 - loss: 1.2221 - val_accuracy: 0.5653 - val_loss: 1.1788
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 7ms/step - accuracy: 0.5674 - loss: 1.1884 - val_accuracy: 0.5778 - val_loss: 1.1514
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5738 - loss: 1.1675 - val_accuracy: 0.5874 - val_loss: 1.1358
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5849 - loss: 1.1390 - val_accuracy: 0.5903 - val_loss: 1.1249
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5959 - loss: 1.1165 - val_accuracy: 0.5913 - val_loss: 1.1206
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 10ms/step - accuracy: 0.6086 - loss: 1.0906 - val_accuracy: 0.5885 - val_loss: 1.1285
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6104 - loss: 1.0790 - val_accuracy: 0.6199 - val_loss: 1.0560
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6169 - loss: 1.0577 - val_accuracy: 0.6128 - val_loss: 1.0596
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6251 - loss: 1.0395 - val_accuracy: 0.6125 - val_loss: 1.0543
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6298 - loss: 1.0250 - val_accuracy: 0.6217 - val_loss: 1.0419
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6385 - loss: 1.0044 - val_accuracy: 0.6359 - val_loss: 1.0111
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6453 - loss: 0.9875 - val_accuracy: 0.6419 - val_loss: 0.9942
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6499 - loss: 0.9826 - val_accuracy: 0.6343 - val_loss: 0.9970
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6525 - loss: 0.9682 - val_accuracy: 0.6519 - val_loss: 0.9610
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6601 - loss: 0.9485 - val_accuracy: 0.6451 - val_loss: 0.9816
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6628 - loss: 0.9404 - val_accuracy: 0.6610 - val_loss: 0.9400
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6687 - loss: 0.9212 - val_accuracy: 0.6478 - val_loss: 0.9783
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6727 - loss: 0.9130 - val_accuracy: 0.6441 - val_loss: 0.9793
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6758 - loss: 0.9063 - val_accuracy: 0.6636 - val_loss: 0.9357
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6833 - loss: 0.8877 - val_accuracy: 0.6770 - val_loss: 0.8993
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6860 - loss: 0.8786 - val_accuracy: 0.6744 - val_loss: 0.9082
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6895 - loss: 0.8687 - val_accuracy: 0.6747 - val_loss: 0.9069
学習時間:130.2秒 パラメータ数:93,450 test_accuracy:0.6676

=== mhsa ===
Model: "mhsa"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input_layer_1       │ (None, 32, 32, 3) │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_2 (Conv2D)   │ (None, 32, 32,    │      1,792 │ input_layer_1[0]… │
│                     │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ max_pooling2d_2     │ (None, 16, 16,    │          0 │ conv2d_2[0][0]    │
│ (MaxPooling2D)      │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_3 (Conv2D)   │ (None, 16, 16,    │     73,856 │ max_pooling2d_2[… │
│                     │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ max_pooling2d_3     │ (None, 8, 8, 128) │          0 │ conv2d_3[0][0]    │
│ (MaxPooling2D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ reshape (Reshape)   │ (None, 64, 128)   │          0 │ max_pooling2d_3[… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ multi_head_attenti… │ (None, 64, 128)   │     66,048 │ reshape[0][0],    │
│ (MultiHeadAttentio… │                   │            │ reshape[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 64, 128)   │          0 │ reshape[0][0],    │
│                     │                   │            │ multi_head_atten… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ layer_normalization │ (None, 64, 128)   │        256 │ add[0][0]         │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ global_average_poo… │ (None, 128)       │          0 │ layer_normalizat… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_2 (Dense)     │ (None, 128)       │     16,512 │ global_average_p… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout_2 (Dropout) │ (None, 128)       │          0 │ dense_2[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_3 (Dense)     │ (None, 10)        │      1,290 │ dropout_2[0][0]   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 159,754 (624.04 KB)
 Trainable params: 159,754 (624.04 KB)
 Non-trainable params: 0 (0.00 B)
None
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 16ms/step - accuracy: 0.3187 - loss: 1.7970 - val_accuracy: 0.4501 - val_loss: 1.4820
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.4768 - loss: 1.4164 - val_accuracy: 0.5290 - val_loss: 1.2856
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5367 - loss: 1.2675 - val_accuracy: 0.5384 - val_loss: 1.2272
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5663 - loss: 1.1854 - val_accuracy: 0.5914 - val_loss: 1.1106
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5942 - loss: 1.1173 - val_accuracy: 0.6136 - val_loss: 1.0637
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.6137 - loss: 1.0637 - val_accuracy: 0.6036 - val_loss: 1.0875
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6374 - loss: 1.0075 - val_accuracy: 0.6335 - val_loss: 1.0211
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 9ms/step - accuracy: 0.6546 - loss: 0.9629 - val_accuracy: 0.6392 - val_loss: 0.9950
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6675 - loss: 0.9227 - val_accuracy: 0.6728 - val_loss: 0.9256
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6771 - loss: 0.8987 - val_accuracy: 0.6377 - val_loss: 1.0198
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6907 - loss: 0.8674 - val_accuracy: 0.6882 - val_loss: 0.8676
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7034 - loss: 0.8317 - val_accuracy: 0.6734 - val_loss: 0.9250
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7118 - loss: 0.8108 - val_accuracy: 0.6853 - val_loss: 0.8999
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7203 - loss: 0.7866 - val_accuracy: 0.6910 - val_loss: 0.8668
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.7295 - loss: 0.7597 - val_accuracy: 0.6850 - val_loss: 0.8977
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7357 - loss: 0.7438 - val_accuracy: 0.6878 - val_loss: 0.9062
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7419 - loss: 0.7237 - val_accuracy: 0.7052 - val_loss: 0.8445
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7510 - loss: 0.6965 - val_accuracy: 0.6983 - val_loss: 0.8644
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7574 - loss: 0.6820 - val_accuracy: 0.7052 - val_loss: 0.8553
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.7605 - loss: 0.6710 - val_accuracy: 0.6993 - val_loss: 0.8753
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7689 - loss: 0.6515 - val_accuracy: 0.7082 - val_loss: 0.8551
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7699 - loss: 0.6429 - val_accuracy: 0.7046 - val_loss: 0.8583
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7781 - loss: 0.6209 - val_accuracy: 0.7149 - val_loss: 0.8414
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7850 - loss: 0.6027 - val_accuracy: 0.6967 - val_loss: 0.9190
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 9ms/step - accuracy: 0.7898 - loss: 0.5921 - val_accuracy: 0.7137 - val_loss: 0.8583
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7910 - loss: 0.5824 - val_accuracy: 0.7121 - val_loss: 0.8587
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.7975 - loss: 0.5652 - val_accuracy: 0.6947 - val_loss: 0.9252
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.8008 - loss: 0.5566 - val_accuracy: 0.7016 - val_loss: 0.9223
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.8062 - loss: 0.5435 - val_accuracy: 0.7158 - val_loss: 0.8756
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.8099 - loss: 0.5335 - val_accuracy: 0.7167 - val_loss: 0.8897
学習時間:166.2秒 パラメータ数:159,754 test_accuracy:0.7133

グラフ+サマリー

# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
    axes[0].plot(h.history['val_accuracy'], label=label)
    axes[1].plot(h.history['val_loss'],     label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
    ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('mhsa_comparison.png', dpi=150)
plt.show()

# ── train_loss vs val_loss(過学習の乖離) ────────────
fig2, axes2 = plt.subplots(2, 1, figsize=(7, 10))
for i, (label, h) in enumerate(histories.items()):
    axes2[i].plot(h.history['loss'],     label='train_loss')
    axes2[i].plot(h.history['val_loss'], label='val_loss')
    axes2[i].set_title(f'{label}')
    axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('mhsa_overfit.png', dpi=150)
plt.show()

print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>10} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>12}")
print("-" * 58)
for label in ['baseline', 'mhsa']:
    val_acc  = histories[label].history['val_accuracy'][-1]
    test_acc = scores[label][1]
    t        = times[label]
    p        = params[label]
    print(f"{label:>10} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f} | {p:>12,}")
print("-" * 58)

最終結果サマリー

===== 最終結果サマリー =====
   Pattern |  Val Acc |  Test Acc |  Time(s) |       Params
----------------------------------------------------------
  baseline |   0.6747 |    0.6676 |    130.2 |       93,450
      mhsa |   0.7167 |    0.7133 |    166.2 |      159,754
----------------------------------------------------------

実験結果

精度グラフ

精度グラフ

損失グラフ

損失グラフ

ベースラインモデル

ベースラインモデル

MHSAありモデル

MHSAありモデル
パターン最終 val_accuracy最終 test_accuracyパラメータ数学習時間
A:MHSAなし(ベースライン)0.67470.667693,450130.2秒
B:MHSAあり0.71670.7133159,754166.2秒

考察

① MHSAを追加すると精度が約4.6ポイント向上した

test_accuracyはベースラインの66.76% → 71.33%と約4.6ポイント向上しました。MHSAが8×8=64個の空間パッチ間のグローバルな依存関係を学習することで、CNNが捉えにくかった「離れた位置の特徴の関連性」をモデルが活用できるようになったと考えられます。

② train_lossとval_lossの乖離に注意

MHSAありモデルは30エポック時点でtrain_accuracy 80.99%、val_accuracy 71.67%と約9ポイントの乖離があります。ベースライン(train 68.95% vs val 67.47%、乖離約1.5ポイント)と比べると過学習が顕著です。MHSAの追加によりモデルの表現力が上がった分、正則化の強化(Dropoutの増加・EarlyStopping・データ拡張)が効果的です。

③ パラメータ数と学習時間のコスト

パラメータ数はベースラインの93,450から159,754へ約1.7倍に増加し、学習時間も130.2秒→166.2秒と約28%増加しました。MHSAが追加する66,048パラメータは注意行列の計算(Q・K・V射影 + 出力射影)で消費されます。精度向上幅(+4.6ポイント)に対して計算コスト増加は比較的小さく、コスト効率は良好と言えます。

④ Channel AttentionやSEブロックとの比較

同様の条件で検証したChannel Attention(CBAM)SEブロックは、パラメータ増加を最小限に抑えながら精度を改善する手法です。MHSAはそれらより大きな精度向上が期待できる一方、過学習リスクと計算コストも高くなります。用途や学習データ量に応じて使い分けるのが現実的です。

⚠️ ハマりポイント
  • Reshapeの次元に注意:Conv2D後の特徴マップは (batch, H, W, C) の4次元。MultiHeadAttentionは3次元入力 (batch, seq_len, d_model) を期待するので、Reshape((H*W, C)) が必要。
  • 残差接続を忘れずに:Attention出力をそのまま使うと学習が不安定になりやすい。Add() + LayerNormalization() をセットで使う。
  • 過学習しやすい:今回の実験でもtrain/val精度の乖離が約9ポイントと大きかった。Dropoutの強化やデータ拡張との併用を検討する。
  • 学習が遅くなる:MHSAは行列積が増えるためCPU環境では特に遅い。GPU(T4以上)推奨。
✅ まとめ
  • CNN後段にMHSAを追加することで test_accuracy が 66.76% → 71.33%(+4.57ポイント)向上した
  • パラメータ数は約1.7倍(93,450 → 159,754)、学習時間は約28%増加とコスト効率は良好
  • 一方でtrain/val精度の乖離が約9ポイントと過学習が顕著になった。データ拡張やDropout強化との併用を推奨
  • KerasではMultiHeadAttentionレイヤーとReshapeを組み合わせるだけでCNNハイブリッド構成を実装できる
  • Channel Attentionより表現力が高い分、正則化の工夫がより重要になる

関連記事