Keras Mixed Precision(半精度学習)で高速化する方法|GPUで最大2〜3倍高速化

投稿日:2025年11月25日火曜日 最終更新日:

CIFAR-10 CNN Google Colab Keras Mixed Precision MNIST

X f B! P L
Eye-catching graphic for the Keras Mixed Precision training tutorial

Mixed Precision(混合精度)とは

Mixed Precision(混合精度) とは、 ディープラーニングの計算において、float32(単精度)float16(半精度) のデータ型を組み合わせて使用する技術です。

これにより、NVIDIA Tensor Coresなどの専用ハードウェアを活用し、 大規模モデルの学習速度を2〜3 倍高速化し、GPUメモリ使用量も削減できます。

Mixed Precision が有効な条件

  • NVIDIA GPU(RTX 20xx / 30xx / 40xx)
  • Google Colab(T4 / L4 / A100)
  • TensorFlow または Keras で学習する場合

Kerasでの設定方法と学習コード

設定方法(3つの必須ステップ)

KerasでMixed Precisionを安定して利用するには、以下の3ステップが必要です。

  1. グローバルポリシー設定:ほとんどの計算をfloat16にする。
  2. Loss Scalingの適用:勾配のアンダーフローを防ぐ。
  3. 出力層のdtype指定:最終出力の不安定化を防ぐ。
# 1. グローバルポリシーの設定
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

# 2. Loss Scalingの適用
optimizer = keras.optimizers.Adam()
optimizer = mixed_precision.LossScaleOptimizer(optimizer)

# 3. 出力層のdtype指定
model = keras.Sequential([
    # ... (中間層) ...
    layers.Dense(10, activation="softmax", dtype="float32") 
])

学習コード例(MNIST)

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

from tensorflow import keras
from tensorflow.keras import layers
from keras.datasets import mnist

# ===== Dataset =====
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# ===== Model =====
model = keras.Sequential([
    layers.Input(shape=(28, 28)),
    layers.Flatten(),
    layers.Dense(128, activation="relu"),
    layers.Dense(10, activation="softmax", dtype="float32")   # 出力は float32
])

# ===== Compile =====
optimizer = keras.optimizers.Adam()
optimizer = mixed_precision.LossScaleOptimizer(optimizer)

model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

# ===== Train =====
model.fit(
    x_train, y_train,
    epochs=3,
    validation_data=(x_test, y_test)
)

検証結果と性能分析

検証結果サマリー(T4 GPU使用)

純粋な学習速度を比較した結果は以下の通りです。

検証ケース float32 時間 (秒) mixed_float16 時間 (秒) 高速化比率 傾向
MNIST + Dense (小規模) 15.88秒 18.16秒 0.87x 遅延(オーバーヘッドが優位)
CIFAR-10 + CNN (中規模, Batch=1024) 57.73秒 54.86秒 1.05x わずかに高速化

なぜ小規模モデルは遅くなるのか?

MNIST + Dense モデルが遅くなったのは、以下の理由によるものです。

  • 計算量の少なさ: モデルが小さく、計算自体が非常に速いため、float16へのデータ型変換コストが、Tensor Coreによる高速化のメリットを上回ってしまいました。
  • 遅くなる典型パターン: 「小規模モデル」「Dense中心」「バッチサイズが小さい」の条件が揃うと、混合精度の恩恵を受けにくい傾向があります。

Mixed Precisionの注意点と活用条件

高速化が期待できるケース(活用条件)

以下の条件では、1.5〜3倍の明確な高速化が期待できます。

  • 計算負荷が高いモデル: CNN(畳み込み層)やTransformerなど、行列計算が多いモデル。
  • モデル規模が大きい。
  • バッチサイズ 64 以上(理想は1024など、GPUメモリの許す限り高負荷)。
  • 最新世代 GPU(T4, L4, A100, RTX 30/40番台など)。

注意点と不安定化の回避

  • CPU環境:ほぼ高速化しません。
  • 古いGPU:GTX 10xx世代など、Tensor Core非搭載のGPUでは非対応です。
  • 出力層:勾配の不安定化を防ぐため、必ず出力層は dtype="float32" に設定してください。
  • Loss Scaling:損失がNaNになる場合は、LossScaleOptimizerが正しく適用されているか確認してください。

