Global Average Pooling vs Flatten|CNNの最終層、 どっちが精度・速度で有利か?【Keras実験】

投稿日:2026年3月15日日曜日 最終更新日:

CIFAR-10 CNN Global Average Pooling Google Colab Keras 過学習

X f B! P L
Global Average Pooling vs Flatten|CNNの最終層、 どっちが精度・速度で有利か?【Keras実験】 アイキャッチ画像

KerasでCNNを書くとき、Conv2Dの後に必ずFlattenを入れている方も多いと思います。

ただ、MobileNetV2やResNetのコードを読むと、FlattenではなくGlobalAveragePooling2D(GAP)が使われています。「どちらが良いのか?何が違うのか?」

今回はGoogle ColabとCIFAR-10を使い、2パターンを実験で比較しました。

この記事を読むとわかること:

  • FlattenとGlobal Average Poolingで何が変わるか(パラメータ数・精度・過学習)
  • val_lossから見た過学習への耐性の差
  • KerasでGAPに書き換える方法(1行の変更)

FlattenとGlobal Average Poolingの違いとは

Flattenはどう動くか

Flattenは、Conv2Dが出力した特徴マップ(例:8×8×128)を全要素を1列に並べて1次元ベクトル化します。8×8×128 = 8,192次元のベクトルになり、その後のDense層へ渡されます。

空間情報を全て保持できる一方で、Dense層への入力次元が大きくなるためパラメータ数が多くなります。例えば Dense(128) への接続では 8,192 × 128 = 約105万パラメータが必要です。

GlobalAveragePooling2Dはどう動くか

GlobalAveragePooling2D(GAP)は、各チャンネルの特徴マップ全体を1つの平均値に集約します。8×8×128の特徴マップは「128個の平均値」つまり128次元ベクトルになります。

Dense(128) への接続は 128 × 128 = 約1.6万パラメータで済むため、Flattenと比べてパラメータ数が大幅に削減されます。

なぜGAPが過学習を抑えやすいのか

パラメータ数が少ない=モデルの表現の自由度が下がるため、訓練データへの過度な適合(過学習)が起きにくくなります。また、GAPはDropoutなしでも正則化効果を持つとされており、MobileNetV2・ResNet・EfficientNetなど多くの現代的なCNNがFlattenではなくGAPを採用しているのはこの理由です。

※MobileNetV2でのGAP使用例は → 【Keras/CNN】MobileNetV2で転移学習してみた

実験設定:2パターンのモデルコード

使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。

build_model(use_gap=True/False) の引数1つで切り替えられる設計にしました。それ以外の条件(Conv2D層数・filters数・エポック数)は全て同一です。

実験コード

環境準備(最初に一度だけ実行)

# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fonts-ipafont-mincho
The following NEW packages will be installed:
  fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 37 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 2s (3,770 kB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 121852 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 44.0 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
  Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了

import・データ準備・モデル構築関数

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time

# CIFAR-10データの読み込み・前処理
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test  = x_test.astype('float32')  / 255.0

def build_model(use_gap: bool):
    """
    use_gap=True  → GlobalAveragePooling2D を使用
    use_gap=False → Flatten を使用
    """
    inputs = keras.layers.Input(shape=(32, 32, 3))
    x = keras.layers.Conv2D(64,  (3, 3), activation='relu', padding='same')(inputs)
    x = keras.layers.MaxPooling2D((2, 2))(x)
    x = keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = keras.layers.MaxPooling2D((2, 2))(x)

    # ここだけが異なる ──────────────────────
    if use_gap:
        x = keras.layers.GlobalAveragePooling2D()(x)  # GAP: 128次元
    else:
        x = keras.layers.Flatten()(x)               # Flatten: 8×8×128=8192次元
    # ──────────────────────────────────────

    x = keras.layers.Dense(128, activation='relu')(x)
    outputs = keras.layers.Dense(10, activation='softmax')(x)
    name = 'B_GAP' if use_gap else 'A_Flatten'
    return keras.Model(inputs, outputs, name=name)

def compile_and_fit(model):
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    start = time.time()
    history = model.fit(
        x_train, y_train,
        epochs=30,
        batch_size=64,
        validation_split=0.2,
        verbose=1
    )
    elapsed = time.time() - start
    return history, elapsed
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 21s 0us/step

2パターンの学習実行

print("\n=== Pattern A:Flatten ===")
model_A = build_model(use_gap=False)
print(model_A.summary())
history_A, time_A = compile_and_fit(model_A)
print(f"学習時間:{time_A:.1f}秒 パラメータ数:{model_A.count_params():,}")

print("\n=== Pattern B:GlobalAveragePooling2D ===")
model_B = build_model(use_gap=True)
print(model_B.summary())
history_B, time_B = compile_and_fit(model_B)
print(f"学習時間:{time_B:.1f}秒 パラメータ数:{model_B.count_params():,}")

# テストデータで最終精度を評価
test_results = {
    'A:Flatten': model_A.evaluate(x_test, y_test, verbose=0),
    'B:GAP':     model_B.evaluate(x_test, y_test, verbose=0),
}
実行結果をクリックして内容を開く
=== Pattern A:Flatten ===
Model: "A_Flatten"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)        │ (None, 32, 32, 3)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d (Conv2D)                 │ (None, 32, 32, 64)     │         1,792 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 16, 16, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 16, 16, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 8, 8, 128)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten (Flatten)               │ (None, 8192)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 128)            │     1,048,704 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 1,125,642 (4.29 MB)
 Trainable params: 1,125,642 (4.29 MB)
 Non-trainable params: 0 (0.00 B)
