手書き数字画像を使ってAIを再学習!【Keras×Colab 転移学習】

2025年5月11日日曜日

Google Colab Keras TensorFlow 画像分類 機械学習 転移学習

X f B! P L

はじめに

こんにちは、SHOUです!
今回は「自分で用意した手書き数字画像をAIに覚えさせたい!」というテーマで、KerasとGoogle Colabを使って転移学習を試してみました。以前の記事でMNISTデータセットを使った分類をしましたが、今回はさらに一歩進め、自前の画像を訓練データに追加してAIの精度をアップさせることを目指します。

画像の準備(Google Drive)

まずは自分の手書き画像をGoogle Driveにアップロードします。今回は「0」~「9」の数字をフォルダごとに分類して保存しました。

  1. MyDrive/digits/
  2. ├── 0/
  3. ├── 1/
  4. ├── 2/
  5. ├── ...
  6. ├── 9/

それぞれのフォルダ内に、該当する数字の画像をまとめておきます。
MNISTの形式とマッチしやすくするために、画像は28×28のグレースケール(白黒)に揃えました。

モデルの構築と事前学習

Colabで次のコードを実行して、Google Driveをマウントします。

  1. from google.colab import drive
  2. drive.mount('/content/drive')

次に、KerasでMNISTを使ったベースモデルを作成・訓練していきます。

  1. import tensorflow as tf
  2. from tensorflow import keras
  3. (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
  4. x_train = x_train / 255.0
  5. x_test = x_test / 255.0
  6. model = keras.Sequential([
  7. keras.Input(shape=(28, 28, 1)),
  8. keras.layers.Flatten(),
  9. keras.layers.Dense(128, activation='relu'),
  10. keras.layers.Dropout(0.2),
  11. keras.layers.Dense(10, activation='softmax')
  12. ])
  13. model.compile(optimizer='adam',
  14. loss='sparse_categorical_crossentropy',
  15. metrics=['accuracy'])
  16. # MNISTで事前学習
  17. model.fit(x_train, y_train, epochs=5)

自前データで再学習(転移学習)

ここから、Google Drive内の手書き画像を使って転移学習を行います。

これらのフォルダに数字ごとの画像が保存されていることを確認してみます👇。

  1. import matplotlib.pyplot as plt
  2. import cv2
  3. import os
  4. base_folder = '/content/drive/MyDrive/Colab/digits'
  5. digit_folders = [str(i) for i in range(10)]
  6. plt.figure(figsize=(10, 4)) # 横長
  7. for idx, digit in enumerate(digit_folders):
  8. folder_path = os.path.join(base_folder, digit)
  9. image_files = [f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
  10. if image_files:
  11. img_path = os.path.join(folder_path, image_files[0]) # 各フォルダの最初の画像を1枚だけ
  12. img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
  13. plt.subplot(3, 4, idx + 1)
  14. plt.imshow(img, cmap='gray')
  15. plt.title(f'Digit: {digit}', fontsize=10)
  16. plt.axis('off')
  17. plt.tight_layout()
  18. plt.show()

ポイント💡

  • base_folderのパスは、ご自分のGoogle Driveの構成に合わせて書き換えてください。
  • plt.subplot(3, 4, ...)と設定しているので、最大12枚分のマス目が作られます。今回は10枚(0〜9)なので、きれいに収まる感じです。

実行すると、↓のように数字画像がズラッと並んで確認できます:

転移学習を行います👇

  1. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  2. datagen = ImageDataGenerator(
  3. rescale=1./255 ,
  4. preprocessing_function=lambda x: 1.0 - x # 白黒反転
  5. )
  6. train_generator = datagen.flow_from_directory(
  7. '/content/drive/MyDrive/Colab/digits',
  8. target_size=(28, 28),
  9. color_mode='grayscale',
  10. batch_size=32,
  11. class_mode='sparse'
  12. )
  13. # 転移学習
  14. model.fit(train_generator, epochs=5)

Point💡
ImageDataGeneratorは、画像データの前処理を自動でやってくれる便利なクラスです。
・MNISTは黒地に白い数字のため、画像データ(白地に黒い数字)を反転させる必要があります。 ・flow_from_directoryは、フォルダ名をクラスラベルとして自動認識し、各画像にラベル付けしてくれるので便利です。

学習結果

  1. Found 10 images belonging to 10 classes.
  2. Epoch 1/5
  3. /usr/local/lib/python3.11/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your PyDataset class should call super().__init__(**kwargs) in its constructor. **kwargs can include workers, use_multiprocessing, max_queue_size. Do not pass these arguments to fit(), as they will be ignored.
  4. self._warn_if_super_not_called()
  5. 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 522ms/step - accuracy: 0.0000e+00 - loss: 52.3672
  6. Epoch 2/5
  7. 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - accuracy: 0.0000e+00 - loss: 35.1001
  8. Epoch 3/5
  9. 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 141ms/step - accuracy: 0.0000e+00 - loss: 25.2860
  10. Epoch 4/5
  11. 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 144ms/step - accuracy: 0.1000 - loss: 16.6126
  12. Epoch 5/5
  13. 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 75ms/step - accuracy: 0.0000e+00 - loss: 13.2358
  14. <keras.src.callbacks.history.History at 0x77feded21a50>

Point💡
学習データがたった10枚(=1クラス1枚)しかないため、
・各エポックでの損失(loss)は減少していますが、精度(accuracy)はほとんど上がらない
・ごく一部のタイミングで accuracy: 0.1(10%) という結果が出ていますが、これは10クラスの分類問題でランダムと同程度の精度であり、しっかり学習できていない状態です。

実際に判定してみる

Drive内にある以前誤認識した数字画像を読み込んで、どの数字か予測してみます。

  1. from tensorflow.keras.preprocessing import image
  2. import numpy as np
  3. img_path = '/content/drive/MyDrive/Colab/handwritten_digit.png'
  4. img = image.load_img(img_path, target_size=(28, 28), color_mode='grayscale')
  5. img_array = image.img_to_array(img)
  6. img_array = np.expand_dims(img_array, axis=0) / 255.0
  7. img_array = 1.0 - img_array # ←★ 追加:白黒反転
  8. prediction = model.predict(img_array)
  9. print('予測ラベル:', np.argmax(prediction))

実行結果

  1. 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step
  2. 予測ラベル: 8

私の場合、「4」の画像を使ったのですが、結果はなんと「8」と判定されました。
まだまだ学習データが少ないため、精度アップにはさらなるデータ投入が必要だと実感しました。

まとめ

今回は、自分で描いた手書き数字を使ってAIに再学習させる方法を試しました。
大きな流れとしては:

  • Google Driveに画像を準備
  • KerasのMNISTモデルをベースに構築
  • 自前データを追加して転移学習
  • 新しい画像を読み込んで判定

という手順でした。
精度を上げるためには、

  • 十分な量の画像データを準備する
  • 回転・ズームなどデータ拡張を活用する
  • より高性能なCNNモデルを使う

といった工夫が効果的なようです。

今回は、自分で描いた数字画像を使って転移学習を実施したため、学習データがまだ少ない段階でも、意外としっかりと認識されることがわかりました。 おそらく、自分で書いた数字は線の癖やスタイルが一貫しているため、AIにとっては「覚えやすい」特徴が強く表れているのだと思います。

もちろん、誤認識(例:4を3と判定するなど)する場合もありますが、追加データを増やす・データ拡張を行うことで、さらに精度アップが期待できそうです。

ぜひ皆さんも、自分のオリジナルな手書き数字でAIを育てる体験をしてみてください😊。

追記

事前学習の際に、以下の警告が出るため、コードの変更を実施した。
Kerasの新しいバージョンでは input_shapeを直接レイヤーに渡すのは非推奨 となり、代わりにInput()を使う方法を推奨しています。

警告

  1. /usr/local/lib/python3.11/dist-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an input_shape/input_dim argument to a layer. When using Sequential models, prefer using an Input(shape) object as the first layer in the model instead.
  2. super().__init__(**kwargs)

変更前

  1. model = keras.models.Sequential([
  2. keras.layers.Flatten(input_shape=(28, 28, 1)),
  3. keras.layers.Dense(128, activation='relu'),
  4. keras.layers.Dropout(0.2),
  5. keras.layers.Dense(10, activation='softmax')
  6. ])

変更後

  1. model = keras.Sequential([
  2. keras.Input(shape=(28, 28, 1)),
  3. keras.layers.Flatten(),
  4. keras.layers.Dense(128, activation='relu'),
  5. keras.layers.Dropout(0.2),
  6. keras.layers.Dense(10, activation='softmax')
  7. ])

このブログを検索

自己紹介

機械学習を学習中。虫も飼ってます。

お問い合わせ

名前

メール *

メッセージ *

プライバシーポリシー

QooQ