AutoAugment vs TrivialAugment、どちらが精度向上に効く?CIFAR-10で比較実験【Keras】

投稿日:2026年6月15日月曜日 最終更新日:

CIFAR-10 CNN Data Augmentation Google Colab Keras 過学習 画像分類

X f B! P L
AutoAugment vs TrivialAugment、どちらが精度向上に効く?CIFAR-10で比較実験【Keras】 アイキャッチ画像

「Data Augmentationをもっと賢くしたい」——そんなときに候補に上がるのが AutoAugmentTrivialAugment です。

ランダムフリップや回転のような手動設定より高度ですが、「実装コストに見合うか」「CIFAR-10のような小画像で効くのか」は実験しないとわかりません。今回はなし・AutoAugment・TrivialAugmentの 3パターン をGoogle Colab(T4)で比較します。

📘 この記事でわかること
  • AutoAugmentとTrivialAugmentの仕組みと違い
  • Kerasでの実装方法(tf.image + tf.py_function)と実装上の注意点
  • CIFAR-10(32×32)では精度差がほぼ出ない理由
  • それでも TrivialAugment を使う理由、使うべき場面

AutoAugment と TrivialAugment とは

どちらも「どのAugmentationをどの強さで適用するか」を自動化するアプローチです。手動でパラメータを決める従来の方法より体系的に探索できます。

手法探索方法特徴
AutoAugment 強化学習でポリシーを最適化 データセット別に最適化済みポリシーを使用。探索コストは高いが既製のポリシーを流用できる
TrivialAugment ランダムに1種類を選び強度もランダム 探索なし。シンプルなのにAutoAugmentに匹敵する精度を達成することが多い

TrivialAugmentは2021年にNeurIPSで発表された手法で、「最もシンプルなランダム戦略が、複雑な探索ベース手法に勝てる」という主張が話題になりました。

TrivialAugmentの動作イメージ

\[ \text{augmented} = T_k(x,\; m),\quad k \sim \text{Uniform}(\mathcal{T}),\quad m \sim \text{Uniform}(0,\; M_{\max}) \]

変換の種類 \(k\) をプール \(\mathcal{T}\) からランダムに1つ選び、強度 \(m\) も一様乱数で決める——それだけです。探索コストはゼロです。

⚠️ ハマりポイント:tf.imageで使えない関数がある
tf.image.adjust_sharpness() は現行TensorFlowバージョンには存在しません。ラプラシアンカーネルを tf.nn.conv2d で適用することで代替実装できます。

また Equalize(ヒストグラム均一化)も tf.image に対応する関数がなく、tf.py_function 経由でNumPy実装が必要です。これらを no-op(何もしない)で代用すると、該当サブポリシーが機能せず精度が下がる原因になります。

実験コード

使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。Augmentation以外の条件は全て同一にして、手法の影響だけを取り出します。

① 環境準備(最初に一度だけ実行)

# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fonts-ipafont-mincho
The following NEW packages will be installed:
  fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 53 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (20.3 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122403 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 43.8 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
  Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了

② import・データ準備・Augmentation関数・モデル構築

import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time

# 再現性のためシード固定
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

# データ読み込み・正規化
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test  = x_test.astype('float32')  / 255.0

# ── 変換の実装 ──────────────────────────────────────────

def _equalize_np(image):
    """ヒストグラム均一化(チャンネル別、NumPy実装)"""
    img = (image * 255).astype(np.uint8)
    out = np.zeros_like(img)
    for c in range(3):
        hist, _ = np.histogram(img[:, :, c].flatten(), 256, [0, 256])
        cdf = hist.cumsum()
        cdf_min = cdf[cdf > 0].min()
        denom = img.shape[0] * img.shape[1] - cdf_min
        lut = np.round((cdf - cdf_min) / denom * 255).astype(np.uint8)
        out[:, :, c] = lut[img[:, :, c]]
    return out.astype(np.float32) / 255.0

def apply_equalize(image):
    """tf.py_function経由でEqualizeを適用"""
    result = tf.py_function(
        func=lambda img: _equalize_np(img.numpy()),
        inp=[image],
        Tout=tf.float32
    )
    result.set_shape(image.shape)
    return result

def apply_sharpness(image, mag):
    """ラプラシアンカーネルによる鮮鋭化(tf.nn.conv2dで実装)"""
    sharpen = tf.constant([
        [ 0, -1,  0],
        [-1,  5, -1],
        [ 0, -1,  0],
    ], dtype=tf.float32) * mag
    identity = tf.constant([
        [0, 0, 0],
        [0, 1, 0],
        [0, 0, 0],
    ], dtype=tf.float32) * (1.0 - mag)
    k = (sharpen + identity)[:, :, tf.newaxis, tf.newaxis]
    img = tf.expand_dims(image, 0)
    channels = tf.split(img, 3, axis=-1)
    out = tf.concat(
        [tf.nn.conv2d(c, k, strides=1, padding='SAME') for c in channels],
        axis=-1
    )
    return tf.clip_by_value(tf.squeeze(out, 0), 0.0, 1.0)

def apply_autocontrast(image):
    """チャンネル別に最小・最大を使ってコントラストを最大化"""
    mn = tf.reduce_min(image, axis=[0, 1], keepdims=True)
    mx = tf.reduce_max(image, axis=[0, 1], keepdims=True)
    scale = tf.where(mx > mn, 1.0 / (mx - mn), tf.ones_like(mx))
    return tf.clip_by_value((image - mn) * scale, 0.0, 1.0)

def apply_transform(image, transform_name, magnitude):
    """変換名と強度(0〜10)を受け取り変換を適用"""
    mag = tf.cast(magnitude, tf.float32) / 10.0  # 0〜1に正規化
    if transform_name == 'FlipLR':
        image = tf.image.flip_left_right(image)
    elif transform_name == 'Brightness':
        image = tf.image.adjust_brightness(image, delta=mag * 0.4)
    elif transform_name == 'Contrast':
        image = tf.image.adjust_contrast(image, contrast_factor=1.0 + mag * 1.8)
    elif transform_name == 'Saturation':
        image = tf.image.adjust_saturation(image, saturation_factor=1.0 + mag * 1.8)
    elif transform_name == 'Sharpness':
        image = apply_sharpness(image, mag * 0.5)
    elif transform_name == 'AutoContrast':
        image = apply_autocontrast(image)
    elif transform_name == 'Equalize':
        image = apply_equalize(image)
    elif transform_name == 'Posterize':
        bits = tf.maximum(tf.cast(4 - mag * 2, tf.int32), 1)
        image_int = tf.cast(image * 255, tf.int32)
        shift = 8 - bits
        image = tf.cast(
            tf.bitwise.left_shift(tf.bitwise.right_shift(image_int, shift), shift),
            tf.float32) / 255.0
    elif transform_name == 'Solarize':
        threshold = 1.0 - mag
        image = tf.where(image < threshold, image, 1.0 - image)
    return tf.clip_by_value(image, 0.0, 1.0)

# ── AutoAugment ─────────────────────────────────────────
# CIFAR-10向けサブポリシー(Google Brainの論文より主要パターン)
# 各サブポリシー:[(変換名, 適用確率, 強度), ...]
CIFAR10_POLICIES = [
    [('FlipLR',       0.5, 0), ('Brightness',   0.6, 7)],
    [('AutoContrast', 0.5, 0), ('Equalize',      0.9, 2)],
    [('Sharpness',    0.5, 1), ('Sharpness',     0.9, 3)],
    [('Brightness',   0.4, 8), ('AutoContrast',  0.6, 0)],
    [('Equalize',     0.8, 8), ('Equalize',      0.0, 3)],
    [('Contrast',     0.7, 0), ('Brightness',    0.3, 7)],
    [('Solarize',     0.2, 4), ('Posterize',     0.8, 6)],
    [('Posterize',    0.8, 6), ('Contrast',      0.5, 8)],
]

def augment_autoaugment(image, label):
    """AutoAugment(CIFAR-10ポリシー)"""
    policy_idx = tf.random.uniform([], 0, len(CIFAR10_POLICIES), dtype=tf.int32)
    for i, policy in enumerate(CIFAR10_POLICIES):
        def apply_policy(img, p=policy):
            for transform_name, prob, mag in p:
                r = tf.random.uniform([])
                img = tf.cond(
                    r < prob,
                    lambda i=img, t=transform_name, m=mag: apply_transform(i, t, m),
                    lambda i=img: i
                )
            return img
        image = tf.cond(
            tf.equal(policy_idx, i),
            lambda img=image, i=i: apply_policy(img, CIFAR10_POLICIES[i]),
            lambda img=image: img
        )
    return image, label

# ── TrivialAugment ──────────────────────────────────────
TRIVIAL_OPS = [
    'FlipLR', 'Brightness', 'Contrast', 'Saturation',
    'Sharpness', 'AutoContrast', 'Equalize', 'Posterize', 'Solarize',
]

def augment_trivialaugment(image, label):
    """TrivialAugment:1種類をランダム選択、強度もランダム"""
    op_idx    = tf.random.uniform([], 0, len(TRIVIAL_OPS), dtype=tf.int32)
    magnitude = tf.random.uniform([], 0, 10)
    for i, op in enumerate(TRIVIAL_OPS):
        image = tf.cond(
            tf.equal(op_idx, i),
            lambda img=image, t=op, m=magnitude: apply_transform(img, t, m),
            lambda img=image: img
        )
    return image, label

def augment_none(image, label):
    """Augmentationなし(ベースライン)"""
    return image, label

# ── データセット・モデル ────────────────────────────────
def build_dataset(x, y, augment_fn, batch_size=64, training=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if training:
        ds = ds.shuffle(len(x), seed=SEED)
        ds = ds.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)
    return ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

def build_model():
    """共通ベースラインCNN"""
    return keras.Sequential([
        keras.layers.Input(shape=(32, 32, 3)),
        keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.GlobalAveragePooling2D(),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation='softmax'),
    ])
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 21s 0us/step