None
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 17ms/step - accuracy: 0.4804 - loss: 1.4496 - val_accuracy: 0.6044 - val_loss: 1.1438
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6337 - loss: 1.0428 - val_accuracy: 0.6505 - val_loss: 0.9944
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.6896 - loss: 0.8893 - val_accuracy: 0.6800 - val_loss: 0.9341
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7275 - loss: 0.7818 - val_accuracy: 0.6889 - val_loss: 0.9057
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.7581 - loss: 0.6897 - val_accuracy: 0.7028 - val_loss: 0.8963
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7883 - loss: 0.5995 - val_accuracy: 0.7202 - val_loss: 0.8540
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8192 - loss: 0.5171 - val_accuracy: 0.7171 - val_loss: 0.8705
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.8500 - loss: 0.4324 - val_accuracy: 0.7193 - val_loss: 0.9373
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8747 - loss: 0.3651 - val_accuracy: 0.7121 - val_loss: 0.9719
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.8989 - loss: 0.2930 - val_accuracy: 0.7139 - val_loss: 1.0618
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9196 - loss: 0.2357 - val_accuracy: 0.7067 - val_loss: 1.1721
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9352 - loss: 0.1907 - val_accuracy: 0.7111 - val_loss: 1.2412
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9470 - loss: 0.1524 - val_accuracy: 0.7030 - val_loss: 1.3994
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9574 - loss: 0.1272 - val_accuracy: 0.7046 - val_loss: 1.4998
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9652 - loss: 0.1058 - val_accuracy: 0.7021 - val_loss: 1.5913
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.9700 - loss: 0.0909 - val_accuracy: 0.7061 - val_loss: 1.6949
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9699 - loss: 0.0902 - val_accuracy: 0.6982 - val_loss: 1.8477
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.9708 - loss: 0.0846 - val_accuracy: 0.7022 - val_loss: 1.8459
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9748 - loss: 0.0735 - val_accuracy: 0.7018 - val_loss: 1.9694
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9789 - loss: 0.0616 - val_accuracy: 0.6973 - val_loss: 2.0612
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9779 - loss: 0.0656 - val_accuracy: 0.6927 - val_loss: 2.1966
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9785 - loss: 0.0617 - val_accuracy: 0.6909 - val_loss: 2.1484
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9783 - loss: 0.0626 - val_accuracy: 0.6972 - val_loss: 2.1378
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9774 - loss: 0.0669 - val_accuracy: 0.7007 - val_loss: 2.2786
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9816 - loss: 0.0536 - val_accuracy: 0.6892 - val_loss: 2.3319
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9779 - loss: 0.0639 - val_accuracy: 0.6989 - val_loss: 2.4350
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.9834 - loss: 0.0491 - val_accuracy: 0.6951 - val_loss: 2.5914
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9806 - loss: 0.0566 - val_accuracy: 0.6909 - val_loss: 2.5629
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9829 - loss: 0.0513 - val_accuracy: 0.7031 - val_loss: 2.5082
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9866 - loss: 0.0391 - val_accuracy: 0.7091 - val_loss: 2.5783
学習時間:140.0秒 パラメータ数:1,125,642