まとめ

  • mixed_precision.set_global_policy("mixed_float16") で簡単に高速化できます。
  • GPUのTensor Coresを活用することで、大規模モデルは2〜3倍高速化できます。
  • 出力層は float32 にして安定化させます。
  • 効果を発揮するには、CNNやTransformerなど計算負荷の高いモデルで、大きなバッチサイズが必要です。

付録:比較検証で使用したソースコードと実行結果

MNIST + Dense (小規模)

import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.datasets import mnist
from tensorflow.keras import mixed_precision
from tensorflow.keras.mixed_precision import LossScaleOptimizer
import os

# Suppress warnings and logs for clean timing
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 


def build_model():
    """Defines a simple Sequential model for MNIST classification."""
    return keras.Sequential([
        layers.Input(shape=(28, 28)),
        layers.Flatten(),
        layers.Dense(128, activation="relu"),
        layers.Dense(10, activation="softmax")
    ])

def train_and_measure(x_train_data, y_train_data, x_test_data, y_test_data, policy="float32"):
    """
    Trains the model with a specified precision policy and measures the training time.
    
    Args:
        x_train_data, y_train_data, ...: Pre-loaded NumPy arrays.
        policy (str): The precision policy ("float32" or "mixed_float16").

    Returns:
        float: The training time in seconds.
    """
    # 1. Set the global mixed precision policy
    if policy == "mixed_float16":
        # Set to mixed_float16 for GPU/TPU speedup
        mixed_precision.set_global_policy("mixed_float16")
    else:
        # Set to standard float32 precision
        tf.keras.mixed_precision.set_global_policy("float32")

    # 2. Create model instance (must be done AFTER setting the policy)
    model = build_model()

    # 3. Configure optimizer, applying Loss Scaling if needed
    optimizer = keras.optimizers.Adam()
    # Apply Loss Scale Optimizer for mixed precision
    if policy == "mixed_float16":
        optimizer = LossScaleOptimizer(optimizer)

    model.compile(optimizer=optimizer,
                  loss="sparse_categorical_crossentropy",
                  metrics=["accuracy"])

    # 4. Start measurement and training
    start = time.time()
    # Pass pre-loaded data (I/O is outside of the timer)
    model.fit(x_train_data, y_train_data,
              epochs=3,
              batch_size=128,
              validation_data=(x_test_data, y_test_data),
              verbose=0) 
    end = time.time()

    # 5. Reset the policy to float32 after training
    tf.keras.mixed_precision.set_global_policy("float32")
    
    return end - start


# =================================================================
# ===== Execution and Comparison (I/O separated) =====
# =================================================================

# --- Data Loading and Preprocessing (OUTSIDE the timer) ---
print("--- Loading and pre-processing MNIST data (Outside of timer)... ---")
# Load data (download completes here)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Data normalization
x_train, x_test = x_train / 255.0, x_test / 255.0
print("--- Pre-processing complete. Measuring pure training time. ---")


# --- Measure Pure Training Time ---
print("\nRunning float32 (Standard Precision)...")
# Pass data as arguments
t1 = train_and_measure(x_train, y_train, x_test, y_test, "float32")
print(f"float32 training time: {t1:.2f} seconds")

print("\nRunning mixed_float16 (Accelerated Precision)...")
# Pass data as arguments
t2 = train_and_measure(x_train, y_train, x_test, y_test, "mixed_float16")
print(f"mixed_float16 training time: {t2:.2f} seconds")

# Calculate and display the speed-up ratio
print("\nSpeed-up Ratio:", f"{t1 / t2:.2f}x")

実行結果

--- Loading and pre-processing MNIST data (Outside of timer)... ---
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
--- Pre-processing complete. Measuring pure training time. ---

Running float32 (Standard Precision)...
float32 training time: 15.88 seconds

Running mixed_float16 (Accelerated Precision)...
mixed_float16 training time: 18.16 seconds

Speed-up Ratio: 0.87x

CIFAR-10 + CNN (中規模, Batch=1024)

Google ColabのT4 GPUを使用した検証を実施しました

import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.datasets import cifar10
from tensorflow.keras import mixed_precision
from tensorflow.keras.mixed_precision import LossScaleOptimizer
from tensorflow.keras.utils import to_categorical
import numpy as np
import os