③ 3パターンの学習実行

# ── 実験設定:Augmentationの種類だけを変える ──────────
configs = [
    ('A_none',           augment_none),
    ('B_autoaugment',    augment_autoaugment),
    ('C_trivialaugment', augment_trivialaugment),
]

# validation用データセット(Augmentationなし)
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_tr  = x_train[:-10000]
y_tr  = y_train[:-10000]

val_ds  = build_dataset(x_val,  y_val,  augment_none, training=False)
test_ds = build_dataset(x_test, y_test, augment_none, training=False)

histories, times, scores = {}, {}, {}

for name, aug_fn in configs:
    print(f"\n=== {name} ===")
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    train_ds = build_dataset(x_tr, y_tr, aug_fn, training=True)
    model = build_model()
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    start = time.time()
    history = model.fit(train_ds, epochs=30, validation_data=val_ds, verbose=1)
    elapsed = time.time() - start
    score = model.evaluate(test_ds, verbose=0)
    label = name.split('_', 1)[1]
    histories[label] = history
    times[label]     = elapsed
    scores[label]    = score
    print(f"学習時間:{elapsed:.1f}秒 test_accuracy:{score[1]:.4f}")
実行結果をクリックして内容を開く
=== A_none ===
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 18s 14ms/step - accuracy: 0.2594 - loss: 1.9440 - val_accuracy: 0.3231 - val_loss: 1.8420
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3652 - loss: 1.6905 - val_accuracy: 0.3859 - val_loss: 1.6379
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4196 - loss: 1.5705 - val_accuracy: 0.4562 - val_loss: 1.4703
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4602 - loss: 1.4768 - val_accuracy: 0.4850 - val_loss: 1.4414
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4867 - loss: 1.4074 - val_accuracy: 0.5043 - val_loss: 1.3569
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5018 - loss: 1.3629 - val_accuracy: 0.5261 - val_loss: 1.3042
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5152 - loss: 1.3203 - val_accuracy: 0.5298 - val_loss: 1.2925
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5305 - loss: 1.2901 - val_accuracy: 0.5337 - val_loss: 1.2706
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5424 - loss: 1.2546 - val_accuracy: 0.5543 - val_loss: 1.2085
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5527 - loss: 1.2330 - val_accuracy: 0.5528 - val_loss: 1.2160
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.5605 - loss: 1.2148 - val_accuracy: 0.5607 - val_loss: 1.2052
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.5700 - loss: 1.1803 - val_accuracy: 0.5767 - val_loss: 1.1471
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5800 - loss: 1.1607 - val_accuracy: 0.5677 - val_loss: 1.1883
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5879 - loss: 1.1390 - val_accuracy: 0.5814 - val_loss: 1.1299
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5951 - loss: 1.1193 - val_accuracy: 0.5955 - val_loss: 1.1048
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6044 - loss: 1.0977 - val_accuracy: 0.6112 - val_loss: 1.0689
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6075 - loss: 1.0868 - val_accuracy: 0.6105 - val_loss: 1.0661
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6163 - loss: 1.0622 - val_accuracy: 0.6193 - val_loss: 1.0398
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6229 - loss: 1.0462 - val_accuracy: 0.6285 - val_loss: 1.0270
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6292 - loss: 1.0338 - val_accuracy: 0.6122 - val_loss: 1.0759
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6348 - loss: 1.0141 - val_accuracy: 0.6319 - val_loss: 1.0154
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6385 - loss: 1.0029 - val_accuracy: 0.6460 - val_loss: 0.9900
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6442 - loss: 0.9866 - val_accuracy: 0.6456 - val_loss: 0.9770
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6498 - loss: 0.9768 - val_accuracy: 0.6505 - val_loss: 0.9686
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6529 - loss: 0.9610 - val_accuracy: 0.6461 - val_loss: 0.9782
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6568 - loss: 0.9558 - val_accuracy: 0.6447 - val_loss: 0.9971
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6614 - loss: 0.9410 - val_accuracy: 0.6500 - val_loss: 0.9733
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6646 - loss: 0.9307 - val_accuracy: 0.6621 - val_loss: 0.9540
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6714 - loss: 0.9190 - val_accuracy: 0.6663 - val_loss: 0.9394
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6740 - loss: 0.9079 - val_accuracy: 0.6549 - val_loss: 0.9668
学習時間:134.6秒 test_accuracy:0.6517

