TensorFlow.jsで事前学習済みモデルを使ったカスタム画像分類(転移学習)を実装する
はじめに
本記事では、TensorFlow.js(以下、TF.js)を使用して、事前学習済みの畳み込みニューラルネットワーク(CNN)モデルを基にしたカスタム画像分類モデルを構築する手法、すなわち転移学習の実装方法について解説します。
読者の皆様は、Python環境でTensorFlowやKerasを用いて機械学習モデルを開発された経験をお持ちのことと存じます。特に画像認識タスクにおいて、ImageNetなどの大規模データセットで学習済みのモデル(例: MobileNet, VGG, ResNet)の、強力な特徴抽出能力を活用する転移学習は、限られたデータで高い精度を実現する非常に効果的な手法です。
TF.jsを用いることで、このような強力な画像認識モデルをWebブラウザやNode.js環境で動作させることが可能になります。Pythonで培った機械学習の知識を活かしつつ、クライアントサイドやサーバーサイドJavaScript環境で動作するカスタム画像分類モデルを開発したいというニーズに対し、本記事が具体的な解決策を提供します。
この記事では、TF.jsを使った転移学習の基本的な考え方から、具体的なコード例、そして実装上の注意点までを詳細に解説します。
転移学習の概念とTensorFlow.jsでのアプローチ
転移学習とは、あるタスクで学習済みのモデルを別の関連タスクに応用する手法です。画像認識の分野では、ImageNetのような汎用的なデータセットで学習されたモデルが持つ、汎用的な特徴抽出能力(エッジ、テクスチャ、形状などを捉える能力)を、新しい特定の分類タスクに「転移」させます。
具体的な手順としては、以下のようになります。
- 事前学習済みモデルの利用: 大規模データセットで学習済みのモデル(ベースモデル)をロードします。このモデルは、画像の特徴を効果的に捉えるための多数の層(畳み込み層、プーリング層など)で構成されています。
- 層の凍結(Freeze): ベースモデルの大部分の層は、汎用的な特徴を抽出するために十分に学習されていると考えられます。これらの層の重みを、新しいタスクの学習中に変化させないように「凍結」します。これにより、事前学習で得られた知識を保持しつつ、学習計算コストも削減できます。
- 新しい分類層の追加: ベースモデルの最終層(通常、元のタスクの分類を行う全結合層)を取り除き、新しいタスクのクラス数に応じた新しい分類層(全結合層など)を追加します。
- 新しい層のみの学習: 新しいタスクのデータを用いて、追加した新しい分類層の重みのみを学習させます。ベースモデルの凍結された層は、特徴抽出器として機能します。
- (オプション)ファインチューニング: 新しい層がある程度学習された後、ベースモデルの一部の層(通常は出力層に近い層)の凍結を解除し、新しい層と共に非常に小さな学習率で再学習させることがあります。これにより、ベースモデルの特徴抽出器を新しいタスクにより適合させることができます。
PythonのKerasでは、model.trainable = False
として層を凍結し、model.add()
などで新しい層を追加し、model.compile()
、model.fit()
といったおなじみのAPIを使って転移学習を実装します。TF.jsでも、Keras APIライクな@tensorflow/tfjs-layers
を使用することで、同様の直感的な方法で転移学習を実装することが可能です。
TF.jsでの転移学習実装コード例(ブラウザ環境)
ここでは、Webブラウザ環境でMobileNetV2をベースモデルとして使用し、新しいカスタム分類タスクを行う簡単な転移学習のコード例を示します。
まず、HTMLファイルでTF.jsライブラリを読み込みます。
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<title>TF.js Transfer Learning Example</title>
<!-- TF.jsライブラリとMobileNetモデルのCDNを読み込み -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@2.1.0"></script>
<script src="script.js"></script>
</head>
<body>
<h1>TensorFlow.js 転移学習サンプル</h1>
<div id="status">モデルロード中...</div>
<!-- 画像を表示する要素とボタン -->
<img id="sample-img" src="path/to/your/sample_image.jpg" style="display:none;">
<button id="train-button" disabled>学習開始</button>
<button id="predict-button" disabled>予測実行</button>
<div id="result"></div>
</body>
</html>
次に、転移学習の中核となるJavaScriptコード(script.js
)を記述します。
let mobilenet; // ロードしたMobileNetモデルを保持する変数
let model; // 構築したカスタムモデルを保持する変数
let classCount = 2; // 新しいタスクのクラス数 (例: 猫と犬の2クラス分類)
let IMAGE_SIZE = 224; // モデルが期待する画像サイズ
// ステータス表示を更新する関数
function updateStatus(message) {
document.getElementById('status').innerText = message;
}
// ページのロードが完了したら実行
window.onload = async () => {
updateStatus('MobileNetモデルをロード中...');
// MobileNetV2モデルをロード。バージョン2を使用し、分類層を含まない設定にする
mobilenet = await mobilenet.load({version: 2, truncated: true});
updateStatus('MobileNetモデルのロード完了。カスタムモデルを構築します。');
// MobileNetの出力層の手前(特徴抽出層の最後)を取得
// MobileNetV2の場合、`conv_pw_13_relu` が特徴抽出の最終層に近い層です
// モデル構造を確認し、適切な層を選択してください
const layer = mobilenet.model.getLayer('conv_pw_13_relu');
// 新しいモデル(Keras Sequentialモデルに相当)を構築
// 特徴抽出層としてMobileNetの一部を使用し、その後に新しい分類層を追加
model = tf.sequential({
layers: [
// MobileNetの特徴抽出層を新しいモデルに追加
// inputShapeは、特徴抽出層の出力形状に合わせる必要があります
tf.layers.reShape({targetShape: [layer.outputShape[1], layer.outputShape[2], layer.outputShape[3]]}), // MobileNetの出力形状をフラット化しない形式にする
tf.layers.conv2d({
// 特徴抽出層を模倣する Conv2D レイヤーを直接追加する例。
// MobileNetの層を直接使う場合は以下のコメントアウトしたコードを参考に。
// MobileNetの特定の層までを取り出して新しいモデルの層として使うのがより一般的です。
// 以下のように、MobileNetのmodelプロパティから層を取得して使用することもできますが、
// モデルの構造理解が必要です。
// model.add(tf.layers.model({inputs: mobilenet.model.input, outputs: mobilenet.model.getLayer('conv_pw_13_relu').output}));
// ここでは、MobileNetの特徴抽出部分全体を新しいモデルの最初の層として利用します。
// このレイヤーのinputShapeは、MobileNetが受け取る画像の形状(IMAGE_SIZE, IMAGE_SIZE, 3)です。
// 出力は MobileNet の getLayer('conv_pw_13_relu').outputShape になります。
// しかし、Sequentialモデルに他のモデルの層を直接追加する簡単なAPIは提供されていません。
// より一般的なアプローチは、Functional APIを使うか、MobileNetの特徴抽出部分を
// 利用するカスタムモデルを作成することです。
// シンプルな例として、MobileNetの特徴抽出器を独立して使用し、その出力を
// 新しいSequentialモデルの入力として扱う方法を示します。
// 実際には、以下のようにMobileNetのmodelプロパティをFunctional modelとして扱い、
// 特定の出力層を持つ新しいモデルを構築することが多いです。
// const baseModel = tf.model({
// inputs: mobilenet.model.input,
// outputs: mobilenet.model.getLayer('conv_pw_13_relu').output
// });
// baseModel.trainable = false; // ベースモデルの重みを凍結
// model = tf.sequential();
// model.add(baseModel); // ベースモデルを層として追加 (これはFunctional Modelの組み込み方法の例です)
// model.add(tf.layers.flatten()); // 特徴マップをフラット化
// model.add(tf.layers.dense({ units: classCount, activation: 'softmax' })); // 新しい分類層
// **より簡単なアプローチ(特徴ベクトルを再利用):**
// MobileNetから特徴ベクトルを抽出し、その特徴ベクトルを入力とする新しいモデルを別途学習させる方法も一般的です。
// この方法では、MobileNetモデル自体を直接新しいモデルに組み込む必要はありません。
// まずMobileNetで各画像の特徴ベクトルを取得し、その特徴ベクトルとラベルのペアで
// 新しい小さなモデル(例: 全結合層のみのモデル)を学習させます。
// このアプローチのコード例を示します。
// 新しいモデル(特徴ベクトルを入力とするモデル)を定義
// MobileNetV2の 'conv_pw_13_relu' 層の出力形状を取得
const featureMapShape = layer.outputShape.slice(1); // バッチサイズを除く形状
const flattenedFeatureSize = featureMapShape[0] * featureMapShape[1] * featureMapShape[2]; // フラット化後のサイズ
model = tf.sequential({
layers: [
tf.layers.flatten({inputShape: featureMapShape}), // MobileNetの特徴マップをフラット化
tf.layers.dense({ units: 128, activation: 'relu' }), // 中間層 (オプション)
tf.layers.dense({ units: classCount, activation: 'softmax' }) // 新しい分類層
]
});
// モデルをコンパイル
model.compile({
optimizer: tf.train.adam(), // または sgd, rmsprop など
loss: 'categoricalCrossentropy', // 分類タスクの場合
metrics: ['accuracy']
});
updateStatus('カスタムモデルの構築完了。学習データ準備へ。');
document.getElementById('train-button').disabled = false; // 学習ボタンを有効化
}) // ここまでが MobileNet を使った特徴抽出器としての部分の考え方
]
});
// 上記コードはMobileNetの特徴抽出層を新しいモデルに組み込む方法の難しさを示唆しています。
// 実践的には、MobileNetから特徴ベクトルを抽出し、その特徴ベクトルに対して別途学習を行うのがシンプルです。
// 以下のコードは、画像からMobileNetで特徴ベクトルを抽出し、それを使って新しいモデルを学習・予測する流れを示します。
// --- 実践的な転移学習のアプローチ(特徴抽出+別途学習) ---
// 新しいモデル(特徴ベクトルを入力とするモデル)の定義は上記で完了しています。
// MobileNetモデル自体は特徴抽出のために使います。
mobilenet = await mobilenet.load({version: 2, truncated: true}); // 分類層を含まないMobileNetv2ロード
// 特徴抽出に使うMobileNetのモデル部分を取得
// 'conv_pw_13_relu' 層までの出力を得るFunctional Modelを作成
const featureExtractor = tf.model({
inputs: mobilenet.model.input,
outputs: mobilenet.model.getLayer('conv_pw_13_relu').output
});
// 特徴抽出器は学習させないので凍結
featureExtractor.trainable = false;
// 新しい分類モデル(特徴ベクトルを入力とするSequentialモデル)の定義は上記で完了しています。
// こちらのモデルを学習させます。
// ダミーデータ生成 (本来は実際の画像データとラベルを使用)
// 実際のデータは画像ファイルやWebカメラから取得し、前処理が必要です。
// ここでは概念を示すためのダミーとして、特徴ベクトルを直接生成します。
// MobileNetV2の 'conv_pw_13_relu' 層の出力形状は (1, 7, 7, 1280) です。
// バッチサイズ1で、高さ7, 幅7, チャンネル数1280のテンソルが出力されます。
const dummyFeature = tf.zeros([1, 7, 7, 1280]);
const dummyLabel = tf.tensor2d([[1, 0]]); // クラス0のOne-hotラベル
// ダミーデータセットを作成(本来は画像データとラベルから特徴ベクトルを抽出し、データセット化)
const dummyFeaturesTensor = tf.zeros([10, 7, 7, 1280]); // 10個のダミー特徴ベクトル
const dummyLabelsTensor = tf.zeros([10, classCount]); // 10個のダミーラベル (One-hot)
// TODO: 実際のデータローディングと前処理のコードをここに記述する必要があります。
// 学習ボタンクリック時の処理
document.getElementById('train-button').onclick = async () => {
updateStatus('学習を開始します...');
// TODO: 実際の学習データ (features, labels) を用意する必要があります
// features: tf.Tensor (shape: [num_samples, 7, 7, 1280])
// labels: tf.Tensor (shape: [num_samples, classCount])
// ダミーデータで学習を実行 (実際は用意したデータで実行)
const history = await model.fit(dummyFeaturesTensor, dummyLabelsTensor, {
epochs: 10, // エポック数
batchSize: 5, // バッチサイズ
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch}: loss = ${logs.loss.toFixed(4)}, accuracy = ${logs.acc.toFixed(4)}`);
updateStatus(`Epoch ${epoch} 完了: Accuracy = ${logs.acc.toFixed(4)}`);
}
}
});
updateStatus('学習完了。');
document.getElementById('predict-button').disabled = false; // 予測ボタンを有効化
};
// 予測ボタンクリック時の処理
document.getElementById('predict-button').onclick = async () => {
updateStatus('予測を実行します...');
// 予測したい画像データをロードし、前処理します。
const imgElement = document.getElementById('sample-img');
imgElement.style.display = 'block'; // 画像を表示 (パスを適切なものに変更)
// 画像をTensorに変換し、リサイズ、正規化
const imageTensor = tf.browser.fromPixels(imgElement)
.resizeNearestNeighbor([IMAGE_SIZE, IMAGE_SIZE]) // リサイズ
.toFloat() // float型に変換
.expandDims(0); // バッチ次元を追加 (shape: [1, IMAGE_SIZE, IMAGE_SIZE, 3])
// MobileNetの特徴抽出器を使って特徴ベクトルを取得
const imageFeatures = featureExtractor.predict(imageTensor);
// 構築した分類モデルで予測を実行
const prediction = model.predict(imageFeatures);
// 予測結果(確率分布)を取得
const probabilities = await prediction.data();
const predictedClassIndex = prediction.argMax(-1).dataSync()[0]; // 最も確率が高いクラスのインデックス
updateStatus(`予測結果: クラス ${predictedClassIndex} (確率: ${probabilities[predictedClassIndex].toFixed(4)})`);
document.getElementById('result').innerText = `予測クラス: ${predictedClassIndex}, 確率: ${probabilities[predictedClassIndex].toFixed(4)}`;
// メモリ解放
imageTensor.dispose();
imageFeatures.dispose();
prediction.dispose();
};
};
// TODO: 実際の学習データ (画像とラベル) を読み込み、前処理を行い、
// MobileNetで特徴ベクトルを抽出して tf.Tensor にまとめる処理を実装する必要があります。
// 例: 画像ファイルをFileReaderやFetch APIで読み込み、tf.browser.fromPixelsでTensor化、
// featureExtractor.predict() で特徴ベクトルを取得、それらを結合して学習用データセットとする。
上記のコードは、概念を示すための骨組みであり、実際のデータローディングと前処理、およびMobileNetで抽出した特徴ベクトルをまとめて学習データとする部分(TODOコメント箇所)は実装する必要があります。
コードのポイントは以下の通りです。
mobilenet.load({version: 2, truncated: true})
: MobileNetV2モデルをロードします。truncated: true
とすることで、ImageNet分類用の最終層を含まないモデルを取得し、特徴抽出器として利用しやすくします。tf.model({inputs: mobilenet.model.input, outputs: mobilenet.model.getLayer('conv_pw_13_relu').output})
: ロードしたMobileNetモデルの一部('conv_pw_13_relu'層まで)を新しいFunctional Modelとして取得します。これが特徴抽出器となります。featureExtractor.trainable = false;
: 特徴抽出器の重みを学習中に更新しないように凍結します。tf.sequential([...])
: 特徴抽出器の出力形状を入力とする新しいSequentialモデルを定義します。このモデルは、フラット化層と新しい全結合(Dense)分類層から構成されます。model.compile(...)
: 新しいモデルを学習用にコンパイルします。オプティマイザ、損失関数、評価指標を指定します。featureExtractor.predict(imageTensor)
: 入力画像から特徴抽出器を用いて特徴ベクトルを取得します。model.fit(features, labels, ...)
: 特徴ベクトルと対応するラベルを用いて、新しい分類モデルを学習させます。model.predict(imageFeatures)
: 抽出した特徴ベクトルを入力として、学習済み分類モデルで予測を行います。
Python/Kerasの経験をお持ちであれば、tf.layers.dense
やmodel.compile
, model.fit
といったAPIがKerasのものと非常によく似ていることに気づかれるでしょう。TF.js Layers APIは、Keras APIを強く意識して設計されており、Pythonでの経験をスムーズに活かすことができます。
PythonのKerasとの対応と違い
TF.js Layers APIは、PythonのKerasと多くの共通点を持っていますが、JavaScript環境特有の考慮事項も存在します。
| 機能・概念 | Python (Keras) | JavaScript (TF.js Layers) | 備考 |
| :-------------------- | :-------------------------------------------------- | :-------------------------------------------------------- | :------------------------------------------------------------------- |
| モデル定義 | tf.keras.Sequential
, tf.keras.Model
| tf.sequential
, tf.model
| API名が若干異なりますが、概念はほぼ同じです。 |
| 層の追加 | model.add(...)
| tf.sequential({ layers: [...] })
または model.add(...)
| Sequentialモデルの定義時にlayers配列でまとめて指定するのが一般的です。 |
| 事前学習済みモデル | tf.keras.applications.MobileNetV2(...)
| @tensorflow-models/mobilenet
などのライブラリを使用 | TF.jsでは別途モデルライブラリとして提供されることが多いです。 |
| 層の凍結 | layer.trainable = False
| layer.trainable = false
| プロパティ名は同じです。 |
| モデルのコンパイル | model.compile(...)
| model.compile(...)
| API名は同じです。 |
| 学習の実行 | model.fit(...)
| await model.fit(...)
| 非同期処理のためawait
が必要です。callbacks
も指定可能です。 |
| 予測の実行 | model.predict(...)
| model.predict(...)
| tf.Tensor
を返します。 |
| データの表現 | numpy.ndarray
, tf.Tensor
| tf.Tensor
| データはtf.Tensor
で扱います。NumPyは直接使えません。 |
| 画像の前処理 | OpenCV, PIL, tf.image
など | tf.browser.fromPixels
, tf.image
(部分的に利用可能) | ブラウザのCanvas APIやTF.jsの画像関連APIを使用します。 |
| オプティマイザ | tf.keras.optimizers.Adam
など | tf.train.adam
など (tf.optimizers
も利用可能) | tf.train
またはtf.optimizers
から取得します。 |
| 損失関数・評価指標 | tf.keras.losses.CategoricalCrossentropy
など | tf.losses.categoricalCrossentropy
など (tf.metrics
も利用可能) | tf.losses
またはtf.metrics
から取得します。 |
| 非同期処理 | 基本的に同期 | 多くのAPIが非同期 (Promiseを返す) | async
/await
の使用が必須です。 |
| ハードウェアアクセラレーション | GPU (CUDA/cuDNN) | WebGL, WebGPU (実験的), WASM | ブラウザやNode.js環境で利用可能なバックエンドが異なります。 |
| メモリ管理 | ガベージコレクション任せ | 明示的なメモリ解放(tensor.dispose()
)が推奨される場合あり | 大規模なテンソルを扱う際は注意が必要です。 |
大きな違いは、JavaScriptの非同期処理と、TF.js独自のデータ読み込み・前処理API(特にブラウザ環境)です。また、Pythonの広範な画像処理ライブラリ(OpenCV, Pillowなど)はそのまま利用できないため、TF.jsやCanvas APIを用いた代替手段を習得する必要があります。しかし、モデルの構築、コンパイル、学習、予測といった核心的なAPIはKerasと類似しているため、概念的なハードルは低いと考えられます。
実践的な考慮事項
TF.jsで転移学習を実装する際に考慮すべき点をいくつか挙げます。
- データ準備と前処理: TF.jsで画像を扱う場合、ブラウザ環境であれば
tf.browser.fromPixels
、Node.js環境であればファイルシステムから画像を読み込みTensorに変換する必要があります。また、ベースモデルが要求する入力サイズへのリサイズ、ピクセル値のスケーリング(例: 0-255を0-1や-1-1へ正規化)など、適切な前処理が不可欠です。Pythonのデータ拡張ライブラリのような機能は直接利用できませんが、TF.jsでも同様の画像変換操作を行うことは可能です。 - 特徴ベクトルの抽出: 効率的な転移学習のためには、学習データ全体の画像に対して一度だけベースモデルで特徴ベクトルを抽出し、その特徴ベクトルとラベルのペアで新しい分類モデルを学習させるのが一般的です。これにより、エポックごとにベースモデルを順伝播させる必要がなくなり、学習時間を大幅に短縮できます。抽出した特徴ベクトルはメモリ上に保持するか、ファイルなどに保存して再利用できるようにすると良いでしょう。
- ハードウェアバックエンドの選択とパフォーマンス: TF.jsはデフォルトで利用可能な最適なバックエンド(WebGL, WASM, CPUなど)を選択しますが、特定のタスクや環境では明示的に指定することでパフォーマンスが向上する場合があります。画像認識のような計算負荷の高いタスクでは、WebGLやWebGPUバックエンドがGPUを活用するため高速です。
tf.setBackend('webgl')
のように設定できます。 - メモリ管理: 特にブラウザ環境では、不要になった
tf.Tensor
オブジェクトがメモリリークの原因となることがあります。tensor.dispose()
メソッドを呼び出すか、tf.tidy()
関数を使用してテンソルを自動的に解放することを意識してください。 - モデルの保存とロード: 学習したカスタムモデルは、TF.js独自の形式で保存・ロードできます。
model.save()
メソッドでIndexedDBやローカルストレージ、サーバーなどに保存し、tf.loadLayersModel()
などでロードして再利用できます。これは、一度学習したモデルをユーザーのブラウザに保存しておき、オフラインでも利用可能にするなどの応用が考えられます。
まとめ
本記事では、TensorFlow.jsを用いた転移学習によるカスタム画像分類モデルの実装方法について解説しました。Pythonでの転移学習の概念を基に、TF.jsのLayers APIを使った具体的なコード例と、Python版Kerasとの対応関係、そして実践的な開発における注意点を示しました。
TF.jsを活用することで、Pythonで培った機械学習の知識を活かしつつ、WebブラウザやNode.jsといったJavaScript環境で動作する強力な画像認識アプリケーションを開発することが可能です。事前学習済みモデルの特徴抽出能力を利用する転移学習は、特にデータが限られている場合に有効な手法であり、Webサービスやエッジデバイス上での画像認識機能実装において強力な選択肢となります。
今後のステップとして、以下の点を検討されると良いでしょう。
- 実際に独自のカスタム画像データセットを用意し、本記事で示したコードを基に転移学習を実行してみる。
- データの読み込み、前処理、特徴ベクトル抽出を効率的に行うためのコードを実装する。
- ブラウザ環境だけでなく、Node.js環境での実装や、Webカメラからのリアルタイム予測に挑戦してみる。
- モデルのパフォーマンスチューニングや、より複雑なモデル構造(例: ファインチューニング)の実装について学ぶ。
TF.jsは進化し続けており、画像認識AI開発における可能性を広げています。本記事が、皆様のTF.jsを使った実践的なAI開発の一助となれば幸いです。