=== Pattern B:GlobalAveragePooling2D ===
Model: "B_GAP"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer_1 (InputLayer)      │ (None, 32, 32, 3)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_2 (Conv2D)               │ (None, 32, 32, 64)     │         1,792 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_2 (MaxPooling2D)  │ (None, 16, 16, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_3 (Conv2D)               │ (None, 16, 16, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_3 (MaxPooling2D)  │ (None, 8, 8, 128)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_average_pooling2d        │ (None, 128)            │             0 │
│ (GlobalAveragePooling2D)        │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 128)            │        16,512 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_3 (Dense)                 │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 93,450 (365.04 KB)
 Trainable params: 93,450 (365.04 KB)
 Non-trainable params: 0 (0.00 B)
None
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 8ms/step - accuracy: 0.2711 - loss: 1.9112 - val_accuracy: 0.3058 - val_loss: 1.7901
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3864 - loss: 1.6565 - val_accuracy: 0.4244 - val_loss: 1.5736
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4424 - loss: 1.5254 - val_accuracy: 0.4579 - val_loss: 1.4987
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4764 - loss: 1.4454 - val_accuracy: 0.4663 - val_loss: 1.4509
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5037 - loss: 1.3739 - val_accuracy: 0.4875 - val_loss: 1.3984
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5187 - loss: 1.3259 - val_accuracy: 0.5143 - val_loss: 1.3384
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5354 - loss: 1.2822 - val_accuracy: 0.5491 - val_loss: 1.2527
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5535 - loss: 1.2385 - val_accuracy: 0.5512 - val_loss: 1.2391
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5652 - loss: 1.2039 - val_accuracy: 0.5695 - val_loss: 1.2003
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5781 - loss: 1.1720 - val_accuracy: 0.5649 - val_loss: 1.1861
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5864 - loss: 1.1481 - val_accuracy: 0.5900 - val_loss: 1.1421
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5954 - loss: 1.1200 - val_accuracy: 0.5736 - val_loss: 1.1796
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 7ms/step - accuracy: 0.6061 - loss: 1.0965 - val_accuracy: 0.6015 - val_loss: 1.1201
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6119 - loss: 1.0769 - val_accuracy: 0.6013 - val_loss: 1.1006
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6164 - loss: 1.0649 - val_accuracy: 0.5979 - val_loss: 1.1086
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6258 - loss: 1.0383 - val_accuracy: 0.6069 - val_loss: 1.0816
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6315 - loss: 1.0257 - val_accuracy: 0.6191 - val_loss: 1.0513
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6396 - loss: 1.0038 - val_accuracy: 0.6274 - val_loss: 1.0557
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6418 - loss: 0.9981 - val_accuracy: 0.6293 - val_loss: 1.0231
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6495 - loss: 0.9765 - val_accuracy: 0.6272 - val_loss: 1.0350
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6566 - loss: 0.9588 - val_accuracy: 0.6422 - val_loss: 1.0008
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6604 - loss: 0.9457 - val_accuracy: 0.6455 - val_loss: 0.9905
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6639 - loss: 0.9355 - val_accuracy: 0.6388 - val_loss: 1.0139
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6671 - loss: 0.9249 - val_accuracy: 0.6412 - val_loss: 1.0002
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6729 - loss: 0.9115 - val_accuracy: 0.6409 - val_loss: 1.0284
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6758 - loss: 0.9020 - val_accuracy: 0.6445 - val_loss: 0.9960
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6817 - loss: 0.8905 - val_accuracy: 0.6685 - val_loss: 0.9330
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6881 - loss: 0.8740 - val_accuracy: 0.6611 - val_loss: 0.9442
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6915 - loss: 0.8666 - val_accuracy: 0.6609 - val_loss: 0.9455
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6932 - loss: 0.8619 - val_accuracy: 0.6690 - val_loss: 0.9362
学習時間:122.3秒 パラメータ数:93,450

グラフ(val_accuracy + val_loss + train_loss比較)+ サマリー

histories = {
    'A:Flatten': history_A,
    'B:GAP':     history_B,
}

# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
    axes[0].plot(h.history['val_accuracy'], label=label)
    axes[1].plot(h.history['val_loss'],     label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
    ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('gap_flatten_comparison.png', dpi=150)
plt.show()

# ── train_loss vs val_loss(過学習の乖離を見る)────────
fig2, axes2 = plt.subplots(1, 2, figsize=(14, 5))
for i, (label, h) in enumerate(histories.items()):
    axes2[i].plot(h.history['loss'],     label='train_loss')
    axes2[i].plot(h.history['val_loss'], label='val_loss')
    axes2[i].set_title(f'{label}:train vs val loss')
    axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('gap_flatten_overfit.png', dpi=150)
plt.show()

# ── 最終結果サマリー ─────────────────────────────────
key_order   = ['A:Flatten', 'B:GAP']
time_list   = {'A:Flatten': time_A, 'B:GAP': time_B}
params_list = {'A:Flatten': model_A.count_params(), 'B:GAP': model_B.count_params()}

print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>12} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>10}")
print("-" * 60)
for key in key_order:
    val_acc = histories[key].history['val_accuracy'][-1]
    test_loss, test_acc = test_results[key]
    elapsed = time_list[key]
    params  = params_list[key]
    print(f"{key:>12} | {val_acc:>8.4f} | {test_acc:>9.4f} | {elapsed:>8.1f} | {params:>10,}")
