MixUpって本当に効くの?
Data Augmentationの手法として MixUp を聞いたことはありますか?
「2枚の画像を半透明に重ねてラベルも混ぜる」という一見ユニークな手法で、2018年に提案されて以来、画像分類の精度向上手法として広く使われています。
今回はGoogle ColabとCIFAR-10を使い、MixUpあり vs なしを完全同一条件で比較しました。実装が少し凝っているMixUpを、tf.data パイプラインでどう組み込むかもあわせて解説します。
- MixUpの仕組みと「なぜ精度が上がるのか」の直感的な理解
- KerasとTensorFlowのみでMixUpを実装する方法(tf.py_function を使った安全な実装)
- CIFAR-10での精度・過学習への効果の実験結果と考察
MixUpとは何か
MixUpは2枚の画像とそのラベルをランダムな比率 \( \lambda \) で混ぜ合わせる手法です。
- 混合画像:\( \tilde{x} = \lambda x_i + (1-\lambda)x_j \)
- 混合ラベル:\( \tilde{y} = \lambda y_i + (1-\lambda)y_j \)
\( \lambda \) はBeta分布 \( \mathrm{Beta}(\alpha, \alpha) \) からサンプリングされます。一般的には \( \alpha = 0.2 \) がよく使われます。この設定では \( \lambda \) が0または1に近い値になりやすく、結果として「ほぼ元画像」に近いサンプルになることが多くなります。
| 項目 | 通常の学習 | MixUpあり |
|---|---|---|
| 入力画像 | 元画像そのまま | 2枚をλで合成した画像 |
| ラベル | 整数 or one-hot(ハードラベル) | 2クラスが混ざったソフトラベル |
| 損失関数 | sparse_categorical_crossentropy | categorical_crossentropy(one-hot必須) |
| 過学習への効果 | — | ソフトラベルで過信を抑制 |
なぜ精度が上がるのか
通常の学習では「この画像は猫(確率1.0)」というハードラベルで学習します。MixUpでは「猫0.7・犬0.3」というソフトラベルで学習するため、モデルが特定クラスに過度に自信を持つことを防ぎます。これが汎化性能の向上につながります。
また、2枚の画像の中間点も学習データとして使うことで、決定境界が滑らかになる効果も期待できます。
MixUpはラベルを混ぜるため、損失関数に
sparse_categorical_crossentropy(整数ラベル前提)は使えません。one-hotラベル + categorical_crossentropy が必須です。
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。MixUpの有無以外の条件はすべて同一にします。
環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 100 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (27.7 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122402 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 100.7 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
import・データ準備
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
# ── データ準備 ──────────────────────────────────────────
(x_train_raw, y_train_raw), (x_test_raw, y_test_raw) = keras.datasets.cifar10.load_data()
x_train = x_train_raw.astype('float32') / 255.0
x_test = x_test_raw.astype('float32') / 255.0
y_train_raw = y_train_raw.flatten()
y_test_raw = y_test_raw.flatten()
NUM_CLASSES = 10
BATCH_SIZE = 64
EPOCHS = 30
MIXUP_ALPHA = 0.2 # Beta分布のパラメータ
# one-hot ラベル(MixUp・categorical_crossentropy用)
y_train_oh = tf.one_hot(y_train_raw, NUM_CLASSES)
y_test_oh = tf.one_hot(y_test_raw, NUM_CLASSES)
print("データ準備完了")
実行結果をクリックして内容を開く
データ準備完了
MixUp の実装(tf.data パイプライン)
tf.random.beta はTensorFlowに存在しないAPIです。np.random.beta はEagerモードでは動きますが、tf.data.map 内はGraph modeで実行されるためNumPy関数を直接呼べません。tf.py_function でラップすることで、Graph mode内でも安全にNumPy関数を呼び出せます。
def mixup_batch(images, labels, alpha=MIXUP_ALPHA):
"""1バッチ分の画像・ラベルをMixUpする関数"""
batch_size = tf.shape(images)[0]
# ✅ np.random.beta を tf.py_function でラップして安全に呼び出す
def sample_lambda(bs):
lam = np.random.beta(alpha, alpha, bs.numpy())
lam = np.maximum(lam, 1.0 - lam) # λ >= 0.5 に制限
return lam.astype(np.float32)
lam = tf.py_function(
func=sample_lambda,
inp=[batch_size],
Tout=tf.float32
)
lam.set_shape([None]) # shapeをgraph modeに伝える
# バッチをシャッフルして混合ペアを作る
indices = tf.random.shuffle(tf.range(batch_size))
images_b = tf.gather(images, indices)
labels_b = tf.gather(labels, indices)
# broadcast用に reshape
lam_img = tf.reshape(lam, [batch_size, 1, 1, 1])
lam_lbl = tf.reshape(lam, [batch_size, 1])
mixed_images = lam_img * images + (1.0 - lam_img) * images_b
mixed_labels = lam_lbl * labels + (1.0 - lam_lbl) * labels_b
return mixed_images, mixed_labels
def make_dataset(x, y, use_mixup, training=True):
"""tf.data パイプラインを構築する関数"""
ds = tf.data.Dataset.from_tensor_slices((x, y))
if training:
ds = ds.shuffle(len(x), seed=42)
ds = ds.batch(BATCH_SIZE, drop_remainder=True)
if use_mixup and training:
ds = ds.map(
lambda imgs, lbls: mixup_batch(imgs, lbls),
num_parallel_calls=tf.data.AUTOTUNE
)
return ds.prefetch(tf.data.AUTOTUNE)
モデル構築・学習関数
def build_model(name):
sanitized_name = name.replace(':', '_')
with tf.name_scope(sanitized_name):
model = keras.Sequential([
keras.layers.Input(shape=(32, 32, 3)),
keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(NUM_CLASSES, activation='softmax'),
], name=sanitized_name)
return model
def compile_and_fit(model, train_ds, val_ds):
model.compile(
optimizer='adam',
loss='categorical_crossentropy', # one-hotラベル必須
metrics=['accuracy']
)
start = time.time()
history = model.fit(train_ds, epochs=EPOCHS,
validation_data=val_ds, verbose=1)
return history, time.time() - start
2パターンの学習実行
# validation split(8:2)
val_split = int(len(x_train) * 0.8)
x_tr, x_val = x_train[:val_split], x_train[val_split:]
y_tr, y_val = y_train_oh[:val_split], y_train_oh[val_split:]
configs = [
('A_no_mixup', False),
('B_with_mixup', True),
]
histories, times, scores = {}, {}, {}
for label, use_mixup in configs:
keras.backend.clear_session()
print(f"\n=== {label} ===")
train_ds = make_dataset(x_tr, y_tr, use_mixup=use_mixup, training=True)
val_ds = make_dataset(x_val, y_val, use_mixup=False, training=False)
test_ds = make_dataset(x_test, y_test_oh, use_mixup=False, training=False)
model = build_model(label)
h, t = compile_and_fit(model, train_ds, val_ds)
s = model.evaluate(test_ds, verbose=0)
histories[label] = h
times[label] = t
scores[label] = s
print(f"学習時間:{t:.1f}秒 test_accuracy:{s[1]:.4f}")
実行結果をクリックして内容を開く
=== A_no_mixup === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 8ms/step - accuracy: 0.2714 - loss: 1.9063 - val_accuracy: 0.3385 - val_loss: 1.7469 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3832 - loss: 1.6564 - val_accuracy: 0.4315 - val_loss: 1.5629 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4363 - loss: 1.5291 - val_accuracy: 0.4568 - val_loss: 1.4752 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4762 - loss: 1.4323 - val_accuracy: 0.4990 - val_loss: 1.3610 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4984 - loss: 1.3759 - val_accuracy: 0.5219 - val_loss: 1.3146 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.5159 - loss: 1.3274 - val_accuracy: 0.5189 - val_loss: 1.3120 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5326 - loss: 1.2891 - val_accuracy: 0.5302 - val_loss: 1.2748 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5464 - loss: 1.2551 - val_accuracy: 0.5468 - val_loss: 1.2428 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5568 - loss: 1.2264 - val_accuracy: 0.5772 - val_loss: 1.1784 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5649 - loss: 1.1995 - val_accuracy: 0.5724 - val_loss: 1.1840 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5747 - loss: 1.1732 - val_accuracy: 0.5778 - val_loss: 1.1529 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5861 - loss: 1.1481 - val_accuracy: 0.5966 - val_loss: 1.1217 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5944 - loss: 1.1254 - val_accuracy: 0.6004 - val_loss: 1.1037 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6020 - loss: 1.1098 - val_accuracy: 0.6113 - val_loss: 1.0736 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6062 - loss: 1.0847 - val_accuracy: 0.6140 - val_loss: 1.0577 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6182 - loss: 1.0619 - val_accuracy: 0.6115 - val_loss: 1.0909 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6227 - loss: 1.0496 - val_accuracy: 0.6231 - val_loss: 1.0340 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6308 - loss: 1.0272 - val_accuracy: 0.6339 - val_loss: 1.0213 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6314 - loss: 1.0225 - val_accuracy: 0.6357 - val_loss: 1.0060 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6387 - loss: 0.9996 - val_accuracy: 0.6116 - val_loss: 1.0655 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6470 - loss: 0.9869 - val_accuracy: 0.6483 - val_loss: 0.9777 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6510 - loss: 0.9732 - val_accuracy: 0.6532 - val_loss: 0.9679 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.6556 - loss: 0.9569 - val_accuracy: 0.6566 - val_loss: 0.9666 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6594 - loss: 0.9461 - val_accuracy: 0.6432 - val_loss: 0.9896 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6661 - loss: 0.9322 - val_accuracy: 0.6619 - val_loss: 0.9407 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6694 - loss: 0.9216 - val_accuracy: 0.6642 - val_loss: 0.9420 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6714 - loss: 0.9162 - val_accuracy: 0.6672 - val_loss: 0.9328 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6796 - loss: 0.8924 - val_accuracy: 0.6777 - val_loss: 0.9126 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6787 - loss: 0.8919 - val_accuracy: 0.6807 - val_loss: 0.9141 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6855 - loss: 0.8787 - val_accuracy: 0.6721 - val_loss: 0.9116 学習時間:124.5秒 test_accuracy:0.6656 === B_with_mixup === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 9s 8ms/step - accuracy: 0.2375 - loss: 2.0379 - val_accuracy: 0.3120 - val_loss: 1.7910 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 8s 8ms/step - accuracy: 0.3367 - loss: 1.8265 - val_accuracy: 0.4025 - val_loss: 1.6209 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.3940 - loss: 1.7306 - val_accuracy: 0.4486 - val_loss: 1.5338 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4360 - loss: 1.6405 - val_accuracy: 0.4830 - val_loss: 1.4521 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.4548 - loss: 1.6014 - val_accuracy: 0.5051 - val_loss: 1.3776 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4729 - loss: 1.5597 - val_accuracy: 0.4900 - val_loss: 1.3881 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.4859 - loss: 1.5302 - val_accuracy: 0.5203 - val_loss: 1.3273 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.4970 - loss: 1.5060 - val_accuracy: 0.5269 - val_loss: 1.3004 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.5103 - loss: 1.4792 - val_accuracy: 0.5465 - val_loss: 1.2640 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5189 - loss: 1.4563 - val_accuracy: 0.5402 - val_loss: 1.2547 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5295 - loss: 1.4434 - val_accuracy: 0.5674 - val_loss: 1.2114 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5409 - loss: 1.4162 - val_accuracy: 0.5788 - val_loss: 1.1804 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5460 - loss: 1.4022 - val_accuracy: 0.5876 - val_loss: 1.1613 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5557 - loss: 1.3804 - val_accuracy: 0.5971 - val_loss: 1.1264 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5591 - loss: 1.3706 - val_accuracy: 0.5998 - val_loss: 1.1102 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.5689 - loss: 1.3504 - val_accuracy: 0.6061 - val_loss: 1.1115 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5771 - loss: 1.3378 - val_accuracy: 0.6156 - val_loss: 1.0794 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5824 - loss: 1.3217 - val_accuracy: 0.6211 - val_loss: 1.0707 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.5865 - loss: 1.3179 - val_accuracy: 0.6233 - val_loss: 1.0652 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5948 - loss: 1.3021 - val_accuracy: 0.6172 - val_loss: 1.0801 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6009 - loss: 1.2885 - val_accuracy: 0.6264 - val_loss: 1.0429 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6048 - loss: 1.2804 - val_accuracy: 0.6336 - val_loss: 1.0467 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6041 - loss: 1.2760 - val_accuracy: 0.6417 - val_loss: 1.0198 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6124 - loss: 1.2606 - val_accuracy: 0.6380 - val_loss: 1.0185 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6149 - loss: 1.2496 - val_accuracy: 0.6521 - val_loss: 0.9982 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6193 - loss: 1.2402 - val_accuracy: 0.6500 - val_loss: 1.0011 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - accuracy: 0.6236 - loss: 1.2368 - val_accuracy: 0.6403 - val_loss: 1.0250 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6243 - loss: 1.2292 - val_accuracy: 0.6584 - val_loss: 0.9845 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6298 - loss: 1.2226 - val_accuracy: 0.6702 - val_loss: 0.9617 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6390 - loss: 1.2089 - val_accuracy: 0.6555 - val_loss: 0.9612 学習時間:144.9秒 test_accuracy:0.6602
グラフ+サマリー
# ── val_accuracy / val_loss 比較グラフ ───────────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label)
axes[1].plot(h.history['val_loss'], label=label)
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(fontsize=9); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('mixup_comparison.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(過学習の乖離)────────────────
fig2, axes2 = plt.subplots(2, 1, figsize=(7, 10))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(label)
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('mixup_overfit.png', dpi=150)
plt.show()
# ── サマリー ─────────────────────────────────────────────
print("\n===== 最終結果サマリー =====")
print(f"{'Pattern':>16} | {'Val Acc':>8} | {'Test Acc':>9} | {'Val Loss':>9} | {'Time(s)':>8}")
print("-" * 62)
for label, h in histories.items():
val_acc = h.history['val_accuracy'][-1]
val_loss = h.history['val_loss'][-1]
test_acc = scores[label][1]
t = times[label]
print(f"{label:>16} | {val_acc:>8.4f} | {test_acc:>9.4f} | {val_loss:>9.4f} | {t:>8.1f}")
print("-" * 62)
最終結果サマリー
===== 最終結果サマリー =====
Pattern | Val Acc | Test Acc | Val Loss | Time(s)
--------------------------------------------------------------
A_no_mixup | 0.6721 | 0.6656 | 0.9116 | 124.5
B_with_mixup | 0.6555 | 0.6602 | 0.9612 | 144.9
--------------------------------------------------------------
実験結果
実験はGoogle Colab(T4 GPU)、CIFAR-10、エポック数30・バッチサイズ64の条件で実施しました。
精度グラフ
損失グラフ
A_no_mixup
B_with_mixup
| パターン | val_accuracy | test_accuracy | val_loss | 学習時間 |
|---|---|---|---|---|
| A:MixUpなし(ベースライン) | 0.6721 | 0.6656 | 0.9116 | 124.5秒 |
| B:MixUpあり(α=0.2) | 0.6555 | 0.6602 | 0.9612 | 144.9秒 |
考察
MixUpなしの方が上回った理由
今回の実験では、MixUpなし(A)がval_accuracy・test_accuracyともにMixUpあり(B)をわずかに上回りました。これは「MixUpは常に効く」という期待に反する結果ですが、理由を整理すると納得できます。
① モデルが浅すぎる
MixUpはResNetやEfficientNetといった深いモデルで効果を発揮することが多い手法です。今回使ったモデルはConv2D×2+Dense×2という軽量な構成で、そもそもの表現力が限られています。モデルが十分な容量を持っていないと、混合された曖昧な画像をうまく学習できず、むしろ混乱の原因になることがあります。
② エポック数が少ない
MixUpは学習初期に精度の上がりが遅くなる傾向があります。ソフトラベルによって損失が下がりにくくなるため、30エポックでは学習が収束しきれていない可能性があります。実際にval_lossを比較すると、MixUpありの方が0.9612と高く(悪く)、学習の途中段階であることが示唆されます。
③ データが小さくシンプルすぎる
CIFAR-10は32×32という低解像度で、クラス数も10クラスとシンプルです。MixUpはクラス間の境界が複雑なデータセットほど効果が出やすい手法であり、ImageNetのような大規模・高解像度データに向いています。CIFAR-10のような比較的シンプルな分布では、MixUpの恩恵が出にくい傾向があります。
val_lossとval_accuracyのギャップに注目
MixUpあり(B)はval_lossが0.9612と高いにもかかわらず、test_accuracy(0.6602)はベースラインと僅差(0.0054差)に留まっています。これはMixUpが「損失の最適化」よりも「汎化」を重視した学習をしていることの現れです。ソフトラベルで学習したモデルは、ハードラベルで評価されると損失が高く出やすい構造になっています。
学習時間のトレードオフ
MixUpありの学習時間は144.9秒で、なし(124.5秒)より約16%増加しました。tf.py_function によるNumPy処理のオーバーヘッドが主な原因です。精度向上が見込めない状況でこのコストを払うのは割に合いません。
まとめ:MixUpが効く条件
| 条件 | MixUpの効果 |
|---|---|
| 深いモデル(ResNet, EfficientNet等) | ✅ 効果が出やすい |
| 大規模・複雑なデータセット(ImageNet等) | ✅ 効果が出やすい |
| エポック数が多い(100エポック以上) | ✅ 収束して効果が出やすい |
| 浅いモデル(今回のような軽量CNN) | ❌ 効果が出にくい・逆効果の場合も |
| 小規模・低解像度データ(CIFAR-10等) | △ 効果が出にくい |
| エポック数が少ない(30エポック前後) | ❌ 収束不足で不利になりやすい |
まとめ
| 項目 | MixUpなし | MixUpあり(α=0.2) |
|---|---|---|
| 実装の複雑さ | シンプル | tf.py_function を使った改修が必要 |
| ラベル形式 | 整数 or one-hot | one-hot(ソフトラベル)必須 |
| 損失関数 | sparse_categorical も使える | categorical_crossentropy 必須 |
| test_accuracy(今回) | 0.6656 | 0.6602(−0.0054) |
| 学習時間(今回) | 124.5秒 | 144.9秒(+16%) |
MixUpは「使えば必ず精度が上がる」魔法の手法ではありません。深いモデル・大規模データ・十分なエポック数が揃って初めて効果が出やすくなります。今回のような軽量CNN+CIFAR-10+30エポックという条件では、実装コストと学習時間の増加に見合った精度向上は得られませんでした。使いどころを見極めることが大切です。
関連記事もあわせてどうぞ:
- Data Augmentation 5パターン比較 → Data Augmentationを重ねすぎると精度が下がる?flip〜cropの5パターンをCIFAR-10で比較
- RandomZoomの効果検証 → RandomZoomで精度は上がる?CIFAR-10で検証【Keras実験】
- 正規化方法の比較 → 画像の正規化方法で精度は変わる?/255・BatchNorm・LayerNormをCIFAR-10で比較
- 損失関数の比較 → 損失関数を変えると精度はどう変わる?sparse vs categorical vs Focal Loss【Keras×CIFAR-10実験】





0 件のコメント:
コメントを投稿