=== B_autoaugment ===
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 29s 41ms/step - accuracy: 0.2277 - loss: 2.0272 - val_accuracy: 0.3121 - val_loss: 1.8418
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 40s 42ms/step - accuracy: 0.3296 - loss: 1.8050 - val_accuracy: 0.3852 - val_loss: 1.6680
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 27s 43ms/step - accuracy: 0.3870 - loss: 1.6762 - val_accuracy: 0.4609 - val_loss: 1.4865
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 39s 40ms/step - accuracy: 0.4309 - loss: 1.5727 - val_accuracy: 0.4816 - val_loss: 1.4763
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 39ms/step - accuracy: 0.4597 - loss: 1.4955 - val_accuracy: 0.5195 - val_loss: 1.3439
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.4761 - loss: 1.4449 - val_accuracy: 0.5341 - val_loss: 1.3158
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 41s 38ms/step - accuracy: 0.4936 - loss: 1.4072 - val_accuracy: 0.5272 - val_loss: 1.3003
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 38ms/step - accuracy: 0.5023 - loss: 1.3723 - val_accuracy: 0.5487 - val_loss: 1.2752
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 23s 37ms/step - accuracy: 0.5113 - loss: 1.3463 - val_accuracy: 0.5567 - val_loss: 1.2183
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 23s 37ms/step - accuracy: 0.5242 - loss: 1.3182 - val_accuracy: 0.5592 - val_loss: 1.2186
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.5339 - loss: 1.2971 - val_accuracy: 0.5839 - val_loss: 1.1728
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 42s 41ms/step - accuracy: 0.5426 - loss: 1.2681 - val_accuracy: 0.5804 - val_loss: 1.1873
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 40s 39ms/step - accuracy: 0.5520 - loss: 1.2469 - val_accuracy: 0.6021 - val_loss: 1.1268
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 40s 37ms/step - accuracy: 0.5561 - loss: 1.2340 - val_accuracy: 0.5897 - val_loss: 1.1371
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 26s 41ms/step - accuracy: 0.5638 - loss: 1.2090 - val_accuracy: 0.5937 - val_loss: 1.1224
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 26s 41ms/step - accuracy: 0.5689 - loss: 1.1968 - val_accuracy: 0.6040 - val_loss: 1.1046
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 26s 40ms/step - accuracy: 0.5803 - loss: 1.1742 - val_accuracy: 0.6192 - val_loss: 1.0608
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 40ms/step - accuracy: 0.5795 - loss: 1.1654 - val_accuracy: 0.6357 - val_loss: 1.0207
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 38ms/step - accuracy: 0.5902 - loss: 1.1461 - val_accuracy: 0.6339 - val_loss: 1.0249
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 37ms/step - accuracy: 0.5940 - loss: 1.1364 - val_accuracy: 0.6162 - val_loss: 1.0542
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 38ms/step - accuracy: 0.5988 - loss: 1.1233 - val_accuracy: 0.6427 - val_loss: 0.9829
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.6025 - loss: 1.1106 - val_accuracy: 0.6330 - val_loss: 1.0132
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 40ms/step - accuracy: 0.6083 - loss: 1.1012 - val_accuracy: 0.6387 - val_loss: 0.9943
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.6086 - loss: 1.0973 - val_accuracy: 0.6518 - val_loss: 0.9795
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.6155 - loss: 1.0779 - val_accuracy: 0.6382 - val_loss: 1.0055
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 41s 40ms/step - accuracy: 0.6165 - loss: 1.0733 - val_accuracy: 0.6562 - val_loss: 0.9603
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 39ms/step - accuracy: 0.6242 - loss: 1.0559 - val_accuracy: 0.6473 - val_loss: 0.9781
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 40ms/step - accuracy: 0.6255 - loss: 1.0520 - val_accuracy: 0.6509 - val_loss: 0.9752
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.6291 - loss: 1.0458 - val_accuracy: 0.6685 - val_loss: 0.9289
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 38ms/step - accuracy: 0.6316 - loss: 1.0331 - val_accuracy: 0.6535 - val_loss: 0.9671
学習時間:855.5秒 test_accuracy:0.6494

