KerasのCNNでDense層のユニット数をいくつにするか、迷ったことはありませんか?
「大きくすれば精度が上がるはず」——直感的にはそう思えますが、過学習のリスクも上がります。今回はGoogle ColabとCIFAR-10を使い、Dense層のユニット数を32・128・512の3パターンで比較しました。
なお、MNISTでのDense層ユニット数比較は → Dense層のユニット数を変えると何が起きるか?CNNでの実験結果 をご覧ください。本記事はCIFAR-10・GAP構成でより踏み込んだ比較を行います。
📘 この記事でわかること
- Dense層のユニット数を変えると精度・パラメータ数・過学習の度合いがどう変わるか
- 小さすぎる・大きすぎるユニット数で何が起きるか
- CIFAR-10+GAP構成での最適なユニット数の目安
ユニット数を変えると何が起きるか
Dense層のユニット数はGAPで集約された特徴をどれだけの次元数で処理するかを決めます。
| ユニット数 | Dense層のParam #(GAP後=128次元の場合) | 期待される挙動 |
|---|---|---|
| 32 | (128+1)× 32 = 4,128 | 表現力が小さい。シンプルなタスクなら十分だがCIFAR-10では不足する可能性あり |
| 128 | (128+1)× 128 = 16,512 | 今回の実験で使ってきた標準値。バランスが良い |
| 512 | (128+1)× 512 = 66,048 | 表現力が大きい。精度向上が期待できるが過学習リスクも上がる |
今回はGAP後のベクトルが128次元(Conv2D 2層・フィルター128)なので、Dense層への入力は128次元です。ユニット数を増やすほどパラメータ数が増えますが、GAP使用モデルではFlattenと比べてパラメータの増加は限定的です。
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。Dense層のユニット数以外の条件は全て同一にして、ユニット数の影響だけを取り出します。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 42 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 3s (2,539 kB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122354 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 69.1 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備・モデル構築関数
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
def build_model(units, name):
return keras.Sequential([
keras.layers.Input(shape=(32, 32, 3)),
keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(units, activation='relu'), # ← ここだけ変える
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax'),
], name=name)
def compile_and_fit(model):
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
start = time.time()
history = model.fit(x_train, y_train, epochs=30, batch_size=64,
validation_split=0.2, verbose=1)
return history, time.time() - start
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 23s 0us/step
3パターンの学習実行
configs = [(32, 'A_units32'), (128, 'B_units128'), (512, 'C_units512')]
histories, times, scores, params = {}, {}, {}, {}
for units, name in configs:
print(f"\n=== {name} ===")
model = build_model(units, name)
print(model.summary())
h, t = compile_and_fit(model)
s = model.evaluate(x_test, y_test, verbose=0)
label = name.split('_')[1]
histories[label] = h
times[label] = t
scores[label] = s
params[label] = model.count_params()
print(f"学習時間:{t:.1f}秒 パラメータ数:{model.count_params():,} test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_units32 === Model: "A_units32" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 32) │ 4,128 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 10) │ 330 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 80,106 (312.91 KB) Trainable params: 80,106 (312.91 KB) Non-trainable params: 0 (0.00 B) None Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 14ms/step - accuracy: 0.2338 - loss: 2.0020 - val_accuracy: 0.2994 - val_loss: 1.8155 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 7s 11ms/step - accuracy: 0.3140 - loss: 1.7816 - val_accuracy: 0.3806 - val_loss: 1.6739 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 8ms/step - accuracy: 0.3658 - loss: 1.6767 - val_accuracy: 0.4334 - val_loss: 1.5550 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4117 - loss: 1.5810 - val_accuracy: 0.4464 - val_loss: 1.4849 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4385 - loss: 1.5189 - val_accuracy: 0.4784 - val_loss: 1.4306 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4576 - loss: 1.4681 - val_accuracy: 0.4824 - val_loss: 1.4119 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4737 - loss: 1.4310 - val_accuracy: 0.4846 - val_loss: 1.3982 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.4823 - loss: 1.4063 - val_accuracy: 0.4946 - val_loss: 1.3570 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4934 - loss: 1.3782 - val_accuracy: 0.5291 - val_loss: 1.2948 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4989 - loss: 1.3577 - val_accuracy: 0.5388 - val_loss: 1.2719 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5137 - loss: 1.3312 - val_accuracy: 0.5412 - val_loss: 1.2494 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.5190 - loss: 1.3165 - val_accuracy: 0.5519 - val_loss: 1.2292 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.5289 - loss: 1.2935 - val_accuracy: 0.5557 - val_loss: 1.2150 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.5339 - loss: 1.2741 - val_accuracy: 0.5616 - val_loss: 1.2107 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 8ms/step - accuracy: 0.5426 - loss: 1.2501 - val_accuracy: 0.5720 - val_loss: 1.1835 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.5465 - loss: 1.2376 - val_accuracy: 0.5521 - val_loss: 1.2184 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 7s 11ms/step - accuracy: 0.5524 - loss: 1.2184 - val_accuracy: 0.5853 - val_loss: 1.1422 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 9ms/step - accuracy: 0.5608 - loss: 1.2079 - val_accuracy: 0.5879 - val_loss: 1.1281 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5652 - loss: 1.1884 - val_accuracy: 0.5730 - val_loss: 1.1925 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5707 - loss: 1.1817 - val_accuracy: 0.6029 - val_loss: 1.0939 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5730 - loss: 1.1630 - val_accuracy: 0.5652 - val_loss: 1.1771 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5826 - loss: 1.1478 - val_accuracy: 0.6167 - val_loss: 1.0622 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5893 - loss: 1.1327 - val_accuracy: 0.6146 - val_loss: 1.0670 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5914 - loss: 1.1254 - val_accuracy: 0.5929 - val_loss: 1.1134 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5968 - loss: 1.1078 - val_accuracy: 0.6293 - val_loss: 1.0308 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6010 - loss: 1.1007 - val_accuracy: 0.6007 - val_loss: 1.0856 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6061 - loss: 1.0880 - val_accuracy: 0.6075 - val_loss: 1.0721 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6093 - loss: 1.0801 - val_accuracy: 0.6244 - val_loss: 1.0353 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6132 - loss: 1.0657 - val_accuracy: 0.6275 - val_loss: 1.0232 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6144 - loss: 1.0621 - val_accuracy: 0.6362 - val_loss: 1.0052 学習時間:162.6秒 パラメータ数:80,106 test_accuracy:0.6424 === B_units128 === Model: "B_units128" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d_2 (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_2 (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d_1 │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_3 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 93,450 (365.04 KB) Trainable params: 93,450 (365.04 KB) Non-trainable params: 0 (0.00 B) None Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 9ms/step - accuracy: 0.2663 - loss: 1.9463 - val_accuracy: 0.3375 - val_loss: 1.7807 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3690 - loss: 1.6809 - val_accuracy: 0.4033 - val_loss: 1.6026 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4252 - loss: 1.5554 - val_accuracy: 0.4298 - val_loss: 1.5469 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4623 - loss: 1.4621 - val_accuracy: 0.4844 - val_loss: 1.4041 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4847 - loss: 1.4022 - val_accuracy: 0.5055 - val_loss: 1.3468 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5048 - loss: 1.3569 - val_accuracy: 0.5066 - val_loss: 1.3457 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5208 - loss: 1.3115 - val_accuracy: 0.5284 - val_loss: 1.2915 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5356 - loss: 1.2765 - val_accuracy: 0.5531 - val_loss: 1.2250 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5426 - loss: 1.2501 - val_accuracy: 0.5574 - val_loss: 1.2090 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5567 - loss: 1.2132 - val_accuracy: 0.5641 - val_loss: 1.1929 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5630 - loss: 1.1951 - val_accuracy: 0.5614 - val_loss: 1.1989 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.5740 - loss: 1.1707 - val_accuracy: 0.5850 - val_loss: 1.1332 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5837 - loss: 1.1482 - val_accuracy: 0.5959 - val_loss: 1.1046 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5938 - loss: 1.1271 - val_accuracy: 0.6071 - val_loss: 1.0858 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5998 - loss: 1.1111 - val_accuracy: 0.6079 - val_loss: 1.0811 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6078 - loss: 1.0899 - val_accuracy: 0.6097 - val_loss: 1.0675 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6140 - loss: 1.0739 - val_accuracy: 0.6007 - val_loss: 1.1002 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6190 - loss: 1.0597 - val_accuracy: 0.6143 - val_loss: 1.0475 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6258 - loss: 1.0424 - val_accuracy: 0.6092 - val_loss: 1.0754 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6300 - loss: 1.0265 - val_accuracy: 0.6350 - val_loss: 1.0213 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6378 - loss: 1.0122 - val_accuracy: 0.6370 - val_loss: 1.0097 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6405 - loss: 0.9970 - val_accuracy: 0.6185 - val_loss: 1.0399 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6475 - loss: 0.9797 - val_accuracy: 0.6372 - val_loss: 0.9916 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6500 - loss: 0.9747 - val_accuracy: 0.6483 - val_loss: 0.9763 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6589 - loss: 0.9519 - val_accuracy: 0.6539 - val_loss: 0.9520 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6614 - loss: 0.9498 - val_accuracy: 0.6630 - val_loss: 0.9372 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6641 - loss: 0.9349 - val_accuracy: 0.6522 - val_loss: 0.9708 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6697 - loss: 0.9239 - val_accuracy: 0.6596 - val_loss: 0.9509 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6709 - loss: 0.9162 - val_accuracy: 0.6381 - val_loss: 1.0029 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6777 - loss: 0.9010 - val_accuracy: 0.6496 - val_loss: 0.9668 学習時間:124.8秒 パラメータ数:93,450 test_accuracy:0.6543 === C_units512 === Model: "C_units512" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d_4 (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_4 (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_5 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_5 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d_2 │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_4 (Dense) │ (None, 512) │ 66,048 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_2 (Dropout) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_5 (Dense) │ (None, 10) │ 5,130 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 146,826 (573.54 KB) Trainable params: 146,826 (573.54 KB) Non-trainable params: 0 (0.00 B) None Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 8ms/step - accuracy: 0.2776 - loss: 1.8955 - val_accuracy: 0.3643 - val_loss: 1.6901 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4006 - loss: 1.6112 - val_accuracy: 0.4183 - val_loss: 1.5506 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4582 - loss: 1.4763 - val_accuracy: 0.4701 - val_loss: 1.4621 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4938 - loss: 1.3845 - val_accuracy: 0.4959 - val_loss: 1.3682 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5177 - loss: 1.3210 - val_accuracy: 0.5216 - val_loss: 1.2977 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5332 - loss: 1.2739 - val_accuracy: 0.5498 - val_loss: 1.2398 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5520 - loss: 1.2292 - val_accuracy: 0.5365 - val_loss: 1.2607 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5703 - loss: 1.1864 - val_accuracy: 0.5802 - val_loss: 1.1571 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5811 - loss: 1.1505 - val_accuracy: 0.5785 - val_loss: 1.1487 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5950 - loss: 1.1160 - val_accuracy: 0.5981 - val_loss: 1.0954 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6072 - loss: 1.0873 - val_accuracy: 0.6068 - val_loss: 1.1063 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6147 - loss: 1.0620 - val_accuracy: 0.6142 - val_loss: 1.0655 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6274 - loss: 1.0285 - val_accuracy: 0.6154 - val_loss: 1.0750 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6343 - loss: 1.0117 - val_accuracy: 0.6244 - val_loss: 1.0318 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6423 - loss: 0.9871 - val_accuracy: 0.6435 - val_loss: 0.9880 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6514 - loss: 0.9701 - val_accuracy: 0.6442 - val_loss: 0.9902 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6601 - loss: 0.9486 - val_accuracy: 0.6582 - val_loss: 0.9460 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6636 - loss: 0.9300 - val_accuracy: 0.6595 - val_loss: 0.9439 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6732 - loss: 0.9110 - val_accuracy: 0.6596 - val_loss: 0.9619 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6785 - loss: 0.8958 - val_accuracy: 0.6722 - val_loss: 0.9117 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6866 - loss: 0.8769 - val_accuracy: 0.6690 - val_loss: 0.9243 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6896 - loss: 0.8651 - val_accuracy: 0.6575 - val_loss: 0.9420 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6955 - loss: 0.8462 - val_accuracy: 0.6642 - val_loss: 0.9367 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7010 - loss: 0.8323 - val_accuracy: 0.6767 - val_loss: 0.9098 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.7062 - loss: 0.8184 - val_accuracy: 0.6889 - val_loss: 0.8665 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7146 - loss: 0.8024 - val_accuracy: 0.6879 - val_loss: 0.8766 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7170 - loss: 0.7910 - val_accuracy: 0.6933 - val_loss: 0.8542 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.7223 - loss: 0.7759 - val_accuracy: 0.7063 - val_loss: 0.8354 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.7250 - loss: 0.7628 - val_accuracy: 0.7078 - val_loss: 0.8349 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.7327 - loss: 0.7468 - val_accuracy: 0.7072 - val_loss: 0.8228 学習時間:125.0秒 パラメータ数:146,826 test_accuracy:0.7043
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('dense_units_comparison.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(過学習の乖離)────────────
fig2, axes2 = plt.subplots(3, 1, figsize=(7, 14))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(f'{label}')
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('dense_units_overfit.png', dpi=150)
plt.show()
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>10} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>12}")
print("-" * 58)
for label in ['units32', 'units128', 'units512']:
val_acc = histories[label].history['val_accuracy'][-1]
test_acc = scores[label][1]
t = times[label]
p = params[label]
print(f"{label:>10} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f} | {p:>12,}")
print("-" * 58)
最終結果サマリー
===== 最終結果サマリー ===== Pattern | Val Acc | Test Acc | Time(s) | Params ---------------------------------------------------------- units32 | 0.6362 | 0.6424 | 162.6 | 80,106 units128 | 0.6496 | 0.6543 | 124.8 | 93,450 units512 | 0.7072 | 0.7043 | 125.0 | 146,826 ----------------------------------------------------------
実験結果
精度グラフ
損失グラフ
unit32
unit128
unit512
| パターン | 最終 val_accuracy | 最終 test_accuracy | パラメータ数 | 学習時間 |
|---|---|---|---|---|
| A:Dense(32) | 63.62% | 64.24% | 80,106 | 162.6秒 |
| B:Dense(128) | 64.96% | 65.43% | 93,450 | 124.8秒 |
| C:Dense(512) | 70.72% | 70.43% | 146,826 | 125.0秒 |
考察
① ユニット数が増えるほど精度が一貫して上がった
今回の実験ではユニット数を増やすほど精度が明確に向上しました。
| 比較 | 精度の向上 | パラメータ増加 | 学習時間の変化 |
|---|---|---|---|
| 32 → 128 | +1.19% | 約1.17倍(80,106 → 93,450) | −37.8秒(約23%減) |
| 128 → 512 | +5.00% | 約1.57倍(93,450 → 146,826) | +0.2秒(ほぼ同じ) |
特に注目すべきは 128→512 の精度向上(+5.00%)に対して学習時間がほぼ変わらない(+0.2秒)点です。パラメータ数は約1.57倍増えていますが、GAP使用モデルではDense層のパラメータ増加がそれほど大きくないため、計算コストへの影響が小さかったと考えられます。
② GAP使用モデルではパラメータ差が小さい
今回のモデルはGAPを使っているため、Dense層への入力は常に128次元に固定されています。ユニット数を32・128・512と変えても、パラメータ数の差は最大で約1.83倍(80,106→146,826)に留まっています。
もしFlattenを使っていた場合、pool_size=(2,2)×2回後の特徴マップは8×8×128=8,192次元になるため、Dense(512)のパラメータ数は(8,192+1)×512 ≒ 419万と桁違いに膨らみます。GAPがDense層のパラメータ爆発を防いでいることがこの比較から改めて確認できます。
③ Dense(32)はパラメータが少ないのに学習時間が最長
Dense(32)のパラメータ数(80,106)は最少ですが、学習時間は162.6秒と3パターン中最長でした。Dense(128)・Dense(512)と約38秒の差があります。
これはGPUの処理効率に関係しています。GPUは大きな行列演算を並列処理することが得意ですが、ユニット数が小さすぎると並列処理の恩恵が十分に得られず、オーバーヘッドが相対的に大きくなる場合があります。「パラメータが少ない=学習が速い」という思い込みは修正が必要です。
④ Dense(512)では過学習はどの程度出たか
Dense(512)はパラメータ数が最多のため過学習のリスクが最も高いパターンです。Dropout=0.2との組み合わせでtrain_lossとval_lossの乖離がどの程度出たかをグラフで確認してください。Dense(32)・(128)と比べて乖離が大きければ、Dropout率を上げる(0.3〜0.5)か、ユニット数を抑える判断が必要です。
関連記事もあわせてどうぞ:
- MNISTでのDense層ユニット数比較 → Dense層のユニット数を変えると何が起きるか?CNNでの実験結果
- GAPとFlattenの比較 → Global Average Pooling vs Flatten|CNNの最終層、どっちが精度・速度で有利か?【Keras実験】
- model.summary()の読み方 → Kerasのmodel.summary()の読み方を徹底解説|パラメータ数の計算方法【初心者向け】
- Dropout率の比較実験 → Dropoutの割合(0.0 vs 0.2 vs 0.5)を変えると過学習はどう変わる?【Keras×CIFAR-10実験】






0 件のコメント:
コメントを投稿