# Suppress warnings and logs for clean timing
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

# --- Model Definition: CNN for CIFAR-10 (High Compute) ---
def build_cnn_model():
    """Defines a high-compute CNN model for CIFAR-10."""
    return keras.Sequential([
        layers.Input(shape=(32, 32, 3)),
        
        # Convolutional Block 1
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(32, (3, 3), activation='relu'),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Dropout(0.25),

        # Convolutional Block 2
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Dropout(0.25),
        
        # Dense Layers
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
        layers.Dropout(0.5),
        # Output layer for 10 classes
        layers.Dense(10, activation="softmax")
    ])


# --- Data Preprocessing (Outside of the timer) ---
print("--- Loading and pre-processing CIFAR-10 data... ---")
# Load data (download is completed here)
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize data to 0-1 range
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Convert labels to One-hot encoding
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

# Set batch size high for T4 GPU utilization
BATCH_SIZE = 1024
print(f"--- Pre-processing complete. Batch Size: {BATCH_SIZE} ---")


# --- Training and Measurement Function (NumPy Array Input) ---
def train_and_measure(x_train_data, y_train_data, x_test_data, y_test_data, policy="float32"):
    """
    Trains the CNN model using NumPy arrays with the specified precision policy and measures time.
    """
    
    # 1. Set global precision policy
    if policy == "mixed_float16":
        mixed_precision.set_global_policy("mixed_float16")
    else:
        tf.keras.mixed_precision.set_global_policy("float32")

    # 2. Build model and configure optimizer
    model = build_cnn_model()
    optimizer = keras.optimizers.Adam(learning_rate=0.0005)
    
    # Apply Loss Scale Optimizer for mixed precision
    if policy == "mixed_float16":
        optimizer = LossScaleOptimizer(optimizer)

    model.compile(optimizer=optimizer,
                  loss="categorical_crossentropy",
                  metrics=["accuracy"])

    # 3. Start timing and training
    start = time.time()
    # Pass NumPy arrays directly to model.fit
    model.fit(x_train_data, y_train_data,
              epochs=5,
              batch_size=BATCH_SIZE,
              validation_data=(x_test_data, y_test_data),
              verbose=0) # Suppress log output for accurate timing
    end = time.time()
    
    # 4. Reset policy to float32 after measurement
    tf.keras.mixed_precision.set_global_policy("float32")
    
    return end - start


# =================================================================
# ===== Execution and Comparison (Pure calculation test) =====
# =================================================================

# --- 1. Measure float32 time ---
print("\n" + "="*60)
print("--- Running float32 (Standard Precision) ---")
print("="*60)
# Pass pre-processed data to the function
t1 = train_and_measure(x_train, y_train, x_test, y_test, "float32")
print(f"float32 Pure Training Time: {t1:.2f} seconds")

# --- 2. Measure mixed_float16 time ---
print("\n" + "="*60)
print("--- Running mixed_float16 (Optimized Precision) ---")
print("="*60)
t2 = train_and_measure(x_train, y_train, x_test, y_test, "mixed_float16")
print(f"mixed_float16 Pure Training Time: {t2:.2f} seconds")

# --- 3. Final Comparison ---
print("\n" + "--- Final Comparison Result (Pure Calculation Focus) ---")
print(f"float32 Time: {t1:.2f} seconds")
print(f"mixed_float16 Time: {t2:.2f} seconds")
print(f"Speed-up Ratio: {t1 / t2:.2f}x")

実行結果

CIFAR-10 + CNNであれば、わずかに速くなりました。

--- Loading and pre-processing CIFAR-10 data... ---
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 15s 0us/step
--- Pre-processing complete. Batch Size: 1024 ---

============================================================
--- Running float32 (Standard Precision) ---
============================================================
float32 Pure Training Time: 57.73 seconds

============================================================
--- Running mixed_float16 (Optimized Precision) ---
============================================================
mixed_float16 Pure Training Time: 54.86 seconds

--- Final Comparison Result (Pure Calculation Focus) ---
float32 Time: 57.73 seconds
mixed_float16 Time: 54.86 seconds
Speed-up Ratio: 1.05x