=== C_trivialaugment ===
Epoch 1/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 21s 28ms/step - accuracy: 0.2346 - loss: 2.0160 - val_accuracy: 0.3228 - val_loss: 1.8289
Epoch 2/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 18s 27ms/step - accuracy: 0.3391 - loss: 1.7923 - val_accuracy: 0.3997 - val_loss: 1.6350
Epoch 3/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 26ms/step - accuracy: 0.3972 - loss: 1.6423 - val_accuracy: 0.4741 - val_loss: 1.4512
Epoch 4/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.4355 - loss: 1.5526 - val_accuracy: 0.4808 - val_loss: 1.4357
Epoch 5/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 27ms/step - accuracy: 0.4545 - loss: 1.4934 - val_accuracy: 0.5077 - val_loss: 1.3468
Epoch 6/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.4756 - loss: 1.4489 - val_accuracy: 0.5310 - val_loss: 1.3133
Epoch 7/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.4856 - loss: 1.4136 - val_accuracy: 0.5386 - val_loss: 1.2842
Epoch 8/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 27ms/step - accuracy: 0.4998 - loss: 1.3806 - val_accuracy: 0.5423 - val_loss: 1.2990
Epoch 9/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 20s 26ms/step - accuracy: 0.5095 - loss: 1.3536 - val_accuracy: 0.5501 - val_loss: 1.2289
Epoch 10/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 26ms/step - accuracy: 0.5197 - loss: 1.3329 - val_accuracy: 0.5562 - val_loss: 1.2310
Epoch 11/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 26ms/step - accuracy: 0.5283 - loss: 1.3020 - val_accuracy: 0.5830 - val_loss: 1.1917
Epoch 12/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5386 - loss: 1.2819 - val_accuracy: 0.5830 - val_loss: 1.1589
Epoch 13/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5458 - loss: 1.2600 - val_accuracy: 0.5856 - val_loss: 1.1437
Epoch 14/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 26ms/step - accuracy: 0.5546 - loss: 1.2378 - val_accuracy: 0.5532 - val_loss: 1.2090
Epoch 15/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.5580 - loss: 1.2273 - val_accuracy: 0.5995 - val_loss: 1.1144
Epoch 16/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5659 - loss: 1.2060 - val_accuracy: 0.6169 - val_loss: 1.0742
Epoch 17/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 27ms/step - accuracy: 0.5736 - loss: 1.1929 - val_accuracy: 0.6269 - val_loss: 1.0382
Epoch 18/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 19s 25ms/step - accuracy: 0.5782 - loss: 1.1744 - val_accuracy: 0.6319 - val_loss: 1.0392
Epoch 19/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5824 - loss: 1.1657 - val_accuracy: 0.6388 - val_loss: 1.0173
Epoch 20/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5910 - loss: 1.1514 - val_accuracy: 0.6257 - val_loss: 1.0427
Epoch 21/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5955 - loss: 1.1378 - val_accuracy: 0.6430 - val_loss: 0.9918
Epoch 22/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6018 - loss: 1.1183 - val_accuracy: 0.6437 - val_loss: 0.9908
Epoch 23/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 27ms/step - accuracy: 0.6052 - loss: 1.1124 - val_accuracy: 0.6482 - val_loss: 0.9743
Epoch 24/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 19s 25ms/step - accuracy: 0.6078 - loss: 1.0988 - val_accuracy: 0.6554 - val_loss: 0.9640
Epoch 25/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.6103 - loss: 1.0915 - val_accuracy: 0.6634 - val_loss: 0.9454
Epoch 26/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6153 - loss: 1.0779 - val_accuracy: 0.6671 - val_loss: 0.9403
Epoch 27/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6197 - loss: 1.0637 - val_accuracy: 0.6644 - val_loss: 0.9508
Epoch 28/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 21s 26ms/step - accuracy: 0.6226 - loss: 1.0597 - val_accuracy: 0.6544 - val_loss: 0.9643
Epoch 29/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6272 - loss: 1.0496 - val_accuracy: 0.6830 - val_loss: 0.8974
Epoch 30/30
625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6336 - loss: 1.0313 - val_accuracy: 0.6548 - val_loss: 0.9598
学習時間:514.7秒 test_accuracy:0.6515

