KerasでCNNを組むとき、活性化関数は「とりあえずReLU」にしていませんか?
近年、BERT・GPT・Vision Transformerなど最新モデルの多くでGELUが採用されています。画像分類CNNでも差が出るのでしょうか。今回はGoogle ColabとCIFAR-10を使い、活性化関数をReLUとGELUの2パターンで比較しました。
なお、Activation関数全般の比較(ReLU・ELU・Swish・Sigmoidなど)は → Activation関数を変えると精度はどう変わる?(relu vs gelu vs elu)【Keras×CIFAR-10実験】 をご覧ください。本記事はReLU vs GELUに絞り、仕組みから実験結果まで踏み込んで比較します。
📘 この記事でわかること
- ReLUとGELUの仕組みの違い
- 活性化関数を変えると精度・損失・学習速度がどう変わるか
- CIFAR-10+GAP構成でどちらを使うべきかの判断基準
ReLUとGELUの違い
まず2つの活性化関数を簡単に整理します。
| 項目 | ReLU | GELU |
|---|---|---|
| 正式名称 | Rectified Linear Unit | Gaussian Error Linear Unit |
| 式(概略) | max(0, x) | x · Φ(x) ※Φは標準正規分布のCDF |
| 負の入力への出力 | 完全に0(ハードなゲート) | 小さな負値を少し通す(ソフトなゲート) |
| 計算コスト | 低い | ReLUより高い(近似計算を使用) |
| 採用例 | VGG、ResNet(初期)など | BERT、GPT、ViT、EfficientNetV2など |
ReLUは x > 0 のときそのまま通し、それ以下は0に丸めます。シンプルで高速な反面、「死んだニューロン(Dying ReLU)」問題が起こることがあります。
GELUはガウス分布の累積分布関数(CDF)を使い、入力をソフトにゲートします。負の値も確率的に少し通すため、勾配消失が起きにくく、深いネットワークで安定しやすいとされています。Kerasでは近似式で実装されており、計算コストは増えますがReLUと同じ感覚で使えます。
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。活性化関数以外の条件はすべて同一にして、活性化関数の影響だけを取り出します。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 42 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 1s (6,998 kB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122354 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 57.9 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備・モデル構築関数
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
def build_model(activation, name):
return keras.Sequential([
keras.layers.Input(shape=(32, 32, 3)),
keras.layers.Conv2D(64, (3, 3), activation=activation, padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(128, (3, 3), activation=activation, padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(128, activation=activation), # ← ここだけ変える
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax'),
], name=name)
def compile_and_fit(model):
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
start = time.time()
history = model.fit(x_train, y_train, epochs=30, batch_size=64,
validation_split=0.2, verbose=1)
return history, time.time() - start
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 9s 0us/step
2パターンの学習実行
configs = [('relu', 'A_relu'), ('gelu', 'B_gelu')]
histories, times, scores, params = {}, {}, {}, {}
for activation, name in configs:
print(f"\n=== {name} ===")
model = build_model(activation, name)
print(model.summary())
h, t = compile_and_fit(model)
s = model.evaluate(x_test, y_test, verbose=0)
label = name.split('_')[1]
histories[label] = h
times[label] = t
scores[label] = s
params[label] = model.count_params()
print(f"学習時間:{t:.1f}秒 パラメータ数:{model.count_params():,} test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_relu === Model: "A_relu" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 93,450 (365.04 KB) Trainable params: 93,450 (365.04 KB) Non-trainable params: 0 (0.00 B) None Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 20s 18ms/step - accuracy: 0.2608 - loss: 1.9359 - val_accuracy: 0.3262 - val_loss: 1.7449 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 7s 11ms/step - accuracy: 0.3548 - loss: 1.7059 - val_accuracy: 0.3922 - val_loss: 1.6394 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 8ms/step - accuracy: 0.4109 - loss: 1.5894 - val_accuracy: 0.4492 - val_loss: 1.5052 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4545 - loss: 1.4871 - val_accuracy: 0.4582 - val_loss: 1.4747 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.4798 - loss: 1.4172 - val_accuracy: 0.5032 - val_loss: 1.3732 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5041 - loss: 1.3644 - val_accuracy: 0.5117 - val_loss: 1.3386 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5189 - loss: 1.3194 - val_accuracy: 0.5284 - val_loss: 1.2783 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5326 - loss: 1.2777 - val_accuracy: 0.5489 - val_loss: 1.2318 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5451 - loss: 1.2476 - val_accuracy: 0.5611 - val_loss: 1.2123 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5582 - loss: 1.2154 - val_accuracy: 0.5790 - val_loss: 1.1668 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5694 - loss: 1.1857 - val_accuracy: 0.5803 - val_loss: 1.1494 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5826 - loss: 1.1588 - val_accuracy: 0.5716 - val_loss: 1.1540 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5913 - loss: 1.1326 - val_accuracy: 0.5994 - val_loss: 1.0933 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5968 - loss: 1.1185 - val_accuracy: 0.6048 - val_loss: 1.0789 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6037 - loss: 1.0931 - val_accuracy: 0.6150 - val_loss: 1.0596 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6138 - loss: 1.0686 - val_accuracy: 0.6142 - val_loss: 1.0634 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6204 - loss: 1.0538 - val_accuracy: 0.6189 - val_loss: 1.0463 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6283 - loss: 1.0306 - val_accuracy: 0.6248 - val_loss: 1.0288 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6334 - loss: 1.0181 - val_accuracy: 0.6249 - val_loss: 1.0474 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6376 - loss: 1.0040 - val_accuracy: 0.6385 - val_loss: 1.0097 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6452 - loss: 0.9830 - val_accuracy: 0.6539 - val_loss: 0.9648 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.6511 - loss: 0.9715 - val_accuracy: 0.6537 - val_loss: 0.9572 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6532 - loss: 0.9600 - val_accuracy: 0.6534 - val_loss: 0.9618 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6588 - loss: 0.9498 - val_accuracy: 0.6672 - val_loss: 0.9319 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 7s 11ms/step - accuracy: 0.6624 - loss: 0.9398 - val_accuracy: 0.6614 - val_loss: 0.9390 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 13s 20ms/step - accuracy: 0.6701 - loss: 0.9241 - val_accuracy: 0.6564 - val_loss: 0.9428 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 15ms/step - accuracy: 0.6731 - loss: 0.9159 - val_accuracy: 0.6728 - val_loss: 0.9070 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 15ms/step - accuracy: 0.6785 - loss: 0.9024 - val_accuracy: 0.6789 - val_loss: 0.9030 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 13ms/step - accuracy: 0.6819 - loss: 0.8874 - val_accuracy: 0.6659 - val_loss: 0.9151 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 14ms/step - accuracy: 0.6854 - loss: 0.8778 - val_accuracy: 0.6704 - val_loss: 0.9210 学習時間:179.1秒 パラメータ数:93,450 test_accuracy:0.6679 === B_gelu === Model: "B_gelu" ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d_2 (Conv2D) │ (None, 32, 32, 64) │ 1,792 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_2 (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d_1 │ (None, 128) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_3 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘ Total params: 93,450 (365.04 KB) Trainable params: 93,450 (365.04 KB) Non-trainable params: 0 (0.00 B) None Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 10s 10ms/step - accuracy: 0.2686 - loss: 1.9377 - val_accuracy: 0.3235 - val_loss: 1.7765 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 7ms/step - accuracy: 0.3748 - loss: 1.6663 - val_accuracy: 0.4225 - val_loss: 1.5712 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.4339 - loss: 1.5412 - val_accuracy: 0.4356 - val_loss: 1.5097 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4699 - loss: 1.4475 - val_accuracy: 0.4933 - val_loss: 1.3871 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4970 - loss: 1.3785 - val_accuracy: 0.4762 - val_loss: 1.4027 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5178 - loss: 1.3239 - val_accuracy: 0.5426 - val_loss: 1.2753 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5379 - loss: 1.2734 - val_accuracy: 0.5462 - val_loss: 1.2363 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 7s 10ms/step - accuracy: 0.5498 - loss: 1.2360 - val_accuracy: 0.5587 - val_loss: 1.2172 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5629 - loss: 1.2026 - val_accuracy: 0.5765 - val_loss: 1.1674 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5759 - loss: 1.1706 - val_accuracy: 0.5788 - val_loss: 1.1702 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.5870 - loss: 1.1402 - val_accuracy: 0.5962 - val_loss: 1.1148 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5941 - loss: 1.1163 - val_accuracy: 0.6040 - val_loss: 1.1069 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 15ms/step - accuracy: 0.6043 - loss: 1.0938 - val_accuracy: 0.6069 - val_loss: 1.0828 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 9ms/step - accuracy: 0.6136 - loss: 1.0703 - val_accuracy: 0.6074 - val_loss: 1.0921 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6203 - loss: 1.0530 - val_accuracy: 0.6172 - val_loss: 1.0552 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6274 - loss: 1.0315 - val_accuracy: 0.6247 - val_loss: 1.0271 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6333 - loss: 1.0164 - val_accuracy: 0.6292 - val_loss: 1.0168 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6402 - loss: 1.0006 - val_accuracy: 0.6428 - val_loss: 0.9946 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6439 - loss: 0.9857 - val_accuracy: 0.6391 - val_loss: 0.9987 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6522 - loss: 0.9650 - val_accuracy: 0.6438 - val_loss: 0.9867 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6572 - loss: 0.9511 - val_accuracy: 0.6421 - val_loss: 0.9789 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6604 - loss: 0.9435 - val_accuracy: 0.6562 - val_loss: 0.9476 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6674 - loss: 0.9290 - val_accuracy: 0.6419 - val_loss: 0.9867 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6687 - loss: 0.9224 - val_accuracy: 0.6492 - val_loss: 0.9741 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6750 - loss: 0.9082 - val_accuracy: 0.6584 - val_loss: 0.9491 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6758 - loss: 0.8978 - val_accuracy: 0.6732 - val_loss: 0.9224 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6838 - loss: 0.8826 - val_accuracy: 0.6652 - val_loss: 0.9374 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6826 - loss: 0.8804 - val_accuracy: 0.6724 - val_loss: 0.9123 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6895 - loss: 0.8663 - val_accuracy: 0.6609 - val_loss: 0.9468 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6918 - loss: 0.8580 - val_accuracy: 0.6715 - val_loss: 0.9203 学習時間:148.7秒 パラメータ数:93,450 test_accuracy:0.6731
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('relu_vs_gelu_comparison.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(過学習の乖離)────────────
fig2, axes2 = plt.subplots(1, 2, figsize=(14, 5))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(f'{label}:train vs val loss')
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('relu_vs_gelu_overfit.png', dpi=150)
plt.show()
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>8} | {'Val Acc':>8} | {'Test Acc':>9} | {'Time(s)':>8} | {'Params':>10}")
print("-" * 54)
for label in ['relu', 'gelu']:
val_acc = histories[label].history['val_accuracy'][-1]
test_acc = scores[label][1]
t = times[label]
p = params[label]
print(f"{label:>8} | {val_acc:>8.4f} | {test_acc:>9.4f} | {t:>8.1f} | {p:>10,}")
print("-" * 54)
最終結果サマリー
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Time(s) | Params
------------------------------------------------------
relu | 0.6704 | 0.6679 | 179.1 | 93,450
gelu | 0.6715 | 0.6731 | 148.7 | 93,450
------------------------------------------------------
実験結果
精度グラフ
損失グラフ
ReLU
GELU
| パターン | 最終 val_accuracy | 最終 test_accuracy | パラメータ数 | 学習時間 |
|---|---|---|---|---|
| A:ReLU | 67.04% | 66.79% | 93,450 | 179.1秒 |
| B:GELU | 67.15% | 67.31% | 93,450 | 148.7秒 |
考察
① GELUがReLUをわずかに上回った
今回の実験ではGELUがReLUをtest_accuracyで+0.52%上回る結果になりました。パラメータ数はまったく同じ(93,450)なので、この差は純粋に活性化関数の違いによるものです。
| 比較 | val_accuracy の差 | test_accuracy の差 | パラメータ変化 |
|---|---|---|---|
| ReLU → GELU | +0.11%(67.04% → 67.15%) | +0.52%(66.79% → 67.31%) | 変化なし(93,450) |
精度差は小さいものの、パラメータ数ゼロ・コード1行の変更('relu' → 'gelu')で得られる改善としては十分意味があります。
② なぜGELUが上回るのか
ReLUは負の入力を完全に0にするため、一度「死んだ」ニューロンは学習が進んでも復活しません。これが「Dying ReLU問題」で、深いネットワークほど顕在化します。
GELUは負の値に対しても小さな勾配が残ります。この「少しだけ通す」設計が勾配消失を起きにくくし、特にエポックが進んだ後半の学習安定性に寄与していると考えられます。グラフを見ると、中盤以降でGELUが安定して高い精度を維持している様子が確認できます。
③ 学習時間について
今回の実験では、ReLUが179.1秒、GELUが148.7秒とGELUの方が約30秒速い結果になりました。理論上はGELUがtanhを使った近似計算を行うためReLUより遅くなるはずですが、今回は逆の結果になっています。
これはColabのGPU割り当て変動によるものと考えられます。ColabのT4 GPUは他のユーザーとリソースを共有しており、同一コードでも実行タイミングによって学習時間が数十秒単位でブレることがあります。そのため今回の学習時間の差は、活性化関数の違いではなく実行環境のノイズである可能性が高いです。
学習時間を正確に比較したい場合は、同一セッション内で複数回実行して平均を取ることをおすすめします。
④ 過学習への影響はどうか
train_lossとval_lossの乖離グラフを確認すると、ReLUとGELUで過学習の度合いに大きな差はありませんでした。Dropout(0.2)がどちらにも同等に効いており、活性化関数の変更単体が過学習を増やすわけではないことが確認できます。
もしGELUを使ってモデルを深くしたり、Dense層を増やしたりする場合は改めて過学習の確認が必要です。
関連記事もあわせてどうぞ:
- 学習率(Learning Rate)の比較 → Adam学習率の最適値は?3パターンの比較実験【Keras】
- Dropout(過学習対策)の比較 → Dropoutの割合(0.0 vs 0.2 vs 0.5)を変えると過学習はどう変わる?【Keras×CIFAR-10実験】
- Conv2Dの層数の比較 → Conv2Dの層数を変えると精度はどう変わる?(1層 vs 2層 vs 3層)【Keras×CIFAR-10実験】
- Dense層のユニット数の比較 → Dense層のユニット数を変えると精度はどう変わる?(32 vs 128 vs 512)【Keras×CIFAR-10実験】
- Optimizer(最適化手法)の比較 → optimizerを変えると精度はどう変わる?(Adam vs SGD vs RMSprop)【Keras×CIFAR-10実験】
- Batch Size(バッチサイズ)の比較 → バッチサイズを変えると精度はどう変わる?(16 vs 64 vs 256)【Keras×CIFAR-10実験】
- Activation関数の比較 → Activation関数を変えると精度はどう変わる?(relu vs gelu vs elu)【Keras×CIFAR-10実験】
- Pooling設定の比較 → GAPモデルでMaxPoolingのpool_sizeを変えると精度はどう変わる?【Keras実験・予想外の結果】





0 件のコメント:
コメントを投稿