📘 この記事でわかること
- Multi-Head Self-Attentionとは何か、CNNとどう組み合わせるか
- KerasでMHSAをCNN後段に追加する実装方法
- なしと比較したときの精度・パラメータ数・過学習の変化
- ViT・ハイブリッドモデルの入口として何を理解しておくべきか
TransformerのAttention機構は自然言語処理で革命をもたらし、現在では画像分類(ViT)にも広く使われています。しかし「CNNに自己注意機構を足すだけで精度が上がるのか?」という疑問は意外と実験で確かめにくいテーマです。
今回はGoogle ColabとCIFAR-10を使い、CNN後段にMulti-Head Self-Attentionを追加したモデル vs 追加しないモデルを比較しました。
Channel Attentionとの違いは → Channel Attentionを追加すると精度は上がるか?CBAM風実装をCIFAR-10で検証 もあわせてどうぞ。
Multi-Head Self-Attentionとは
Self-Attention(自己注意機構)は、入力シーケンスの各要素が他のすべての要素との関連度を計算し、重み付き和として新しい表現を生成します。Multi-Headはこれを複数の「ヘッド」で並列実行し、異なる観点の注意パターンを同時に学習します。
\[ \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \]
CNN+MHSAのハイブリッド構成では、Conv2Dで局所的な特徴を抽出した後、MHSAでグローバルな依存関係を捉えることができます。
| 手法 | 捉える関係 | 計算コスト | 代表モデル |
|---|---|---|---|
| Conv2D | 局所的な特徴(近傍ピクセル) | 低〜中 | VGG, ResNet |
| Channel Attention | チャネル間の重要度 | 低 | SENet, CBAM |
| Multi-Head Self-Attention | 全位置間のグローバルな依存 | 高(系列長の2乗) | ViT, Swin Transformer |
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。MHSAの有無以外の条件は全て同一にして、効果だけを取り出します。Conv2D 2層でGAPしたあとの128次元ベクトルにMHSAを適用します。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 51 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (23.9 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122363 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 34.7 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備・モデル構築
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
# ── データ読み込み ─────────────────────────────────────
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# ── ベースラインモデル(MHSAなし)────────────────────────
def build_baseline(name='baseline'):
inputs = keras.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
x = keras.layers.MaxPooling2D((2, 2))(x)
x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = keras.layers.MaxPooling2D((2, 2))(x)
x = keras.layers.GlobalAveragePooling2D()(x) # → (batch, 128)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.2)(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
return keras.Model(inputs, outputs, name=name)
# ── MHSAありモデル ────────────────────────────────────
# Conv2D後の特徴マップ (batch, 8, 8, 128) をシーケンス化してMHSAを適用
def build_mhsa(name='mhsa', num_heads=4, key_dim=32):
inputs = keras.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
x = keras.layers.MaxPooling2D((2, 2))(x)
x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = keras.layers.MaxPooling2D((2, 2))(x)
# (batch, 8, 8, 128) → (batch, 64, 128) にreshapeしてシーケンス化
B, H, W, C = x.shape
x = keras.layers.Reshape((H * W, C))(x) # (batch, 64, 128)
# Multi-Head Self-Attention
attn_out = keras.layers.MultiHeadAttention(
num_heads=num_heads, key_dim=key_dim, dropout=0.1
)(x, x) # self-attention: query=key=value=x
x = keras.layers.Add()([x, attn_out]) # 残差接続
x = keras.layers.LayerNormalization()(x)
# GAPに相当する操作(シーケンス方向の平均)
x = keras.layers.GlobalAveragePooling1D()(x) # (batch, 128)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.2)(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
return keras.Model(inputs, outputs, name=name)
def compile_and_fit(model, epochs=30):
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
start = time.time()
history = model.fit(x_train, y_train, epochs=epochs, batch_size=64,
validation_split=0.2, verbose=1)
return history, time.time() - start
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 3s 0us/step
2パターンの学習実行
configs = [
('baseline', build_baseline),
('mhsa', build_mhsa),
]
histories, times, scores, params = {}, {}, {}, {}
for label, builder in configs:
print(f"\n=== {label} ===")
model = builder(name=label)
print(model.summary())
h, t = compile_and_fit(model)
s = model.evaluate(x_test, y_test, verbose=0)
histories[label] = h
times[label] = t
scores[label] = s
params[label] = model.count_params()
print(f"学習時間:{t:.1f}秒 パラメータ数:{model.count_params():,} test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== baseline === Model: "baseline" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 32, 32, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 93,450 (365.04 KB) Trainable params: 93,450 (365.04 KB) Non-trainable params: 0 (0.00 B) None Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - accuracy: 0.2789 - loss: 1.9064 - val_accuracy: 0.3747 - val_loss: 1.7024 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3839 - loss: 1.6467 - val_accuracy: 0.4309 - val_loss: 1.5472 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4368 - loss: 1.5371 - val_accuracy: 0.4792 - val_loss: 1.4520 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4710 - loss: 1.4425 - val_accuracy: 0.4997 - val_loss: 1.3721 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4939 - loss: 1.3734 - val_accuracy: 0.5163 - val_loss: 1.3181 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5155 - loss: 1.3235 - val_accuracy: 0.5195 - val_loss: 1.3086 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5321 - loss: 1.2790 - val_accuracy: 0.5249 - val_loss: 1.2782 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5452 - loss: 1.2485 - val_accuracy: 0.5352 - val_loss: 1.2625 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5534 - loss: 1.2221 - val_accuracy: 0.5653 - val_loss: 1.1788 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 7ms/step - accuracy: 0.5674 - loss: 1.1884 - val_accuracy: 0.5778 - val_loss: 1.1514 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5738 - loss: 1.1675 - val_accuracy: 0.5874 - val_loss: 1.1358 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5849 - loss: 1.1390 - val_accuracy: 0.5903 - val_loss: 1.1249 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5959 - loss: 1.1165 - val_accuracy: 0.5913 - val_loss: 1.1206 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 10ms/step - accuracy: 0.6086 - loss: 1.0906 - val_accuracy: 0.5885 - val_loss: 1.1285 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6104 - loss: 1.0790 - val_accuracy: 0.6199 - val_loss: 1.0560 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6169 - loss: 1.0577 - val_accuracy: 0.6128 - val_loss: 1.0596 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6251 - loss: 1.0395 - val_accuracy: 0.6125 - val_loss: 1.0543 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6298 - loss: 1.0250 - val_accuracy: 0.6217 - val_loss: 1.0419 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6385 - loss: 1.0044 - val_accuracy: 0.6359 - val_loss: 1.0111 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6453 - loss: 0.9875 - val_accuracy: 0.6419 - val_loss: 0.9942 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6499 - loss: 0.9826 - val_accuracy: 0.6343 - val_loss: 0.9970 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6525 - loss: 0.9682 - val_accuracy: 0.6519 - val_loss: 0.9610 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6601 - loss: 0.9485 - val_accuracy: 0.6451 - val_loss: 0.9816 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6628 - loss: 0.9404 - val_accuracy: 0.6610 - val_loss: 0.9400 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6687 - loss: 0.9212 - val_accuracy: 0.6478 - val_loss: 0.9783 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6727 - loss: 0.9130 - val_accuracy: 0.6441 - val_loss: 0.9793 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6758 - loss: 0.9063 - val_accuracy: 0.6636 - val_loss: 0.9357 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6833 - loss: 0.8877 - val_accuracy: 0.6770 - val_loss: 0.8993 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6860 - loss: 0.8786 - val_accuracy: 0.6744 - val_loss: 0.9082 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6895 - loss: 0.8687 - val_accuracy: 0.6747 - val_loss: 0.9069 学習時間:130.2秒 パラメータ数:93,450 test_accuracy:0.6676 === mhsa === Model: "mhsa" ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ input_layer_1 │ (None, 32, 32, 3) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_2 (Conv2D) │ (None, 32, 32, │ 1,792 │ input_layer_1[0]… │ │ │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ max_pooling2d_2 │ (None, 16, 16, │ 0 │ conv2d_2[0][0] │ │ (MaxPooling2D) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_3 (Conv2D) │ (None, 16, 16, │ 73,856 │ max_pooling2d_2[… │ │ │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ max_pooling2d_3 │ (None, 8, 8, 128) │ 0 │ conv2d_3[0][0] │ │ (MaxPooling2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ reshape (Reshape) │ (None, 64, 128) │ 0 │ max_pooling2d_3[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ multi_head_attenti… │ (None, 64, 128) │ 66,048 │ reshape[0][0], │ │ (MultiHeadAttentio… │ │ │ reshape[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add (Add) │ (None, 64, 128) │ 0 │ reshape[0][0], │ │ │ │ │ multi_head_atten… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ layer_normalization │ (None, 64, 128) │ 256 │ add[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ global_average_poo… │ (None, 128) │ 0 │ layer_normalizat… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_2 (Dense) │ (None, 128) │ 16,512 │ global_average_p… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dropout_2 (Dropout) │ (None, 128) │ 0 │ dense_2[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_3 (Dense) │ (None, 10) │ 1,290 │ dropout_2[0][0] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘ Total params: 159,754 (624.04 KB) Trainable params: 159,754 (624.04 KB) Non-trainable params: 0 (0.00 B) None Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 16ms/step - accuracy: 0.3187 - loss: 1.7970 - val_accuracy: 0.4501 - val_loss: 1.4820 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.4768 - loss: 1.4164 - val_accuracy: 0.5290 - val_loss: 1.2856 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5367 - loss: 1.2675 - val_accuracy: 0.5384 - val_loss: 1.2272 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5663 - loss: 1.1854 - val_accuracy: 0.5914 - val_loss: 1.1106 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5942 - loss: 1.1173 - val_accuracy: 0.6136 - val_loss: 1.0637 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.6137 - loss: 1.0637 - val_accuracy: 0.6036 - val_loss: 1.0875 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6374 - loss: 1.0075 - val_accuracy: 0.6335 - val_loss: 1.0211 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 9ms/step - accuracy: 0.6546 - loss: 0.9629 - val_accuracy: 0.6392 - val_loss: 0.9950 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6675 - loss: 0.9227 - val_accuracy: 0.6728 - val_loss: 0.9256 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6771 - loss: 0.8987 - val_accuracy: 0.6377 - val_loss: 1.0198 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6907 - loss: 0.8674 - val_accuracy: 0.6882 - val_loss: 0.8676 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7034 - loss: 0.8317 - val_accuracy: 0.6734 - val_loss: 0.9250 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7118 - loss: 0.8108 - val_accuracy: 0.6853 - val_loss: 0.8999 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7203 - loss: 0.7866 - val_accuracy: 0.6910 - val_loss: 0.8668 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.7295 - loss: 0.7597 - val_accuracy: 0.6850 - val_loss: 0.8977 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7357 - loss: 0.7438 - val_accuracy: 0.6878 - val_loss: 0.9062 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7419 - loss: 0.7237 - val_accuracy: 0.7052 - val_loss: 0.8445 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7510 - loss: 0.6965 - val_accuracy: 0.6983 - val_loss: 0.8644 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7574 - loss: 0.6820 - val_accuracy: 0.7052 - val_loss: 0.8553 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.7605 - loss: 0.6710 - val_accuracy: 0.6993 - val_loss: 0.8753 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7689 - loss: 0.6515 - val_accuracy: 0.7082 - val_loss: 0.8551 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7699 - loss: 0.6429 - val_accuracy: 0.7046 - val_loss: 0.8583 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7781 - loss: 0.6209 - val_accuracy: 0.7149 - val_loss: 0.8414 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7850 - loss: 0.6027 - val_accuracy: 0.6967 - val_loss: 0.9190 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 9ms/step - accuracy: 0.7898 - loss: 0.5921 - val_accuracy: 0.7137 - val_loss: 0.8583 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.7910 - loss: 0.5824 - val_accuracy: 0.7121 - val_loss: 0.8587 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.7975 - loss: 0.5652 - val_accuracy: 0.6947 - val_loss: 0.9252 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.8008 - loss: 0.5566 - val_accuracy: 0.7016 - val_loss: 0.9223 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.8062 - loss: 0.5435 - val_accuracy: 0.7158 - val_loss: 0.8756 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.8099 - loss: 0.5335 - val_accuracy: 0.7167 - val_loss: 0.8897 学習時間:166.2秒 パラメータ数:159,754 test_accuracy:0.7133
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('mhsa_comparison.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(過学習の乖離) ────────────
fig2, axes2 = plt.subplots(2, 1, figsize=(7, 10))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(f'{label}')
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('mhsa_overfit.png', dpi=150)
plt.show()
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>10} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>12}")
print("-" * 58)
for label in ['baseline', 'mhsa']:
val_acc = histories[label].history['val_accuracy'][-1]
test_acc = scores[label][1]
t = times[label]
p = params[label]
print(f"{label:>10} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f} | {p:>12,}")
print("-" * 58)
最終結果サマリー
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Time(s) | Params
----------------------------------------------------------
baseline | 0.6747 | 0.6676 | 130.2 | 93,450
mhsa | 0.7167 | 0.7133 | 166.2 | 159,754
----------------------------------------------------------
実験結果
精度グラフ
損失グラフ
ベースラインモデル
MHSAありモデル
| パターン | 最終 val_accuracy | 最終 test_accuracy | パラメータ数 | 学習時間 |
|---|---|---|---|---|
| A:MHSAなし(ベースライン) | 0.6747 | 0.6676 | 93,450 | 130.2秒 |
| B:MHSAあり | 0.7167 | 0.7133 | 159,754 | 166.2秒 |
考察
① MHSAを追加すると精度が約4.6ポイント向上した
test_accuracyはベースラインの66.76% → 71.33%と約4.6ポイント向上しました。MHSAが8×8=64個の空間パッチ間のグローバルな依存関係を学習することで、CNNが捉えにくかった「離れた位置の特徴の関連性」をモデルが活用できるようになったと考えられます。
② train_lossとval_lossの乖離に注意
MHSAありモデルは30エポック時点でtrain_accuracy 80.99%、val_accuracy 71.67%と約9ポイントの乖離があります。ベースライン(train 68.95% vs val 67.47%、乖離約1.5ポイント)と比べると過学習が顕著です。MHSAの追加によりモデルの表現力が上がった分、正則化の強化(Dropoutの増加・EarlyStopping・データ拡張)が効果的です。
③ パラメータ数と学習時間のコスト
パラメータ数はベースラインの93,450から159,754へ約1.7倍に増加し、学習時間も130.2秒→166.2秒と約28%増加しました。MHSAが追加する66,048パラメータは注意行列の計算(Q・K・V射影 + 出力射影)で消費されます。精度向上幅(+4.6ポイント)に対して計算コスト増加は比較的小さく、コスト効率は良好と言えます。
④ Channel AttentionやSEブロックとの比較
同様の条件で検証したChannel Attention(CBAM)やSEブロックは、パラメータ増加を最小限に抑えながら精度を改善する手法です。MHSAはそれらより大きな精度向上が期待できる一方、過学習リスクと計算コストも高くなります。用途や学習データ量に応じて使い分けるのが現実的です。
- Reshapeの次元に注意:Conv2D後の特徴マップは (batch, H, W, C) の4次元。MultiHeadAttentionは3次元入力 (batch, seq_len, d_model) を期待するので、Reshape((H*W, C)) が必要。
- 残差接続を忘れずに:Attention出力をそのまま使うと学習が不安定になりやすい。Add() + LayerNormalization() をセットで使う。
- 過学習しやすい:今回の実験でもtrain/val精度の乖離が約9ポイントと大きかった。Dropoutの強化やデータ拡張との併用を検討する。
- 学習が遅くなる:MHSAは行列積が増えるためCPU環境では特に遅い。GPU(T4以上)推奨。
- CNN後段にMHSAを追加することで test_accuracy が 66.76% → 71.33%(+4.57ポイント)向上した
- パラメータ数は約1.7倍(93,450 → 159,754)、学習時間は約28%増加とコスト効率は良好
- 一方でtrain/val精度の乖離が約9ポイントと過学習が顕著になった。データ拡張やDropout強化との併用を推奨
- KerasではMultiHeadAttentionレイヤーとReshapeを組み合わせるだけでCNNハイブリッド構成を実装できる
- Channel Attentionより表現力が高い分、正則化の工夫がより重要になる
関連記事
- Channel Attentionとの比較 → Channel Attentionを追加すると精度は上がるか?CBAM風実装をCIFAR-10で検証
- SEブロックの効果 → SEブロック(Squeeze-and-Excitation)を追加すると精度は上がるか?
- Residual接続の効果 → Residual接続(スキップ接続)あり vs なし を比較





0 件のコメント:
コメントを投稿