④ グラフ+サマリー出力

import pandas as pd

label_map = {
    'none':           'なし',
    'autoaugment':    'AutoAugment',
    'trivialaugment': 'TrivialAugment',
}

# ── val_accuracy / val_loss 比較グラフ ─────────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
    axes[0].plot(h.history['val_accuracy'], label=label_map[label])
    axes[1].plot(h.history['val_loss'],     label=label_map[label])
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
    ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('augment_val.png', dpi=150)
plt.show()

# ── train_loss vs val_loss(過学習の乖離確認)──────────
fig2, axes2 = plt.subplots(1, 3, figsize=(18, 5))
for i, (label, h) in enumerate(histories.items()):
    axes2[i].plot(h.history['loss'],     label='train_loss')
    axes2[i].plot(h.history['val_loss'], label='val_loss')
    axes2[i].set_title(label_map[label])
    axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('augment_overfit.png', dpi=150)
plt.show()

# ── 結果サマリー ────────────────────────────────────────
rows = []
for label in ['none', 'autoaugment', 'trivialaugment']:
    val_acc  = histories[label].history['val_accuracy'][-1]
    test_acc = scores[label][1]
    t        = times[label]
    rows.append({
        'パターン':       label_map[label],
        'val_accuracy':  f"{val_acc:.4f}",
        'test_accuracy': f"{test_acc:.4f}",
        '学習時間(s)':   f"{t:.1f}",
    })
df = pd.DataFrame(rows).sort_values('test_accuracy', ascending=False)
print(df.to_string(index=False))

最終結果サマリー

          パターン val_accuracy test_accuracy 学習時間(s)
            なし       0.6549        0.6517   134.6
TrivialAugment       0.6548        0.6515   514.7
   AutoAugment       0.6535        0.6494   855.5

実験結果

精度グラフ

精度グラフ

損失グラフ

損失グラフ

なし

なし

TrivialAugment

TrivialAugment

AutoAugment

AutoAugment
パターン val_accuracy test_accuracy 学習時間
なし 0.6549 0.6517 134.6秒
TrivialAugment 0.6548 0.6515 514.7秒
AutoAugment 0.6535 0.6494 855.5秒

考察

① 精度差はほぼゼロ——3パターンが横並びの結果

最大精度差は 0.0023pt(0.23%) で、乱数シードを変えれば容易に順位が入れ替わる誤差範囲です。コードの修正前(Sharpness no-op・Equalize no-op・AutoContrast 固定値)と比べても結果は変わりませんでした。

つまり今回の条件下では、変換の実装精度より「そもそもAugmentationが有効かどうか」が支配的だったことになります。

② なぜ差が出なかったのか——CIFAR-10×30エポックの構造的な問題