print("-" * 60)

最終結果サマリー

===== 最終結果サマリー =====
     Pattern |  Val Acc |  Test Acc |  Time(s) |     Params
------------------------------------------------------------
   A:Flatten |   0.7091 |    0.6911 |    140.0 |  1,125,642
       B:GAP |   0.6690 |    0.6639 |    122.3 |     93,450
------------------------------------------------------------

実験結果①:精度・損失の比較

val_accuracy の比較(30エポック)

val_accuracy GAPとFlatten
val_loss GAPとFlatten

グラフから読み取れること:

  • 最終精度(test_accuracy)はGAPとFlattenで大きな差はありません。
  • 精度だけを見ると「どちらでも良い」という結論になりますが、次のval_lossに注目することで重要な違いが見えてきます。

val_loss から見る過学習の度合い

Flatten train_lossとval_loss
GAP train_lossとval_loss

このグラフがこの記事で最も注目すべき部分です:

  • Flatten(パターンA)は後半のエポックでtrain_lossとval_lossの乖離が大きくなります。これは過学習が進んでいることを意味します。
  • GAP(パターンB)は乖離が相対的に小さく抑えられており、過学習への耐性が高いことが確認できます。

パラメータ数の差がそのまま過学習のしやすさの差として表れた結果です。

パターン最終 val_accuracy最終 test_accuracytrain vs val 乖離過学習の度合い
A:Flatten70.91%69.11%大きいやや多い
B:GAP66.90%66.39%小さい少ない

実験結果②:パラメータ数と学習時間

精度は似ていても、コストには大きな差があります。

パターンTest Accuracyパラメータ数学習時間(秒)評価
A:Flatten69.11%1,125,642140.0秒△パラメータ過多
B:GAP66.9%93,450122.3秒◎ 軽量・過学習耐性

GAPに切り替えるとパラメータ数が大幅に削減されます。精度がほぼ同等であれば、より少ないパラメータで同じ精度が出る方が良いモデルと言えます。

モバイルデバイスへのデプロイや、Colabの限られたメモリ・時間制限の中で実験する場合にも、GAPは実用的なメリットがあります。

結論:どちらを選ぶべきか

実験結果をもとに、状況別の結論をまとめます。

スクラッチでCNNを書く場合 → GAPを積極的に採用してよい

精度は若干Flattenが上回りますが、パラメータ数が1/12になる点を考えるとGAPは十分に実用的な選択肢です。特に過学習が気になる場合やモデルを軽量に保ちたい場合に有効です。

転移学習・ファインチューニングを行う場合 → GAPが標準的な選択

MobileNetV2・EfficientNet・ResNetなどの事前学習済みモデルはほぼ全てGAPを採用しています。これらを使う場合は自然とGAPになります。

Flattenが適しているケースは?

空間的な位置情報を細かく活用したい場合(例:物体検出の中間処理など)はFlattenが有効なこともあります。ただし一般的な画像分類タスクでは、GAPで十分なケースがほとんどです。

正直なところ、今回の実験では精度の差は小さい結果になることが多いです。「精度はほぼ同じで、パラメータは少なく、過学習しにくい」という理由でGAPを選ぶのが現代的なCNN設計の常識になっています。

補足:GAP使用時はDense層の前に追加するだけ

既存のFlatten使用コードをGAPに変更するには、1行を差し替えるだけです。

# Before:Flatten
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(128, activation='relu')(x)

# After:GlobalAveragePooling2D(1行変えるだけ)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(128, activation='relu')(x)
GAPの後のDense層はオプションです。最終分類層(softmax)の直前にGAPを置いて、Dense層を省略する構成もよく見られます。MobileNetV2がこの構成を採用しています。

まとめ

今回はKerasとCIFAR-10を使い、FlattenとGlobal Average Poolingを比較しました。

  • 精度はFlattenが約2〜3%上回った(ただしパラメータ数は1/12)
  • パラメータ数・学習時間はGAPが大幅に少ない
  • 現代的なCNN設計ではGAPが標準(MobileNetV2・ResNet・EfficientNetなど)

関連記事もあわせてどうぞ: