「Data Augmentationをもっと賢くしたい」——そんなときに候補に上がるのが AutoAugment と TrivialAugment です。
ランダムフリップや回転のような手動設定より高度ですが、「実装コストに見合うか」「CIFAR-10のような小画像で効くのか」は実験しないとわかりません。今回はなし・AutoAugment・TrivialAugmentの 3パターン をGoogle Colab(T4)で比較します。
- AutoAugmentとTrivialAugmentの仕組みと違い
- Kerasでの実装方法(
tf.image+tf.py_function)と実装上の注意点 - CIFAR-10(32×32)では精度差がほぼ出ない理由
- それでも TrivialAugment を使う理由、使うべき場面
AutoAugment と TrivialAugment とは
どちらも「どのAugmentationをどの強さで適用するか」を自動化するアプローチです。手動でパラメータを決める従来の方法より体系的に探索できます。
| 手法 | 探索方法 | 特徴 |
|---|---|---|
| AutoAugment | 強化学習でポリシーを最適化 | データセット別に最適化済みポリシーを使用。探索コストは高いが既製のポリシーを流用できる |
| TrivialAugment | ランダムに1種類を選び強度もランダム | 探索なし。シンプルなのにAutoAugmentに匹敵する精度を達成することが多い |
TrivialAugmentは2021年にNeurIPSで発表された手法で、「最もシンプルなランダム戦略が、複雑な探索ベース手法に勝てる」という主張が話題になりました。
TrivialAugmentの動作イメージ
\[ \text{augmented} = T_k(x,\; m),\quad k \sim \text{Uniform}(\mathcal{T}),\quad m \sim \text{Uniform}(0,\; M_{\max}) \]
変換の種類 \(k\) をプール \(\mathcal{T}\) からランダムに1つ選び、強度 \(m\) も一様乱数で決める——それだけです。探索コストはゼロです。
tf.image.adjust_sharpness() は現行TensorFlowバージョンには存在しません。ラプラシアンカーネルを tf.nn.conv2d で適用することで代替実装できます。また
Equalize(ヒストグラム均一化)も tf.image に対応する関数がなく、tf.py_function 経由でNumPy実装が必要です。これらを no-op(何もしない)で代用すると、該当サブポリシーが機能せず精度が下がる原因になります。
実験コード
使用環境はGoogle Colab(GPU:T4)、データセットはCIFAR-10です。Augmentation以外の条件は全て同一にして、手法の影響だけを取り出します。
① 環境準備(最初に一度だけ実行)
# ── 環境準備(最初に一度だけ実行)──────────────────────
!apt-get -y install fonts-ipafont-gothic
!rm -rf /root/.cache/matplotlib
!pip install -q japanize_matplotlib
print("環境準備完了")
実行結果をクリックして内容を開く
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
fonts-ipafont-mincho
The following NEW packages will be installed:
fonts-ipafont-gothic fonts-ipafont-mincho
0 upgraded, 2 newly installed, 0 to remove and 53 not upgraded.
Need to get 8,237 kB of archives.
After this operation, 28.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-gothic all 00303-21ubuntu1 [3,513 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fonts-ipafont-mincho all 00303-21ubuntu1 [4,724 kB]
Fetched 8,237 kB in 0s (20.3 MB/s)
Selecting previously unselected package fonts-ipafont-gothic.
(Reading database ... 122403 files and directories currently installed.)
Preparing to unpack .../fonts-ipafont-gothic_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-gothic (00303-21ubuntu1) ...
Selecting previously unselected package fonts-ipafont-mincho.
Preparing to unpack .../fonts-ipafont-mincho_00303-21ubuntu1_all.deb ...
Unpacking fonts-ipafont-mincho (00303-21ubuntu1) ...
Setting up fonts-ipafont-mincho (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-mincho/ipam.ttf to provide /usr/share/fonts/truetype/fonts-japanese-mincho.ttf (fonts-japanese-mincho.ttf) in auto mode
Setting up fonts-ipafont-gothic (00303-21ubuntu1) ...
update-alternatives: using /usr/share/fonts/opentype/ipafont-gothic/ipag.ttf to provide /usr/share/fonts/truetype/fonts-japanese-gothic.ttf (fonts-japanese-gothic.ttf) in auto mode
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 43.8 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Building wheel for japanize_matplotlib (setup.py) ... done
環境準備完了
② import・データ準備・Augmentation関数・モデル構築
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import japanize_matplotlib
import time
# 再現性のためシード固定
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
# データ読み込み・正規化
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# ── 変換の実装 ──────────────────────────────────────────
def _equalize_np(image):
"""ヒストグラム均一化(チャンネル別、NumPy実装)"""
img = (image * 255).astype(np.uint8)
out = np.zeros_like(img)
for c in range(3):
hist, _ = np.histogram(img[:, :, c].flatten(), 256, [0, 256])
cdf = hist.cumsum()
cdf_min = cdf[cdf > 0].min()
denom = img.shape[0] * img.shape[1] - cdf_min
lut = np.round((cdf - cdf_min) / denom * 255).astype(np.uint8)
out[:, :, c] = lut[img[:, :, c]]
return out.astype(np.float32) / 255.0
def apply_equalize(image):
"""tf.py_function経由でEqualizeを適用"""
result = tf.py_function(
func=lambda img: _equalize_np(img.numpy()),
inp=[image],
Tout=tf.float32
)
result.set_shape(image.shape)
return result
def apply_sharpness(image, mag):
"""ラプラシアンカーネルによる鮮鋭化(tf.nn.conv2dで実装)"""
sharpen = tf.constant([
[ 0, -1, 0],
[-1, 5, -1],
[ 0, -1, 0],
], dtype=tf.float32) * mag
identity = tf.constant([
[0, 0, 0],
[0, 1, 0],
[0, 0, 0],
], dtype=tf.float32) * (1.0 - mag)
k = (sharpen + identity)[:, :, tf.newaxis, tf.newaxis]
img = tf.expand_dims(image, 0)
channels = tf.split(img, 3, axis=-1)
out = tf.concat(
[tf.nn.conv2d(c, k, strides=1, padding='SAME') for c in channels],
axis=-1
)
return tf.clip_by_value(tf.squeeze(out, 0), 0.0, 1.0)
def apply_autocontrast(image):
"""チャンネル別に最小・最大を使ってコントラストを最大化"""
mn = tf.reduce_min(image, axis=[0, 1], keepdims=True)
mx = tf.reduce_max(image, axis=[0, 1], keepdims=True)
scale = tf.where(mx > mn, 1.0 / (mx - mn), tf.ones_like(mx))
return tf.clip_by_value((image - mn) * scale, 0.0, 1.0)
def apply_transform(image, transform_name, magnitude):
"""変換名と強度(0〜10)を受け取り変換を適用"""
mag = tf.cast(magnitude, tf.float32) / 10.0 # 0〜1に正規化
if transform_name == 'FlipLR':
image = tf.image.flip_left_right(image)
elif transform_name == 'Brightness':
image = tf.image.adjust_brightness(image, delta=mag * 0.4)
elif transform_name == 'Contrast':
image = tf.image.adjust_contrast(image, contrast_factor=1.0 + mag * 1.8)
elif transform_name == 'Saturation':
image = tf.image.adjust_saturation(image, saturation_factor=1.0 + mag * 1.8)
elif transform_name == 'Sharpness':
image = apply_sharpness(image, mag * 0.5)
elif transform_name == 'AutoContrast':
image = apply_autocontrast(image)
elif transform_name == 'Equalize':
image = apply_equalize(image)
elif transform_name == 'Posterize':
bits = tf.maximum(tf.cast(4 - mag * 2, tf.int32), 1)
image_int = tf.cast(image * 255, tf.int32)
shift = 8 - bits
image = tf.cast(
tf.bitwise.left_shift(tf.bitwise.right_shift(image_int, shift), shift),
tf.float32) / 255.0
elif transform_name == 'Solarize':
threshold = 1.0 - mag
image = tf.where(image < threshold, image, 1.0 - image)
return tf.clip_by_value(image, 0.0, 1.0)
# ── AutoAugment ─────────────────────────────────────────
# CIFAR-10向けサブポリシー(Google Brainの論文より主要パターン)
# 各サブポリシー:[(変換名, 適用確率, 強度), ...]
CIFAR10_POLICIES = [
[('FlipLR', 0.5, 0), ('Brightness', 0.6, 7)],
[('AutoContrast', 0.5, 0), ('Equalize', 0.9, 2)],
[('Sharpness', 0.5, 1), ('Sharpness', 0.9, 3)],
[('Brightness', 0.4, 8), ('AutoContrast', 0.6, 0)],
[('Equalize', 0.8, 8), ('Equalize', 0.0, 3)],
[('Contrast', 0.7, 0), ('Brightness', 0.3, 7)],
[('Solarize', 0.2, 4), ('Posterize', 0.8, 6)],
[('Posterize', 0.8, 6), ('Contrast', 0.5, 8)],
]
def augment_autoaugment(image, label):
"""AutoAugment(CIFAR-10ポリシー)"""
policy_idx = tf.random.uniform([], 0, len(CIFAR10_POLICIES), dtype=tf.int32)
for i, policy in enumerate(CIFAR10_POLICIES):
def apply_policy(img, p=policy):
for transform_name, prob, mag in p:
r = tf.random.uniform([])
img = tf.cond(
r < prob,
lambda i=img, t=transform_name, m=mag: apply_transform(i, t, m),
lambda i=img: i
)
return img
image = tf.cond(
tf.equal(policy_idx, i),
lambda img=image, i=i: apply_policy(img, CIFAR10_POLICIES[i]),
lambda img=image: img
)
return image, label
# ── TrivialAugment ──────────────────────────────────────
TRIVIAL_OPS = [
'FlipLR', 'Brightness', 'Contrast', 'Saturation',
'Sharpness', 'AutoContrast', 'Equalize', 'Posterize', 'Solarize',
]
def augment_trivialaugment(image, label):
"""TrivialAugment:1種類をランダム選択、強度もランダム"""
op_idx = tf.random.uniform([], 0, len(TRIVIAL_OPS), dtype=tf.int32)
magnitude = tf.random.uniform([], 0, 10)
for i, op in enumerate(TRIVIAL_OPS):
image = tf.cond(
tf.equal(op_idx, i),
lambda img=image, t=op, m=magnitude: apply_transform(img, t, m),
lambda img=image: img
)
return image, label
def augment_none(image, label):
"""Augmentationなし(ベースライン)"""
return image, label
# ── データセット・モデル ────────────────────────────────
def build_dataset(x, y, augment_fn, batch_size=64, training=True):
ds = tf.data.Dataset.from_tensor_slices((x, y))
if training:
ds = ds.shuffle(len(x), seed=SEED)
ds = ds.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)
return ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
def build_model():
"""共通ベースラインCNN"""
return keras.Sequential([
keras.layers.Input(shape=(32, 32, 3)),
keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax'),
])
実行結果をクリックして内容を開く
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 21s 0us/step
③ 3パターンの学習実行
# ── 実験設定:Augmentationの種類だけを変える ──────────
configs = [
('A_none', augment_none),
('B_autoaugment', augment_autoaugment),
('C_trivialaugment', augment_trivialaugment),
]
# validation用データセット(Augmentationなし)
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_tr = x_train[:-10000]
y_tr = y_train[:-10000]
val_ds = build_dataset(x_val, y_val, augment_none, training=False)
test_ds = build_dataset(x_test, y_test, augment_none, training=False)
histories, times, scores = {}, {}, {}
for name, aug_fn in configs:
print(f"\n=== {name} ===")
np.random.seed(SEED)
tf.random.set_seed(SEED)
train_ds = build_dataset(x_tr, y_tr, aug_fn, training=True)
model = build_model()
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
start = time.time()
history = model.fit(train_ds, epochs=30, validation_data=val_ds, verbose=1)
elapsed = time.time() - start
score = model.evaluate(test_ds, verbose=0)
label = name.split('_', 1)[1]
histories[label] = history
times[label] = elapsed
scores[label] = score
print(f"学習時間:{elapsed:.1f}秒 test_accuracy:{score[1]:.4f}")
実行結果をクリックして内容を開く
=== A_none === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 18s 14ms/step - accuracy: 0.2594 - loss: 1.9440 - val_accuracy: 0.3231 - val_loss: 1.8420 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3652 - loss: 1.6905 - val_accuracy: 0.3859 - val_loss: 1.6379 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4196 - loss: 1.5705 - val_accuracy: 0.4562 - val_loss: 1.4703 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4602 - loss: 1.4768 - val_accuracy: 0.4850 - val_loss: 1.4414 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.4867 - loss: 1.4074 - val_accuracy: 0.5043 - val_loss: 1.3569 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5018 - loss: 1.3629 - val_accuracy: 0.5261 - val_loss: 1.3042 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5152 - loss: 1.3203 - val_accuracy: 0.5298 - val_loss: 1.2925 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5305 - loss: 1.2901 - val_accuracy: 0.5337 - val_loss: 1.2706 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5424 - loss: 1.2546 - val_accuracy: 0.5543 - val_loss: 1.2085 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5527 - loss: 1.2330 - val_accuracy: 0.5528 - val_loss: 1.2160 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.5605 - loss: 1.2148 - val_accuracy: 0.5607 - val_loss: 1.2052 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.5700 - loss: 1.1803 - val_accuracy: 0.5767 - val_loss: 1.1471 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.5800 - loss: 1.1607 - val_accuracy: 0.5677 - val_loss: 1.1883 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5879 - loss: 1.1390 - val_accuracy: 0.5814 - val_loss: 1.1299 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.5951 - loss: 1.1193 - val_accuracy: 0.5955 - val_loss: 1.1048 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6044 - loss: 1.0977 - val_accuracy: 0.6112 - val_loss: 1.0689 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6075 - loss: 1.0868 - val_accuracy: 0.6105 - val_loss: 1.0661 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6163 - loss: 1.0622 - val_accuracy: 0.6193 - val_loss: 1.0398 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6229 - loss: 1.0462 - val_accuracy: 0.6285 - val_loss: 1.0270 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6292 - loss: 1.0338 - val_accuracy: 0.6122 - val_loss: 1.0759 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6348 - loss: 1.0141 - val_accuracy: 0.6319 - val_loss: 1.0154 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6385 - loss: 1.0029 - val_accuracy: 0.6460 - val_loss: 0.9900 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6442 - loss: 0.9866 - val_accuracy: 0.6456 - val_loss: 0.9770 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6498 - loss: 0.9768 - val_accuracy: 0.6505 - val_loss: 0.9686 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6529 - loss: 0.9610 - val_accuracy: 0.6461 - val_loss: 0.9782 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.6568 - loss: 0.9558 - val_accuracy: 0.6447 - val_loss: 0.9971 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6614 - loss: 0.9410 - val_accuracy: 0.6500 - val_loss: 0.9733 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6646 - loss: 0.9307 - val_accuracy: 0.6621 - val_loss: 0.9540 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.6714 - loss: 0.9190 - val_accuracy: 0.6663 - val_loss: 0.9394 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.6740 - loss: 0.9079 - val_accuracy: 0.6549 - val_loss: 0.9668 学習時間:134.6秒 test_accuracy:0.6517 === B_autoaugment === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 29s 41ms/step - accuracy: 0.2277 - loss: 2.0272 - val_accuracy: 0.3121 - val_loss: 1.8418 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 40s 42ms/step - accuracy: 0.3296 - loss: 1.8050 - val_accuracy: 0.3852 - val_loss: 1.6680 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 27s 43ms/step - accuracy: 0.3870 - loss: 1.6762 - val_accuracy: 0.4609 - val_loss: 1.4865 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 39s 40ms/step - accuracy: 0.4309 - loss: 1.5727 - val_accuracy: 0.4816 - val_loss: 1.4763 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 39ms/step - accuracy: 0.4597 - loss: 1.4955 - val_accuracy: 0.5195 - val_loss: 1.3439 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.4761 - loss: 1.4449 - val_accuracy: 0.5341 - val_loss: 1.3158 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 41s 38ms/step - accuracy: 0.4936 - loss: 1.4072 - val_accuracy: 0.5272 - val_loss: 1.3003 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 38ms/step - accuracy: 0.5023 - loss: 1.3723 - val_accuracy: 0.5487 - val_loss: 1.2752 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 23s 37ms/step - accuracy: 0.5113 - loss: 1.3463 - val_accuracy: 0.5567 - val_loss: 1.2183 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 23s 37ms/step - accuracy: 0.5242 - loss: 1.3182 - val_accuracy: 0.5592 - val_loss: 1.2186 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.5339 - loss: 1.2971 - val_accuracy: 0.5839 - val_loss: 1.1728 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 42s 41ms/step - accuracy: 0.5426 - loss: 1.2681 - val_accuracy: 0.5804 - val_loss: 1.1873 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 40s 39ms/step - accuracy: 0.5520 - loss: 1.2469 - val_accuracy: 0.6021 - val_loss: 1.1268 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 40s 37ms/step - accuracy: 0.5561 - loss: 1.2340 - val_accuracy: 0.5897 - val_loss: 1.1371 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 26s 41ms/step - accuracy: 0.5638 - loss: 1.2090 - val_accuracy: 0.5937 - val_loss: 1.1224 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 26s 41ms/step - accuracy: 0.5689 - loss: 1.1968 - val_accuracy: 0.6040 - val_loss: 1.1046 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 26s 40ms/step - accuracy: 0.5803 - loss: 1.1742 - val_accuracy: 0.6192 - val_loss: 1.0608 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 40ms/step - accuracy: 0.5795 - loss: 1.1654 - val_accuracy: 0.6357 - val_loss: 1.0207 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 38ms/step - accuracy: 0.5902 - loss: 1.1461 - val_accuracy: 0.6339 - val_loss: 1.0249 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 37ms/step - accuracy: 0.5940 - loss: 1.1364 - val_accuracy: 0.6162 - val_loss: 1.0542 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 38ms/step - accuracy: 0.5988 - loss: 1.1233 - val_accuracy: 0.6427 - val_loss: 0.9829 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.6025 - loss: 1.1106 - val_accuracy: 0.6330 - val_loss: 1.0132 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 40ms/step - accuracy: 0.6083 - loss: 1.1012 - val_accuracy: 0.6387 - val_loss: 0.9943 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.6086 - loss: 1.0973 - val_accuracy: 0.6518 - val_loss: 0.9795 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.6155 - loss: 1.0779 - val_accuracy: 0.6382 - val_loss: 1.0055 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 41s 40ms/step - accuracy: 0.6165 - loss: 1.0733 - val_accuracy: 0.6562 - val_loss: 0.9603 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 39ms/step - accuracy: 0.6242 - loss: 1.0559 - val_accuracy: 0.6473 - val_loss: 0.9781 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 40ms/step - accuracy: 0.6255 - loss: 1.0520 - val_accuracy: 0.6509 - val_loss: 0.9752 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 25s 39ms/step - accuracy: 0.6291 - loss: 1.0458 - val_accuracy: 0.6685 - val_loss: 0.9289 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 24s 38ms/step - accuracy: 0.6316 - loss: 1.0331 - val_accuracy: 0.6535 - val_loss: 0.9671 学習時間:855.5秒 test_accuracy:0.6494 === C_trivialaugment === Epoch 1/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 21s 28ms/step - accuracy: 0.2346 - loss: 2.0160 - val_accuracy: 0.3228 - val_loss: 1.8289 Epoch 2/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 18s 27ms/step - accuracy: 0.3391 - loss: 1.7923 - val_accuracy: 0.3997 - val_loss: 1.6350 Epoch 3/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 26ms/step - accuracy: 0.3972 - loss: 1.6423 - val_accuracy: 0.4741 - val_loss: 1.4512 Epoch 4/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.4355 - loss: 1.5526 - val_accuracy: 0.4808 - val_loss: 1.4357 Epoch 5/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 27ms/step - accuracy: 0.4545 - loss: 1.4934 - val_accuracy: 0.5077 - val_loss: 1.3468 Epoch 6/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.4756 - loss: 1.4489 - val_accuracy: 0.5310 - val_loss: 1.3133 Epoch 7/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.4856 - loss: 1.4136 - val_accuracy: 0.5386 - val_loss: 1.2842 Epoch 8/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 27ms/step - accuracy: 0.4998 - loss: 1.3806 - val_accuracy: 0.5423 - val_loss: 1.2990 Epoch 9/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 20s 26ms/step - accuracy: 0.5095 - loss: 1.3536 - val_accuracy: 0.5501 - val_loss: 1.2289 Epoch 10/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 26ms/step - accuracy: 0.5197 - loss: 1.3329 - val_accuracy: 0.5562 - val_loss: 1.2310 Epoch 11/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 26ms/step - accuracy: 0.5283 - loss: 1.3020 - val_accuracy: 0.5830 - val_loss: 1.1917 Epoch 12/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5386 - loss: 1.2819 - val_accuracy: 0.5830 - val_loss: 1.1589 Epoch 13/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5458 - loss: 1.2600 - val_accuracy: 0.5856 - val_loss: 1.1437 Epoch 14/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 26ms/step - accuracy: 0.5546 - loss: 1.2378 - val_accuracy: 0.5532 - val_loss: 1.2090 Epoch 15/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.5580 - loss: 1.2273 - val_accuracy: 0.5995 - val_loss: 1.1144 Epoch 16/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5659 - loss: 1.2060 - val_accuracy: 0.6169 - val_loss: 1.0742 Epoch 17/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 27ms/step - accuracy: 0.5736 - loss: 1.1929 - val_accuracy: 0.6269 - val_loss: 1.0382 Epoch 18/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 19s 25ms/step - accuracy: 0.5782 - loss: 1.1744 - val_accuracy: 0.6319 - val_loss: 1.0392 Epoch 19/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5824 - loss: 1.1657 - val_accuracy: 0.6388 - val_loss: 1.0173 Epoch 20/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5910 - loss: 1.1514 - val_accuracy: 0.6257 - val_loss: 1.0427 Epoch 21/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.5955 - loss: 1.1378 - val_accuracy: 0.6430 - val_loss: 0.9918 Epoch 22/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6018 - loss: 1.1183 - val_accuracy: 0.6437 - val_loss: 0.9908 Epoch 23/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 17s 27ms/step - accuracy: 0.6052 - loss: 1.1124 - val_accuracy: 0.6482 - val_loss: 0.9743 Epoch 24/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 19s 25ms/step - accuracy: 0.6078 - loss: 1.0988 - val_accuracy: 0.6554 - val_loss: 0.9640 Epoch 25/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 26ms/step - accuracy: 0.6103 - loss: 1.0915 - val_accuracy: 0.6634 - val_loss: 0.9454 Epoch 26/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6153 - loss: 1.0779 - val_accuracy: 0.6671 - val_loss: 0.9403 Epoch 27/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6197 - loss: 1.0637 - val_accuracy: 0.6644 - val_loss: 0.9508 Epoch 28/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 21s 26ms/step - accuracy: 0.6226 - loss: 1.0597 - val_accuracy: 0.6544 - val_loss: 0.9643 Epoch 29/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6272 - loss: 1.0496 - val_accuracy: 0.6830 - val_loss: 0.8974 Epoch 30/30 625/625 ━━━━━━━━━━━━━━━━━━━━ 16s 25ms/step - accuracy: 0.6336 - loss: 1.0313 - val_accuracy: 0.6548 - val_loss: 0.9598 学習時間:514.7秒 test_accuracy:0.6515
④ グラフ+サマリー出力
import pandas as pd
label_map = {
'none': 'なし',
'autoaugment': 'AutoAugment',
'trivialaugment': 'TrivialAugment',
}
# ── val_accuracy / val_loss 比較グラフ ─────────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for label, h in histories.items():
axes[0].plot(h.history['val_accuracy'], label=label_map[label])
axes[1].plot(h.history['val_loss'], label=label_map[label])
axes[0].set_title('val_accuracy の比較(全30エポック)')
axes[1].set_title('val_loss の比較(全30エポック)')
for ax in axes:
ax.set_xlabel('Epoch'); ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('augment_val.png', dpi=150)
plt.show()
# ── train_loss vs val_loss(過学習の乖離確認)──────────
fig2, axes2 = plt.subplots(1, 3, figsize=(18, 5))
for i, (label, h) in enumerate(histories.items()):
axes2[i].plot(h.history['loss'], label='train_loss')
axes2[i].plot(h.history['val_loss'], label='val_loss')
axes2[i].set_title(label_map[label])
axes2[i].set_xlabel('Epoch'); axes2[i].legend(); axes2[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('augment_overfit.png', dpi=150)
plt.show()
# ── 結果サマリー ────────────────────────────────────────
rows = []
for label in ['none', 'autoaugment', 'trivialaugment']:
val_acc = histories[label].history['val_accuracy'][-1]
test_acc = scores[label][1]
t = times[label]
rows.append({
'パターン': label_map[label],
'val_accuracy': f"{val_acc:.4f}",
'test_accuracy': f"{test_acc:.4f}",
'学習時間(s)': f"{t:.1f}",
})
df = pd.DataFrame(rows).sort_values('test_accuracy', ascending=False)
print(df.to_string(index=False))
最終結果サマリー
パターン val_accuracy test_accuracy 学習時間(s)
なし 0.6549 0.6517 134.6
TrivialAugment 0.6548 0.6515 514.7
AutoAugment 0.6535 0.6494 855.5
実験結果
精度グラフ
損失グラフ
なし
TrivialAugment
AutoAugment
| パターン | val_accuracy | test_accuracy | 学習時間 |
|---|---|---|---|
| なし | 0.6549 | 0.6517 | 134.6秒 |
| TrivialAugment | 0.6548 | 0.6515 | 514.7秒 |
| AutoAugment | 0.6535 | 0.6494 | 855.5秒 |
考察
① 精度差はほぼゼロ——3パターンが横並びの結果
最大精度差は 0.0023pt(0.23%) で、乱数シードを変えれば容易に順位が入れ替わる誤差範囲です。コードの修正前(Sharpness no-op・Equalize no-op・AutoContrast 固定値)と比べても結果は変わりませんでした。
つまり今回の条件下では、変換の実装精度より「そもそもAugmentationが有効かどうか」が支配的だったことになります。
② なぜ差が出なかったのか——CIFAR-10×30エポックの構造的な問題
このブログのAugmentation系実験では、CIFAR-10の小画像に対して高度なAugmentationが効きにくいケースが繰り返し観測されています。原因は2つです。
| 原因 | 内容 |
|---|---|
| 画像が小さすぎる | 32×32ピクセルでは色変換・鮮鋭化・ポスタリゼーションが画像の情報を破壊しやすく、有益な多様性より有害なノイズになりやすい |
| エポック数が少ない | Augmentationは学習を意図的に難しくする手法。30エポックでは多様な変換を十分に学習に活かしきれず、むしろ収束が遅れるだけになりやすい |
論文でAutoAugment・TrivialAugmentが効果を示した条件は、ImageNet(224×224)や数百エポックの学習です。スケール感が根本的に異なることが今回の結果に直結しています。
③ 学習時間のコストが見合わない
精度がほぼ同じにもかかわらず、学習時間は大きく異なりました。
| パターン | 学習時間 | なしとの比率 |
|---|---|---|
| なし | 134.6秒 | 1.0× |
| TrivialAugment | 514.7秒 | 3.8× |
| AutoAugment | 855.5秒 | 6.4× |
TrivialAugmentは tf.py_function(Equalize)の影響で約4倍、AutoAugmentは tf.cond を8段ネストするポリシー選択のオーバーヘッドでさらに遅くなっています。CIFAR-10×30エポックという条件では、このコストに見合う精度改善は得られませんでした。
④ それでも TrivialAugment を使う価値はあるか
今回の結果だけを見れば「CIFAR-10では不要」という結論になりますが、実務での判断は異なります。
| 条件 | TrivialAugmentの有効性 |
|---|---|
| エポック数を増やす(100エポック以上) | ◎ 多様な変換が学習に浸透しやすくなり精度向上が期待できる |
| ImageNet規模の画像(224×224以上) | ◎ 本来想定された条件。論文でも明確な効果が示されている |
| データ量が少ない(クラスあたり数百枚以下) | ○ 過学習防止として有効になりやすい |
| CIFAR-10×30エポック(今回) | △ 精度向上は期待しにくい。過学習が問題なら検討の余地あり |
⑤ Kerasで正確なAutoAugmentを実装したい場合
今回の実装は tf.image の制約から完全なAutoAugmentではありません。論文と同等の実装を使いたい場合は KerasCV が現実的な選択肢です。
# KerasCV を使った AutoAugment(参考)
# pip install keras-cv
import keras_cv
auto_aug = keras_cv.layers.AutoAugment(value_range=(0, 1))
# tf.data パイプラインで使用
ds = ds.map(lambda x, y: (auto_aug(x, training=True), y))
実務での推奨
| 状況 | 推奨 | 理由 |
|---|---|---|
| CIFAR-10規模・30エポック以下 | なし or フリップのみ | 高度なAugmentは精度改善なしにコストだけ増える |
| CIFAR-10で過学習が問題のとき | TrivialAugment(エポック数も増やす) | 汎化性能向上の余地がある |
| ImageNet規模・100エポック以上 | TrivialAugment | 実装シンプル・探索コストゼロで効果十分 |
| 精度を最大化したい本番モデル | KerasCV の AutoAugment | 完全実装済みポリシーで最高精度を狙える |
- CIFAR-10×30エポックでは3パターンの精度差は最大 0.23% で、誤差範囲内。高度なAugmentationの恩恵はほぼ得られなかった
- 原因は「32×32では変換が情報を破壊しやすい」「30エポックでは多様性を活かしきれない」の2点
- 学習時間はなし(135秒)に対して TrivialAugment が3.8倍、AutoAugment が6.4倍。精度が同等なら明らかにコスト過多
tf.image.adjust_sharpness()は現行TFに存在しない。ラプラシアンカーネル(tf.nn.conv2d)で代替実装できるEqualizeはtf.py_function経由のNumPy実装が必要。ただしこれが学習時間増加の一因になる- AutoAugment・TrivialAugmentの真価は 大きな画像・多いエポック数・少ないデータ量 の条件で発揮される
English Summary
We compared three augmentation strategies on CIFAR-10 (32×32) for 30 epochs: no augmentation, AutoAugment (CIFAR-10 policy), and TrivialAugment. The maximum accuracy difference was only 0.23%, well within noise range. The main finding is that on small images with few epochs, complex augmentation strategies offer no benefit while significantly increasing training time (3.8× for TrivialAugment, 6.4× for AutoAugment). AutoAugment and TrivialAugment show their true value with larger images, more epochs, or smaller datasets.
関連記事
- CutMix・MixUp比較 → CutMixとMixUpの効果を比較!なしと比べてどれが精度向上に効く?
- Random Erasing → CutOut / Random Erasingで過学習は防げるか?
- Augmentationの順番 → Augmentationと正規化、順番で精度は変わる?
- Data Augmentation重ねすぎ問題 → Data Augmentationを重ねすぎると精度が下がる?
- Gradient Clipping比較 → 勾配クリッピング(Gradient Clipping)で学習は安定する?






0 件のコメント:
コメントを投稿