このブログのAugmentation系実験では、CIFAR-10の小画像に対して高度なAugmentationが効きにくいケースが繰り返し観測されています。原因は2つです。

原因内容
画像が小さすぎる 32×32ピクセルでは色変換・鮮鋭化・ポスタリゼーションが画像の情報を破壊しやすく、有益な多様性より有害なノイズになりやすい
エポック数が少ない Augmentationは学習を意図的に難しくする手法。30エポックでは多様な変換を十分に学習に活かしきれず、むしろ収束が遅れるだけになりやすい

論文でAutoAugment・TrivialAugmentが効果を示した条件は、ImageNet(224×224)や数百エポックの学習です。スケール感が根本的に異なることが今回の結果に直結しています。

③ 学習時間のコストが見合わない

精度がほぼ同じにもかかわらず、学習時間は大きく異なりました。

パターン学習時間なしとの比率
なし134.6秒1.0×
TrivialAugment514.7秒3.8×
AutoAugment855.5秒6.4×

TrivialAugmentは tf.py_function(Equalize)の影響で約4倍、AutoAugmentは tf.cond を8段ネストするポリシー選択のオーバーヘッドでさらに遅くなっています。CIFAR-10×30エポックという条件では、このコストに見合う精度改善は得られませんでした。

④ それでも TrivialAugment を使う価値はあるか

今回の結果だけを見れば「CIFAR-10では不要」という結論になりますが、実務での判断は異なります。

条件TrivialAugmentの有効性
エポック数を増やす(100エポック以上)◎ 多様な変換が学習に浸透しやすくなり精度向上が期待できる
ImageNet規模の画像(224×224以上)◎ 本来想定された条件。論文でも明確な効果が示されている
データ量が少ない(クラスあたり数百枚以下)○ 過学習防止として有効になりやすい
CIFAR-10×30エポック(今回)△ 精度向上は期待しにくい。過学習が問題なら検討の余地あり

⑤ Kerasで正確なAutoAugmentを実装したい場合

今回の実装は tf.image の制約から完全なAutoAugmentではありません。論文と同等の実装を使いたい場合は KerasCV が現実的な選択肢です。

# KerasCV を使った AutoAugment(参考)
# pip install keras-cv
import keras_cv
auto_aug = keras_cv.layers.AutoAugment(value_range=(0, 1))
# tf.data パイプラインで使用
ds = ds.map(lambda x, y: (auto_aug(x, training=True), y))

実務での推奨

状況推奨理由
CIFAR-10規模・30エポック以下なし or フリップのみ高度なAugmentは精度改善なしにコストだけ増える
CIFAR-10で過学習が問題のときTrivialAugment(エポック数も増やす)汎化性能向上の余地がある
ImageNet規模・100エポック以上TrivialAugment実装シンプル・探索コストゼロで効果十分
精度を最大化したい本番モデルKerasCV の AutoAugment完全実装済みポリシーで最高精度を狙える
まとめ
  • CIFAR-10×30エポックでは3パターンの精度差は最大 0.23% で、誤差範囲内。高度なAugmentationの恩恵はほぼ得られなかった
  • 原因は「32×32では変換が情報を破壊しやすい」「30エポックでは多様性を活かしきれない」の2点
  • 学習時間はなし(135秒)に対して TrivialAugment が3.8倍、AutoAugment が6.4倍。精度が同等なら明らかにコスト過多
  • tf.image.adjust_sharpness() は現行TFに存在しない。ラプラシアンカーネル(tf.nn.conv2d)で代替実装できる
  • Equalizetf.py_function 経由のNumPy実装が必要。ただしこれが学習時間増加の一因になる
  • AutoAugment・TrivialAugmentの真価は 大きな画像・多いエポック数・少ないデータ量 の条件で発揮される

English Summary

We compared three augmentation strategies on CIFAR-10 (32×32) for 30 epochs: no augmentation, AutoAugment (CIFAR-10 policy), and TrivialAugment. The maximum accuracy difference was only 0.23%, well within noise range. The main finding is that on small images with few epochs, complex augmentation strategies offer no benefit while significantly increasing training time (3.8× for TrivialAugment, 6.4× for AutoAugment). AutoAugment and TrivialAugment show their true value with larger images, more epochs, or smaller datasets.

関連記事