KerasでCNNを書くとき、Conv2Dの後に必ずFlattenを入れている方も多いと思います。
ただ、MobileNetV2やResNetのコードを読むと、FlattenではなくGlobalAveragePooling2D(GAP)が使われています。「どちらが良いのか?何が違うのか?」
今回はGoogle ColabとCIFAR-10を使い、2パターンを実験で比較しました。
この記事を読むとわかること:
- FlattenとGlobal Average Poolingで何が変わるか(パラメータ数・精度・過学習)
- val_lossから見た過学習への耐性の差
- KerasでGAPに書き換える方法(1行の変更)
FlattenとGlobal Average Poolingの違いとは
Flattenはどう動くか
Flattenは、Conv2Dが出力した特徴マップ(例:8×8×128)を全要素を1列に並べて1次元ベクトル化します。8×8×128 = 8,192次元のベクトルになり、その後のDense層へ渡されます。
空間情報を全て保持できる一方で、Dense層への入力次元が大きくなるためパラメータ数が多くなります。例えば Dense(128) への接続では 8,192 × 128 = 約105万パラメータが必要です。
GlobalAveragePooling2Dはどう動くか
GlobalAveragePooling2D(GAP)は、各チャンネルの特徴マップ全体を1つの平均値に集約します。8×8×128の特徴マップは「128個の平均値」つまり128次元ベクトルになります。
Dense(128) への接続は 128 × 128 = 約1.6万パラメータで済むため、Flattenと比べてパラメータ数が大幅に削減されます。
なぜGAPが過学習を抑えやすいのか
パラメータ数が少ない=モデルの表現の自由度が下がるため、訓練データへの過度な適合(過学習)が起きにくくなります。また、GAPはDropoutなしでも正則化効果を持つとされており、MobileNetV2・ResNet・EfficientNetなど多くの現代的なCNNがFlattenではなくGAPを採用しているのはこの理由です。
※MobileNetV2でのGAP使用例は → 【Keras/CNN】MobileNetV2で転移学習してみた
実験設定:2パターンのモデルコード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。
build_model(use_gap=True/False) の引数1つで切り替えられる設計にしました。それ以外の条件(Conv2D層数・filters数・エポック数)は全て同一です。
実験コード
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 37 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 2s (3,770 kB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 121852 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 44.0 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備・モデル構築関数
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
# CIFAR-10データの読み込み・前処理
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
def build_model(use_gap: bool):
"""
use_gap=True → GlobalAveragePooling2D を使用
use_gap=False → Flatten を使用
"""
inputs = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
x = keras.layers.MaxPooling2D((2, 2))(x)
x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = keras.layers.MaxPooling2D((2, 2))(x)
# ここだけが異なる ──────────────────────
if use_gap:
x = keras.layers.GlobalAveragePooling2D()(x) # GAP: 128次元
else:
x = keras.layers.Flatten()(x) # Flatten: 8×8×128=8192次元
# ──────────────────────────────────────
x = keras.layers.Dense(128, activation='relu')(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
name = 'B_GAP' if use_gap else 'A_Flatten'
return keras.Model(inputs, outputs, name=name)
def compile_and_fit(model):
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
start = time.time()
history = model.fit(
x_train, y_train,
epochs=30,
batch_size=64,
validation_split=0.2,
verbose=1
)
elapsed = time.time() - start
return history, elapsed
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 21s 0us/step
2パターンの学習実行
print("\n=== Pattern A:Flatten ===")
model_A = build_model(use_gap=False)
print(model_A.summary())
history_A, time_A = compile_and_fit(model_A)
print(f"学習時間:{time_A:.1f}秒 パラメータ数:{model_A.count_params():,}")
print("\n=== Pattern B:GlobalAveragePooling2D ===")
model_B = build_model(use_gap=True)
print(model_B.summary())
history_B, time_B = compile_and_fit(model_B)
print(f"学習時間:{time_B:.1f}秒 パラメータ数:{model_B.count_params():,}")
# テストデータで最終精度を評価
test_results = {
'A:Flatten': model_A.evaluate(x_test, y_test, verbose=0),
'B:GAP': model_B.evaluate(x_test, y_test, verbose=0),
}
実行結果をクリックして内容を開く
=== Pattern A:Flatten === Model: "A_Flatten" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 32, 32, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ flatten (Flatten) │ (None, 8192) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 128) │ 1,048,704 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 1,125,642 (4.29 MB) Trainable params: 1,125,642 (4.29 MB) Non-trainable params: 0 (0.00 B) None Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 17ms/step - accuracy: 0.4804 - loss: 1.4496 - val_accuracy: 0.6044 - val_loss: 1.1438 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6337 - loss: 1.0428 - val_accuracy: 0.6505 - val_loss: 0.9944 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.6896 - loss: 0.8893 - val_accuracy: 0.6800 - val_loss: 0.9341 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7275 - loss: 0.7818 - val_accuracy: 0.6889 - val_loss: 0.9057 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.7581 - loss: 0.6897 - val_accuracy: 0.7028 - val_loss: 0.8963 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7883 - loss: 0.5995 - val_accuracy: 0.7202 - val_loss: 0.8540 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8192 - loss: 0.5171 - val_accuracy: 0.7171 - val_loss: 0.8705 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.8500 - loss: 0.4324 - val_accuracy: 0.7193 - val_loss: 0.9373 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8747 - loss: 0.3651 - val_accuracy: 0.7121 - val_loss: 0.9719 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8989 - loss: 0.2930 - val_accuracy: 0.7139 - val_loss: 1.0618 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9196 - loss: 0.2357 - val_accuracy: 0.7067 - val_loss: 1.1721 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9352 - loss: 0.1907 - val_accuracy: 0.7111 - val_loss: 1.2412 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9470 - loss: 0.1524 - val_accuracy: 0.7030 - val_loss: 1.3994 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9574 - loss: 0.1272 - val_accuracy: 0.7046 - val_loss: 1.4998 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9652 - loss: 0.1058 - val_accuracy: 0.7021 - val_loss: 1.5913 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.9700 - loss: 0.0909 - val_accuracy: 0.7061 - val_loss: 1.6949 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9699 - loss: 0.0902 - val_accuracy: 0.6982 - val_loss: 1.8477 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.9708 - loss: 0.0846 - val_accuracy: 0.7022 - val_loss: 1.8459 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9748 - loss: 0.0735 - val_accuracy: 0.7018 - val_loss: 1.9694 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9789 - loss: 0.0616 - val_accuracy: 0.6973 - val_loss: 2.0612 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9779 - loss: 0.0656 - val_accuracy: 0.6927 - val_loss: 2.1966 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9785 - loss: 0.0617 - val_accuracy: 0.6909 - val_loss: 2.1484 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9783 - loss: 0.0626 - val_accuracy: 0.6972 - val_loss: 2.1378 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9774 - loss: 0.0669 - val_accuracy: 0.7007 - val_loss: 2.2786 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9816 - loss: 0.0536 - val_accuracy: 0.6892 - val_loss: 2.3319 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9779 - loss: 0.0639 - val_accuracy: 0.6989 - val_loss: 2.4350 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.9834 - loss: 0.0491 - val_accuracy: 0.6951 - val_loss: 2.5914 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9806 - loss: 0.0566 - val_accuracy: 0.6909 - val_loss: 2.5629 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9829 - loss: 0.0513 - val_accuracy: 0.7031 - val_loss: 2.5082 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9866 - loss: 0.0391 - val_accuracy: 0.7091 - val_loss: 2.5783 学習時間:140.0秒 パラメータ数:1,125,642 === Pattern B:GlobalAveragePooling2D === Model: "B_GAP" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer_1 (InputLayer) │ (None, 32, 32, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_2 (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_2 (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_3 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 93,450 (365.04 KB) Trainable params: 93,450 (365.04 KB) Non-trainable params: 0 (0.00 B) None Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 8ms/step - accuracy: 0.2711 - loss: 1.9112 - val_accuracy: 0.3058 - val_loss: 1.7901 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3864 - loss: 1.6565 - val_accuracy: 0.4244 - val_loss: 1.5736 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4424 - loss: 1.5254 - val_accuracy: 0.4579 - val_loss: 1.4987 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4764 - loss: 1.4454 - val_accuracy: 0.4663 - val_loss: 1.4509 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5037 - loss: 1.3739 - val_accuracy: 0.4875 - val_loss: 1.3984 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5187 - loss: 1.3259 - val_accuracy: 0.5143 - val_loss: 1.3384 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5354 - loss: 1.2822 - val_accuracy: 0.5491 - val_loss: 1.2527 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5535 - loss: 1.2385 - val_accuracy: 0.5512 - val_loss: 1.2391 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5652 - loss: 1.2039 - val_accuracy: 0.5695 - val_loss: 1.2003 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5781 - loss: 1.1720 - val_accuracy: 0.5649 - val_loss: 1.1861 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5864 - loss: 1.1481 - val_accuracy: 0.5900 - val_loss: 1.1421 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5954 - loss: 1.1200 - val_accuracy: 0.5736 - val_loss: 1.1796 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 7ms/step - accuracy: 0.6061 - loss: 1.0965 - val_accuracy: 0.6015 - val_loss: 1.1201 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6119 - loss: 1.0769 - val_accuracy: 0.6013 - val_loss: 1.1006 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6164 - loss: 1.0649 - val_accuracy: 0.5979 - val_loss: 1.1086 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6258 - loss: 1.0383 - val_accuracy: 0.6069 - val_loss: 1.0816 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6315 - loss: 1.0257 - val_accuracy: 0.6191 - val_loss: 1.0513 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6396 - loss: 1.0038 - val_accuracy: 0.6274 - val_loss: 1.0557 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6418 - loss: 0.9981 - val_accuracy: 0.6293 - val_loss: 1.0231 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6495 - loss: 0.9765 - val_accuracy: 0.6272 - val_loss: 1.0350 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6566 - loss: 0.9588 - val_accuracy: 0.6422 - val_loss: 1.0008 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6604 - loss: 0.9457 - val_accuracy: 0.6455 - val_loss: 0.9905 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6639 - loss: 0.9355 - val_accuracy: 0.6388 - val_loss: 1.0139 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6671 - loss: 0.9249 - val_accuracy: 0.6412 - val_loss: 1.0002 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6729 - loss: 0.9115 - val_accuracy: 0.6409 - val_loss: 1.0284 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6758 - loss: 0.9020 - val_accuracy: 0.6445 - val_loss: 0.9960 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6817 - loss: 0.8905 - val_accuracy: 0.6685 - val_loss: 0.9330 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6881 - loss: 0.8740 - val_accuracy: 0.6611 - val_loss: 0.9442 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6915 - loss: 0.8666 - val_accuracy: 0.6609 - val_loss: 0.9455 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6932 - loss: 0.8619 - val_accuracy: 0.6690 - val_loss: 0.9362 学習時間:122.3秒 パラメータ数:93,450
グラフ(val_accuracy + val_loss + train_loss比較)+ サマリー
histories = {
'A:Flatten': history_A,
'B:GAP': history_B,
}
# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('gap_flatten_comparison.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(過学習の乖離を見る)────────
fig2, axes2 = plt.subplots(1, 2, figsize=(14, 5))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(f'{label}:train vs val loss')
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('gap_flatten_overfit.png', dpi=150)
plt.show()
# ── 最終結果サマリー ─────────────────────────────────
key_order = ['A:Flatten', 'B:GAP']
time_list = {'A:Flatten': time_A, 'B:GAP': time_B}
params_list = {'A:Flatten': model_A.count_params(), 'B:GAP': model_B.count_params()}
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>12} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>10}")
print("-" * 60)
for key in key_order:
val_acc = histories[key].history['val_accuracy'][-1]
test_loss, test_acc = test_results[key]
elapsed = time_list[key]
params = params_list[key]
print(f"{key:>12} | {val_acc:>8.4f} | {test_acc:>9.4f} | {elapsed:>8.1f} | {params:>10,}")
print("-" * 60)
最終結果サマリー
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Time(s) | Params
------------------------------------------------------------
A:Flatten | 0.7091 | 0.6911 | 140.0 | 1,125,642
B:GAP | 0.6690 | 0.6639 | 122.3 | 93,450
------------------------------------------------------------
実験結果①:精度・損失の比較
val_accuracy の比較(30エポック)
グラフから読み取れること:
- 最終精度(test_accuracy)はGAPとFlattenで大きな差はありません。
- 精度だけを見ると「どちらでも良い」という結論になりますが、次のval_lossに注目することで重要な違いが見えてきます。
val_loss から見る過学習の度合い
このグラフがこの記事で最も注目すべき部分です:
- Flatten(パターンA)は後半のエポックでtrain_lossとval_lossの乖離が大きくなります。これは過学習が進んでいることを意味します。
- GAP(パターンB)は乖離が相対的に小さく抑えられており、過学習への耐性が高いことが確認できます。
パラメータ数の差がそのまま過学習のしやすさの差として表れた結果です。
| パターン | 最終 val_accuracy | 最終 test_accuracy | train vs val 乖離 | 過学習の度合い |
|---|---|---|---|---|
| A:Flatten | 70.91% | 69.11% | 大きい | やや多い |
| B:GAP | 66.90% | 66.39% | 小さい | 少ない |
実験結果②:パラメータ数と学習時間
精度は似ていても、コストには大きな差があります。
| パターン | Test Accuracy | パラメータ数 | 学習時間(秒) | 評価 |
|---|---|---|---|---|
| A:Flatten | 69.11% | 1,125,642 | 140.0秒 | △パラメータ過多 |
| B:GAP | 66.9% | 93,450 | 122.3秒 | ◎ 軽量・過学習耐性 |
GAPに切り替えるとパラメータ数が大幅に削減されます。精度がほぼ同等であれば、より少ないパラメータで同じ精度が出る方が良いモデルと言えます。
モバイルデバイスへのデプロイや、Colabの限られたメモリ・時間制限の中で実験する場合にも、GAPは実用的なメリットがあります。
結論:どちらを選ぶべきか
実験結果をもとに、状況別の結論をまとめます。
スクラッチでCNNを書く場合 → GAPを積極的に採用してよい
精度は若干Flattenが上回りますが、パラメータ数が1/12になる点を考えるとGAPは十分に実用的な選択肢です。特に過学習が気になる場合やモデルを軽量に保ちたい場合に有効です。
転移学習・ファインチューニングを行う場合 → GAPが標準的な選択
MobileNetV2・EfficientNet・ResNetなどの事前学習済みモデルはほぼ全てGAPを採用しています。これらを使う場合は自然とGAPになります。
Flattenが適しているケースは?
空間的な位置情報を細かく活用したい場合(例:物体検出の中間処理など)はFlattenが有効なこともあります。ただし一般的な画像分類タスクでは、GAPで十分なケースがほとんどです。
正直なところ、今回の実験では精度の差は小さい結果になることが多いです。「精度はほぼ同じで、パラメータは少なく、過学習しにくい」という理由でGAPを選ぶのが現代的なCNN設計の常識になっています。
補足:GAP使用時はDense層の前に追加するだけ
既存のFlatten使用コードをGAPに変更するには、1行を差し替えるだけです。
# Before:Flatten
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(128, activation='relu')(x)
# After:GlobalAveragePooling2D(1行変えるだけ)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(128, activation='relu')(x)
まとめ
今回はKerasとCIFAR-10を使い、FlattenとGlobal Average Poolingを比較しました。
- 精度はFlattenが約2〜3%上回った(ただしパラメータ数は1/12)
- パラメータ数・学習時間はGAPが大幅に少ない
- 現代的なCNN設計ではGAPが標準(MobileNetV2・ResNet・EfficientNetなど)
関連記事もあわせてどうぞ:
- CNNの基本的な作り方 → Google ColabでKerasを使った画像分類入門(MNIST編)
- GAPが使われている転移学習の実例 → 【CNN入門】Kerasで画像分類モデルを作る方法|基本構造からMobileNetV2までやさしく解説
- Conv2Dのfilters数の選び方 → Conv2DのFilters数(32 vs 64 vs 128)を変えると 精度はどう変わる?【CIFAR-10実験】
- CIFAR-10での過学習対策 → CIFAR-10で学ぶ:過学習(オーバーフィッティング)の判定方法と実践的対策





0 件のコメント:
コメントを投稿