✅ モデル設計や実装レビューの相談はココナラ
へ(外部リンク)
CIFAR-10とは?
CIFAR-10は10種類のクラス(飛行機、自動車、鳥、猫、鹿、犬、カエル、馬、船、トラック)に分類された32×32ピクセルの小さなカラー画像60,000枚で構成されたデータセットです。 機械学習やディープラーニングの研究において、画像認識モデルの入門・比較・評価に使われる代表的な教材となっています。 特に、コンピュータビジョン分野の基礎を学ぶ際には欠かせない存在です。
CIFAR-10を題材に学習する際には、精度向上の工夫や実験の進め方に悩む人も多いです。本記事では実装コードに加え、過学習を防ぐ工夫や初心者がつまずきやすいポイントも解説していきます。基本的な手法はMNISTのCNN実装記事でも触れていますが、CIFAR-10はより難易度が高いため追加の工夫が必要です。
CIFAR-10の代表的な応用例
- 画像分類モデルの学習:CNNの構造比較や最新モデルのベンチマークに使用
- 転移学習の実験:ImageNetなどより大きなデータへの応用を見据えた研究
- ハイパーパラメータ調整:学習率・最適化手法の効果を検証する題材
- 教育教材:大学やオンライン講座での実習用データ
こうした応用は、研究や開発だけでなく、AIに初めて触れる学生やエンジニアが実際に手を動かして学ぶ際にも役立ちます。
CIFAR-10は研究用だけでなく、教材としても広く利用されています。大学の機械学習講義やオンライン講座の演習課題としても使われており、基礎から応用へ進むステップアップに最適です。私自身、最初はColabで学習を実行しましたが、途中でセッションが切れてしまい、学習結果が保存されず苦労しました。チェックポイントを保存する工夫を取り入れると、この問題を回避できます。
PythonとTensorFlowでの実装例
以下にCIFAR-10を用いた基本的なCNNモデルの実装例を示します。 Google Colabを利用すれば無料でGPUを使えるため、PCの性能に依存せず気軽に試すことができます。
ソースコード
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
# データセットの読み込み
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
# 正規化(0〜1にスケーリング)
x_train, x_test = x_train / 255.0, x_test / 255.0
# CNNモデル構築
model = models.Sequential([
layers.Conv2D(32, (3,3), activation='relu', input_shape=(32,32,3)),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64, (3,3), activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
])
# コンパイル
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 学習
history = model.fit(x_train, y_train, epochs=10,
validation_data=(x_test, y_test))
このコードを実行するだけで、CIFAR-10を用いた基本的な画像分類器を学習できます。 学習済みモデルの精度を確認したり、レイヤーを増やすことで性能改善を試みるなどの応用も可能です。
実行例
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 ━━━━━━━━━━━━━━━━━━━━ 13s 0us/step
Epoch 1/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 13s 6ms/step - accuracy: 0.3801 - loss: 1.6946 - val_accuracy: 0.5706 - val_loss: 1.2054
Epoch 2/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 14s 4ms/step - accuracy: 0.5993 - loss: 1.1415 - val_accuracy: 0.6386 - val_loss: 1.0430
Epoch 3/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - accuracy: 0.6580 - loss: 0.9763 - val_accuracy: 0.6572 - val_loss: 0.9830
Epoch 4/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.6941 - loss: 0.8835 - val_accuracy: 0.6713 - val_loss: 0.9615
Epoch 5/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.7127 - loss: 0.8229 - val_accuracy: 0.6849 - val_loss: 0.9206
Epoch 6/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.7374 - loss: 0.7558 - val_accuracy: 0.6826 - val_loss: 0.9277
Epoch 7/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.7528 - loss: 0.7072 - val_accuracy: 0.6804 - val_loss: 0.9615
Epoch 8/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 10s 7ms/step - accuracy: 0.7633 - loss: 0.6706 - val_accuracy: 0.6909 - val_loss: 0.9275
Epoch 9/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 17s 4ms/step - accuracy: 0.7822 - loss: 0.6250 - val_accuracy: 0.6995 - val_loss: 0.9142
Epoch 10/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 9s 4ms/step - accuracy: 0.7919 - loss: 0.5886 - val_accuracy: 0.6963 - val_loss: 0.9518
よくある課題と改善策
CIFAR-10でCNNを学習させると、10エポック程度では精度が60〜70%に留まることが多いです。これはデータが複雑なため、単純なモデルでは表現力が不足しているからです。そこで次のような工夫が役立ちます。
- データ拡張(回転・反転・ランダムクロップなど)
- Batch Normalizationによる安定化(解説記事はこちら)
- ドロップアウトで過学習を防止
- ResNetなどの深いモデルを利用
特に、Colab環境ではGPUリソースに限りがあるため、エポック数を増やすよりも「データ拡張」と「軽量なモデル設計」で効率よく学習を進めるのがおすすめです。
学習の評価方法
モデルを学習しただけでは性能が分かりません。そこで、テストデータで評価し、学習曲線を可視化して確認するのが一般的です。以下のコードを追加すると、精度や損失の推移をグラフで確認できます。
import matplotlib.pyplot as plt
import japanize_matplotlib
# テストデータで評価
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"テスト精度: {test_acc:.4f}, テスト損失: {test_loss:.4f}")
# 学習過程の可視化
plt.figure(figsize=(12,4))
# 精度の推移
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='訓練精度')
plt.plot(history.history['val_accuracy'], label='検証精度')
plt.title('精度の推移')
plt.xlabel('エポック')
plt.ylabel('精度')
plt.legend()
# 損失の推移
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='訓練損失')
plt.plot(history.history['val_loss'], label='検証損失')
plt.title('損失の推移')
plt.xlabel('エポック')
plt.ylabel('損失')
plt.legend()
plt.show()
上記のコードでは、学習過程の精度と損失をグラフ化して表示します。訓練データと検証データの推移を比較することで、過学習の兆候を見つけたり、学習率やエポック数を調整する判断材料にできます。
実行結果
313/313 - 1s - 5ms/step - accuracy: 0.6859 - loss: 0.9342 テスト精度: 0.6859, テスト損失: 0.9342
学習曲線を確認することで、モデルが正しく学習しているかを判断できます。例えば「訓練精度は上がるのに検証精度が伸びない」といった場合は、過学習が起きているサインです。このような場合はデータ拡張や正則化を強化することで改善が期待できます。
実践的な応用と発展
学習したモデルは、Webアプリに組み込んだり、IoTデバイスに搭載してリアルタイム分類を行うなど、さまざまな応用が考えられます。 また、転移学習を用いることで、医療画像や産業検査などの実データにも応用可能です。
関連記事(内部リンク)
まとめ
CIFAR-10は画像認識の入門に最適なデータセットであり、PythonとTensorFlowを用いれば誰でも簡単に実装できます。 教育や研究に加え、実務での応用にも繋がる基礎的な力を身につけることができます。 まずはColabでコードを実行し、精度改善やモデル改良に挑戦してみましょう。
追加
学習曲線を描画する際に日本語表記が文字化けしないようにする。
!pip install japanize_matplotlib Collecting japanize_matplotlib Downloading japanize-matplotlib-1.1.3.tar.gz (4.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.1/4.1 MB 73.9 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from japanize_matplotlib) (3.10.0) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->japanize_matplotlib) (1.3.3) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->japanize_matplotlib) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->japanize_matplotlib) (4.59.2) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->japanize_matplotlib) (1.4.9) Requirement already satisfied: numpy>=1.23 in /usr/local/lib/python3.12/dist-packages (from matplotlib->japanize_matplotlib) (2.0.2) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->japanize_matplotlib) (25.0) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib->japanize_matplotlib) (11.3.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->japanize_matplotlib) (3.2.3) Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib->japanize_matplotlib) (2.9.0.post0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib->japanize_matplotlib) (1.17.0) Building wheels for collected packages: japanize_matplotlib Building wheel for japanize_matplotlib (setup.py) ... done Created wheel for japanize_matplotlib: filename=japanize_matplotlib-1.1.3-py3-none-any.whl size=4120257 sha256=d30c4950e8d5959952ef787df1c8838ba80a6109fed21452db5fc5c152f7a512 Stored in directory: /root/.cache/pip/wheels/c1/f7/9b/418f19a7b9340fc16e071e89efc379aca68d40238b258df53d Successfully built japanize_matplotlib Installing collected packages: japanize_matplotlib Successfully installed japanize_matplotlib-1.1.3
0 件のコメント:
コメントを投稿