損失関数を変えると精度は変わるのか?
KerasでCNNを組むとき、損失関数に sparse_categorical_crossentropy を使うのが"お約束"になっていませんか?
でも categorical_crossentropy との違いや、クラス不均衡対策として注目される Focal Loss を試したことはあるでしょうか。
今回はGoogle ColabとCIFAR-10を使い、3種類の損失関数を完全同一条件で比較しました。
- sparse_categorical と categorical_crossentropy の使い分け
- Focal Loss の仕組みとKerasへの実装方法
- CIFAR-10での精度・損失・学習速度の実験結果
3つの損失関数のおさらい
| 損失関数 | ラベル形式 | 特徴 |
|---|---|---|
| sparse_categorical_crossentropy | 整数ラベル 例: 3 |
最も一般的。ラベルをone-hotに変換する手間が不要 |
| categorical_crossentropy | one-hotラベル 例: [0,0,0,1,…] |
ラベルを事前にone-hot変換する必要あり。内部の計算はsparseと同等 |
| Focal Loss | 整数 or one-hot (実装次第) |
クラス不均衡に強い。簡単なサンプルへの損失を小さくし、難しいサンプルに集中させる |
sparse と categorical の違いは「ラベル形式だけ」で、数値的には同じ計算をしています。一方、Focal Lossは損失の計算式そのものが異なります。
Focal Loss の式
通常のCross Entropy(\(CE\) )に対して、Focal Loss(\(FL\) )は予測確率\(p_t\) に応じた重みを掛け合わせます。
$$FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)$$
- 正しく分類できている(
pが高い)サンプル → 重みが小さくなり損失が減る - 間違えやすい(
pが低い)サンプル → 重みが大きくなり損失が増える
γ(ガンマ)= 0 のときは通常のCross Entropyと同じになります。一般的には γ = 2 がよく使われます。
以前は
tensorflow_addons.losses.SigmoidFocalCrossEntropy が使われていましたが、tensorflow-addons は2024年以降メンテナンス終了となりました。本記事ではKerasのみでFocal Lossをカスタム実装します。
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。損失関数の種類以外の条件はすべて同一にし、損失関数の影響だけを取り出します。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 42 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (25.9 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122354 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 72.3 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備・Focal Loss実装
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
# ── データ準備 ──────────────────────────────────────────
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# categorical_crossentropy 用の one-hot ラベル
y_train_oh = keras.utils.to_categorical(y_train, 10)
y_test_oh = keras.utils.to_categorical(y_test, 10)
NUM_CLASSES = 10
# ── Focal Loss カスタム実装 ──────────────────────────────
class FocalLoss(keras.losses.Loss):
"""Focal Loss(gamma=2, alpha=0.25)
Lin et al., "Focal Loss for Dense Object Detection", 2017
"""
def __init__(self, gamma=2.0, alpha=0.25, **kwargs):
super().__init__(**kwargs)
self.gamma = gamma
self.alpha = alpha
def call(self, y_true, y_pred):
# y_true: (batch, 10) one-hot
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0)
ce = -y_true * tf.math.log(y_pred) # cross entropy per class
p_t = tf.reduce_sum(y_true * y_pred, axis=-1, keepdims=True)
focal_weight = self.alpha * tf.pow(1.0 - p_t, self.gamma)
loss = focal_weight * ce
return tf.reduce_mean(tf.reduce_sum(loss, axis=-1))
print("Focal Loss 実装完了")
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 32s 0us/step Focal Loss 実装完了
モデル構築関数
def build_model(name):
# 日本語・記号を含むモデル名をサニタイズ
sanitized_name = name.replace(':', '_')
# name_scope で名前の衝突を防ぐ
with tf.name_scope(sanitized_name):
model = keras.Sequential([
keras.layers.Input(shape=(32, 32, 3)),
keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(NUM_CLASSES, activation='softmax'),
], name=sanitized_name)
return model
def compile_and_fit(model, loss_fn, x_tr, y_tr):
model.compile(optimizer='adam',
loss=loss_fn,
metrics=['accuracy'])
start = time.time()
history = model.fit(x_tr, y_tr, epochs=30, batch_size=64,
validation_split=0.2, verbose=1)
return history, time.time() - start
3パターンの学習実行
# (label, loss_fn, train用ラベル, test用ラベル)
configs = [
('A_sparse_categorical', 'sparse_categorical_crossentropy', y_train, y_test),
('B_categorical', 'categorical_crossentropy', y_train_oh, y_test_oh),
('C_Focal_Loss', FocalLoss(), y_train_oh, y_test_oh),
]
histories, times, scores = {}, {}, {}
for label, loss_fn, y_tr, y_te in configs:
keras.backend.clear_session() # セッションをクリアして名前の競合を防ぐ
print(f"\n=== {label} ===")
model = build_model(label)
h, t = compile_and_fit(model, loss_fn, x_train, y_tr)
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
s = model.evaluate(x_test, y_te, verbose=0)
histories[label] = h
times[label] = t
scores[label] = s
print(f"学習時間:{t:.1f}秒 test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_sparse_categorical === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 8ms/step - accuracy: 0.2723 - loss: 1.9213 - val_accuracy: 0.3584 - val_loss: 1.7292 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.3853 - loss: 1.6433 - val_accuracy: 0.4375 - val_loss: 1.5412 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4437 - loss: 1.5072 - val_accuracy: 0.4754 - val_loss: 1.4214 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4790 - loss: 1.4194 - val_accuracy: 0.4809 - val_loss: 1.4257 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4992 - loss: 1.3722 - val_accuracy: 0.5016 - val_loss: 1.3404 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5190 - loss: 1.3189 - val_accuracy: 0.5420 - val_loss: 1.2623 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5315 - loss: 1.2860 - val_accuracy: 0.5368 - val_loss: 1.2685 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5457 - loss: 1.2508 - val_accuracy: 0.5521 - val_loss: 1.2439 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.5571 - loss: 1.2223 - val_accuracy: 0.5779 - val_loss: 1.1527 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5685 - loss: 1.1899 - val_accuracy: 0.5761 - val_loss: 1.1507 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5763 - loss: 1.1648 - val_accuracy: 0.5779 - val_loss: 1.1480 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5850 - loss: 1.1425 - val_accuracy: 0.5865 - val_loss: 1.1395 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5912 - loss: 1.1251 - val_accuracy: 0.6085 - val_loss: 1.0799 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6051 - loss: 1.0996 - val_accuracy: 0.6147 - val_loss: 1.0627 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6114 - loss: 1.0763 - val_accuracy: 0.6135 - val_loss: 1.0641 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6177 - loss: 1.0594 - val_accuracy: 0.6237 - val_loss: 1.0290 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6235 - loss: 1.0453 - val_accuracy: 0.6322 - val_loss: 1.0055 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6302 - loss: 1.0273 - val_accuracy: 0.6234 - val_loss: 1.0410 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6382 - loss: 1.0083 - val_accuracy: 0.6375 - val_loss: 1.0170 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6380 - loss: 1.0025 - val_accuracy: 0.6559 - val_loss: 0.9589 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6461 - loss: 0.9808 - val_accuracy: 0.6533 - val_loss: 0.9675 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6492 - loss: 0.9683 - val_accuracy: 0.6544 - val_loss: 0.9762 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6565 - loss: 0.9585 - val_accuracy: 0.6528 - val_loss: 0.9475 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6623 - loss: 0.9453 - val_accuracy: 0.6519 - val_loss: 0.9567 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6633 - loss: 0.9342 - val_accuracy: 0.6482 - val_loss: 0.9647 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6696 - loss: 0.9239 - val_accuracy: 0.6651 - val_loss: 0.9215 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6737 - loss: 0.9133 - val_accuracy: 0.6766 - val_loss: 0.9030 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6760 - loss: 0.9018 - val_accuracy: 0.6755 - val_loss: 0.9041 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6829 - loss: 0.8872 - val_accuracy: 0.6789 - val_loss: 0.9047 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6874 - loss: 0.8783 - val_accuracy: 0.6806 - val_loss: 0.8926 学習時間:125.4秒 test_accuracy:0.6739 === B_categorical === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 9ms/step - accuracy: 0.2661 - loss: 1.9221 - val_accuracy: 0.3552 - val_loss: 1.7029 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 7s 6ms/step - accuracy: 0.3701 - loss: 1.6788 - val_accuracy: 0.3927 - val_loss: 1.6269 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4213 - loss: 1.5762 - val_accuracy: 0.4567 - val_loss: 1.4989 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4600 - loss: 1.4792 - val_accuracy: 0.4849 - val_loss: 1.4361 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4878 - loss: 1.4062 - val_accuracy: 0.5098 - val_loss: 1.3508 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5071 - loss: 1.3535 - val_accuracy: 0.5269 - val_loss: 1.2948 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5229 - loss: 1.3094 - val_accuracy: 0.5380 - val_loss: 1.2710 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5352 - loss: 1.2710 - val_accuracy: 0.5484 - val_loss: 1.2388 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5465 - loss: 1.2463 - val_accuracy: 0.5629 - val_loss: 1.1991 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5597 - loss: 1.2166 - val_accuracy: 0.5676 - val_loss: 1.1869 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5682 - loss: 1.1898 - val_accuracy: 0.5819 - val_loss: 1.1575 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5732 - loss: 1.1696 - val_accuracy: 0.5888 - val_loss: 1.1318 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5844 - loss: 1.1461 - val_accuracy: 0.5954 - val_loss: 1.1079 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5942 - loss: 1.1213 - val_accuracy: 0.5791 - val_loss: 1.1435 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6029 - loss: 1.1003 - val_accuracy: 0.6038 - val_loss: 1.0855 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6079 - loss: 1.0837 - val_accuracy: 0.6151 - val_loss: 1.0622 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6184 - loss: 1.0580 - val_accuracy: 0.6201 - val_loss: 1.0569 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6242 - loss: 1.0480 - val_accuracy: 0.6270 - val_loss: 1.0323 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6313 - loss: 1.0301 - val_accuracy: 0.6216 - val_loss: 1.0520 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6356 - loss: 1.0168 - val_accuracy: 0.6277 - val_loss: 1.0399 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6384 - loss: 1.0070 - val_accuracy: 0.6371 - val_loss: 1.0001 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6449 - loss: 0.9860 - val_accuracy: 0.6501 - val_loss: 0.9698 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6515 - loss: 0.9762 - val_accuracy: 0.6442 - val_loss: 0.9822 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6542 - loss: 0.9581 - val_accuracy: 0.6581 - val_loss: 0.9639 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6617 - loss: 0.9427 - val_accuracy: 0.6475 - val_loss: 0.9697 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6654 - loss: 0.9318 - val_accuracy: 0.6566 - val_loss: 0.9535 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6752 - loss: 0.9160 - val_accuracy: 0.6667 - val_loss: 0.9334 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6731 - loss: 0.9124 - val_accuracy: 0.6537 - val_loss: 0.9821 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6777 - loss: 0.9013 - val_accuracy: 0.6603 - val_loss: 0.9338 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6791 - loss: 0.8909 - val_accuracy: 0.6783 - val_loss: 0.9033 学習時間:126.0秒 test_accuracy:0.6753 === C_Focal_Loss === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 8ms/step - accuracy: 0.2670 - loss: 0.3579 - val_accuracy: 0.3643 - val_loss: 0.2968 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 7s 7ms/step - accuracy: 0.3703 - loss: 0.2884 - val_accuracy: 0.4208 - val_loss: 0.2664 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4349 - loss: 0.2574 - val_accuracy: 0.4741 - val_loss: 0.2368 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4753 - loss: 0.2358 - val_accuracy: 0.4851 - val_loss: 0.2263 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4937 - loss: 0.2238 - val_accuracy: 0.5114 - val_loss: 0.2147 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5102 - loss: 0.2139 - val_accuracy: 0.5220 - val_loss: 0.2057 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5255 - loss: 0.2058 - val_accuracy: 0.5411 - val_loss: 0.1966 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5343 - loss: 0.2010 - val_accuracy: 0.5493 - val_loss: 0.1906 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5487 - loss: 0.1930 - val_accuracy: 0.5640 - val_loss: 0.1857 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5566 - loss: 0.1872 - val_accuracy: 0.5747 - val_loss: 0.1783 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5653 - loss: 0.1835 - val_accuracy: 0.5727 - val_loss: 0.1760 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5773 - loss: 0.1777 - val_accuracy: 0.5867 - val_loss: 0.1729 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5869 - loss: 0.1727 - val_accuracy: 0.5887 - val_loss: 0.1733 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5907 - loss: 0.1693 - val_accuracy: 0.5974 - val_loss: 0.1644 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5969 - loss: 0.1669 - val_accuracy: 0.6093 - val_loss: 0.1608 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6028 - loss: 0.1620 - val_accuracy: 0.5989 - val_loss: 0.1638 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6137 - loss: 0.1586 - val_accuracy: 0.5893 - val_loss: 0.1725 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6194 - loss: 0.1558 - val_accuracy: 0.6242 - val_loss: 0.1530 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6254 - loss: 0.1518 - val_accuracy: 0.6249 - val_loss: 0.1534 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6298 - loss: 0.1490 - val_accuracy: 0.6244 - val_loss: 0.1530 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6344 - loss: 0.1462 - val_accuracy: 0.6310 - val_loss: 0.1472 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6420 - loss: 0.1426 - val_accuracy: 0.6373 - val_loss: 0.1451 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6477 - loss: 0.1407 - val_accuracy: 0.6429 - val_loss: 0.1437 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6516 - loss: 0.1377 - val_accuracy: 0.6388 - val_loss: 0.1429 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6559 - loss: 0.1354 - val_accuracy: 0.6491 - val_loss: 0.1402 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6597 - loss: 0.1339 - val_accuracy: 0.6525 - val_loss: 0.1379 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 6s 7ms/step - accuracy: 0.6595 - loss: 0.1317 - val_accuracy: 0.6383 - val_loss: 0.1405 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6668 - loss: 0.1292 - val_accuracy: 0.6567 - val_loss: 0.1366 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6715 - loss: 0.1271 - val_accuracy: 0.6735 - val_loss: 0.1308 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6745 - loss: 0.1251 - val_accuracy: 0.6636 - val_loss: 0.1366 学習時間:129.1秒 test_accuracy:0.6603
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(fontsize=9); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('loss_comparison.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(過学習の乖離)────────────────
fig2, axes2 = plt.subplots(3, 1, figsize=(7, 14))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(label)
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('loss_overfit.png', dpi=150)
plt.show()
# ── サマリー ─────────────────────────────────────────────
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>22} | {'Val Acc':>8} | {'Test Acc':>9} | {'Val Loss':>9} | {'Time(s)':>8}")
print("-" * 70)
for label, h in histories.items():
val_acc = h.history['val_accuracy'][-1]
val_loss = h.history['val_loss'][-1]
test_acc = scores[label][1]
t = times[label]
print(f"{label:>22} | {val_acc:>8.4f} | {test_acc:>9.4f} | {val_loss:>9.4f} | {t:>8.1f}")
print("-" * 70)
最終結果サマリー
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Val Loss | Time(s)
----------------------------------------------------------------------
A_sparse_categorical | 0.6806 | 0.6739 | 0.8926 | 125.4
B_categorical | 0.6783 | 0.6753 | 0.9033 | 126.0
C_Focal_Loss | 0.6636 | 0.6603 | 0.1366 | 129.1
----------------------------------------------------------------------
実験結果
実験はGoogle Colab(T4 GPU)、CIFAR-10、エポック数30・バッチサイズ64の条件で実施しました。
精度グラフ
損失グラフ
A_sparse_categorical
B_categorical
C_Focal_Loss
| パターン | val_accuracy | test_accuracy | val_loss | 学習時間 |
|---|---|---|---|---|
| A:sparse_categorical_crossentropy | 68.06% | 67.39% | 0.8926 | 125.4秒 |
| B:categorical_crossentropy | 67.83% | 67.53% | 0.9033 | 126.0秒 |
| C:Focal Loss(γ=2, α=0.25) | 66.36% | 66.03% | 0.1366 ※ | 129.1秒 |
※ Focal Lossのval_lossは損失のスケール自体が異なるため、A・Bとの直接比較はできません(後述)。
考察
① sparse と categorical は「ほぼ同等」を確認
A(sparse)とB(categorical)のtest_accuracyの差はわずか0.14%(67.39% vs 67.53%)で、誤差の範囲内です。val_lossもほぼ同水準(0.8926 vs 0.9033)で、ラベルの形式が違うだけで計算は同一であることが実験でも裏付けられました。ラベルの形式以外に本質的な違いはなく、今回の差は学習のばらつきの範囲と考えられます。
ラベルが整数のまま使えるsparseは、to_categorical() の変換が不要でコードがシンプルになります。特別な理由がなければ sparse_categorical_crossentropy を使うのがベストです。
② Focal Loss は均衡データでは精度が下がった
C(Focal Loss)のtest_accuracyは66.03%で、ベースライン(A:67.39%)より約1.4%低下しました。CIFAR-10はクラスあたり5,000枚・10クラスの完全均衡データセットです。Focal Lossが「難しいサンプルに集中する」という性質が、均衡データでは必ずしもメリットにならなかった結果です。
Focal Lossは「簡単に分類できるサンプルの損失を小さくする」ため、均衡データでは一部のサンプルを意図的に無視することになります。CIFAR-10ではどのクラスも同等の難易度で出現するため、この重み付けが逆効果になったと考えられます。
③ Focal Loss の val_loss を他と比較してはいけない
今回の結果でもっとも注意が必要なのが、Focal Lossのval_loss(0.1366)です。AやBのval_loss(約0.90)と比べて極端に小さく見えますが、これは「精度が高い」ことを意味しません。
Focal Loss ではalpha と (1 - p_t)^γ によって損失のスケールが通常のクロスエントロピーとは変化します。Focal Lossはハイパーパラメータによって損失のスケール自体が変化するため、A・Bと損失値を直接比較することに意味はありません。 最終的な評価は必ず精度(accuracy)で行う必要があります。
| 比較 | test_accuracy の差 |
|---|---|
| A(sparse)vs B(categorical) | +0.14%(実質同等) |
| A(sparse)vs C(Focal Loss) | −1.36%(Focal Lossが下) |
④ Focal Loss が真価を発揮するのはどんな場面か
Focal Lossはそもそも物体検出(RetinaNet)のために提案された損失関数です。背景(簡単なサンプル)が圧倒的に多く、物体(難しいサンプル)が少ないという極端な不均衡を扱うために設計されています。
CIFAR-10のような均衡な多クラス分類タスクでは、standard crossentropyで十分です。Focal Lossを試す価値があるのは、以下のようなケースです。
- クラス間のサンプル数に大きな偏りがあるデータセット(例:10:1以上の不均衡)
- 物体検出タスクで背景クラスが大多数を占める場合
- 医療画像など「稀なクラスを見逃したくない」場面
まとめ
| 損失関数 | ラベル形式 | クラス均衡データ | クラス不均衡データ | 実験精度 |
|---|---|---|---|---|
| sparse_categorical_crossentropy | 整数 | ✅ 最もシンプル・高速 | △ 標準的 | 67.39% |
| categorical_crossentropy | one-hot | ✅ sparseと同等 | △ 標準的 | 67.53% |
| Focal Loss(γ=2, α=0.25) | one-hot | △ 均衡データでは逆効果 | ✅ 不均衡に強い | 66.03% |
sparse と categorical は実質同じなので、ラベルが整数ならsparseを選べばOKです。Focal Lossは均衡データでは精度が下がることが今回の実験で確認されました。Focal Lossはクラス不均衡が深刻な場面(物体検出・医療画像など)に温存しておくのが賢明です。
関連記事もあわせてどうぞ:
- 活性化関数の比較 → ReLUとGELUを比較してみた:Activation関数で精度は変わる?【Keras×CIFAR-10実験】
- 正則化手法の比較 → L2・Dropout・AdamWを組み合わせると過学習は防げる?【Keras×CIFAR-10実験】
- AdamW vs Adam 比較 → AdamWとAdamの違いを実験で確かめる【Keras】
- BatchNorm vs Dropout 比較 → BatchNormalizationとDropout、どちらが過学習に強い?【Keras実験】






0 件のコメント:
コメントを投稿