TensorFlow.jsにおける画像推論のバッチ処理:パフォーマンス向上と実装パターン
TensorFlow.jsを用いた画像認識AIの開発において、単一の画像に対する推論だけでなく、複数の画像をまとめて処理したいケースは少なくありません。例えば、画像データセット全体に対する推論結果を取得したり、リアルタイム処理ではないものの、短時間で多くの画像を処理したいといった場面です。
このような場合、画像を一枚ずつモデルに入力して推論を行うよりも、複数の画像をまとめて「バッチ」として処理する方が、多くの場合で高いパフォーマンスを得ることができます。この記事では、TensorFlow.jsで画像推論を行う際にバッチ処理を適用する方法について、そのメリットや具体的なコード例、Pythonでの機械学習開発経験者が理解しやすいような解説を交えながらご紹介します。
画像推論におけるバッチ処理のメリット
画像認識モデルは、通常、入力として特定の形状を持ったテンソルを期待します。単一の画像の場合、この形状は [height, width, channels]
となります。しかし、多くの機械学習フレームワークやモデルは、入力テンソルの先頭に「バッチサイズ」の次元を持つことを前提として設計されています。つまり、入力形状は [batch_size, height, width, channels]
となります。
PythonでTensorFlowやKerasを使った開発に慣れている方であれば、モデルにデータを渡す際に、単一の画像でも tf.expand_dims(image, axis=0)
のようにしてバッチ次元を追加した経験があるかと思います。これは、モデルが [1, height, width, channels]
という形状の入力を期待しているためです。
複数の画像をまとめてバッチとして処理する主なメリットは、ハードウェアリソース、特にGPUの並列処理能力を効率的に利用できる点にあります。GPUは大量の計算を並列で実行するのに特化しており、バッチサイズが大きいほど、この並列処理の恩恵を最大限に受けやすくなります。結果として、全体のスループット(単位時間あたりに処理できる画像の枚数)が向上し、複数の画像を処理する合計時間を短縮できます。
これは、Webブラウザ環境やNode.js環境でTensorFlow.jsを使用する場合も同様です。GPUバックエンド(WebGLやWebGPUなど)が利用可能な環境であれば、バッチ処理によるパフォーマンス向上が期待できます。
TensorFlow.jsでのバッチ入力テンソルの準備
TensorFlow.jsでバッチ処理を行うためには、複数の画像データを一つのバッチ入力テンソル [batch_size, height, width, channels]
として準備する必要があります。基本的な流れは以下のようになります。
- 複数の画像データを読み込む。
- 各画像に対して、モデルが期待するサイズへのリサイズや正規化などの前処理を適用し、個別のテンソル
[height, width, channels]
に変換する。 - 変換された複数のテンソルを、バッチ次元で結合して一つのテンソル
[batch_size, height, width, channels]
を作成する。
ここでは、ステップ3の「複数のテンソルを結合してバッチテンソルを作成する」方法に焦点を当てます。TensorFlow.jsには、複数のテンソルを指定した軸で結合するための tf.stack()
という便利な関数があります。
例えば、3つの画像からそれぞれ個別のテンソル imageTensor1
, imageTensor2
, imageTensor3
([height, width, channels]
形状) が作成できたとします。これらをバッチテンソルとして結合するには、以下のように tf.stack()
を使用します。
// imageTensor1, imageTensor2, imageTensor3 はそれぞれ [height, width, channels] 形状のテンソルとする
// テンソルの配列を作成
const imageTensors = [imageTensor1, imageTensor2, imageTensor3];
// tf.stack() を使用してバッチ次元 (axis=0) で結合
// 生成される batchedTensor は [3, height, width, channels] 形状となる
const batchedTensor = tf.stack(imageTensors, 0);
// モデルにバッチテンソルを入力して推論
// const predictions = model.predict(batchedTensor);
// 使用済みテンソルのメモリ解放を忘れずに行う
// tf.dispose([imageTensor1, imageTensor2, imageTensor3, batchedTensor]);
tf.stack(tensors, axis)
は、指定されたテンソルの配列 tensors
を、指定された axis
に沿って結合し、新しい次元を追加します。画像の場合、新しいバッチ次元は通常先頭(axis=0
)に追加されるため、この例のように axis=0
を指定します。PythonのTensorFlow/NumPyにおける tf.stack()
や np.stack()
と同様の挙動をします。
具体的なコード例:複数の画像要素からのバッチ推論
Webブラウザ環境で、複数の <img>
要素から画像データを読み込み、バッチ処理で推論を行う例を考えます。ここでは、簡略化のため、画像の前処理(リサイズ、正規化)は各画像に対して個別に行い、その後に tf.stack()
で結合するアプローチを取ります。
<!DOCTYPE html>
<html>
<head>
<title>TF.js Batch Image Inference Example</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<style>
img { max-width: 100px; height: auto; margin: 5px; }
</style>
</head>
<body>
<h1>TF.js バッチ推論デモ</h1>
<img id="img1" src="image1.jpg" crossorigin="anonymous">
<img id="img2" src="image2.jpg" crossorigin="anonymous">
<img id="img3" src="image3.jpg" crossorigin="anonymous">
<img id="img4" src="image4.jpg" crossorigin="anonymous">
<button id="runInference">バッチ推論実行</button>
<div id="output"></div>
<script>
async function runBatchInference() {
const imgElements = document.querySelectorAll('img');
const outputDiv = document.getElementById('output');
outputDiv.innerText = 'モデル読み込み中...';
// ここでは MobileNetV2 を例として使用します
const modelUrl = 'https://tfhub.dev/google/tfjs-model/mobilenet_v2/1.0/classification/3/default/1/model.json';
let model;
try {
model = await tf.loadGraphModel(modelUrl);
outputDiv.innerText = 'モデル読み込み完了。推論実行中...';
} catch (error) {
outputDiv.innerText = 'モデルの読み込みに失敗しました: ' + error.message;
console.error(error);
return;
}
const imageTensors = [];
const imageSize = 224; // MobileNetV2 の入力サイズ
// 各画像要素からテンソルを生成し、前処理を行う
for (const imgEl of imgElements) {
try {
// 画像要素からテンソルを生成 ([height, width, channels])
const imageTensor = tf.browser.fromPixels(imgEl);
// モデルの入力サイズに合わせてリサイズ
const resizedTensor = tf.image.resizeBilinear(imageTensor, [imageSize, imageSize]);
// MobileNetV2 は -1 から 1 の範囲に正規化を期待する場合が多い
// (モデルによっては 0-1 や 0-255 の場合もあります。使用するモデルのドキュメントを確認してください)
const normalizedTensor = resizedTensor.div(255.0).sub(0.5).mul(2.0);
// 使用済みテンソルのメモリを解放
imageTensor.dispose();
resizedTensor.dispose();
imageTensors.push(normalizedTensor);
} catch (error) {
console.error('画像の処理中にエラー:', error);
// エラーが発生した画像はスキップするなど、適切なエラーハンドリングを行います
}
}
let batchedInput = null;
let predictions = null;
try {
// 複数のテンソルをバッチ次元でスタックして結合 ([batch_size, height, width, channels])
batchedInput = tf.stack(imageTensors, 0);
console.log('Batched Input Shape:', batchedInput.shape); // 例: [4, 224, 224, 3]
// バッチ入力テンソルを使ってモデル推論を実行
const startTime = performance.now();
predictions = model.predict(batchedInput);
const endTime = performance.now();
console.log('推論時間 (バッチ):', endTime - startTime, 'ms');
// 推論結果のテンソルを取得
// モデルの出力形状はモデルによって異なります。分類モデルの場合は [batch_size, num_classes] が一般的です。
const predictionsArray = await predictions.array();
outputDiv.innerText = `バッチ推論完了。処理時間: ${(endTime - startTime).toFixed(2)} ms`;
console.log('Batch Predictions:', predictionsArray);
// 各画像に対する結果を個別に処理する場合
// predictionsArray は [batch_size, ...] の配列になっているので、ループで処理
// for (let i = 0; i < predictionsArray.length; i++) {
// console.log(`画像 ${i+1} の推論結果:`, predictionsArray[i]);
// // ここで各画像の結果に応じた処理を行う
// }
} catch (error) {
outputDiv.innerText = '推論中にエラーが発生しました: ' + error.message;
console.error(error);
} finally {
// 使用したテンソルを解放する
if (batchedInput) batchedInput.dispose();
if (predictions) predictions.dispose();
imageTensors.forEach(t => t.dispose()); // 個別の画像テンソルも解放
model.dispose(); // モデルも使用後は解放
}
}
document.getElementById('runInference').addEventListener('click', runBatchInference);
// ダミー画像のファイルパスは適切に置き換えてください。
// または base64 データURI を使用することも可能です。
</script>
</body>
</html>
このコードでは、ページ上の複数の <img>
要素を取得し、それぞれを tf.browser.fromPixels()
でテンソルに変換しています。その後、tf.image.resizeBilinear()
でモデルの入力サイズにリサイズし、[-1, 1] の範囲に正規化しています。重要なのは、この個別に処理されたテンソルを imageTensors
という配列に格納し、最後に tf.stack(imageTensors, 0)
で一つのバッチテンソルを作成している点です。
作成されたバッチテンソル batchedInput
を model.predict()
に渡すことで、バッチ推論が実行されます。推論結果 predictions
もバッチ次元を含むテンソル(分類モデルであれば [batch_size, num_classes]
)となりますので、.array()
や .data()
でJavaScriptの配列/TypedArrayに変換した後、各要素(各画像に対応する推論結果)を処理する必要があります。
Pythonとの比較と応用
PythonでTensorFlow/Kerasを使っている場合、model.predict()
にNumPy配列やTensorFlow Tensorのリストではなく、[batch_size, height, width, channels]
形状の単一の配列/テンソルを渡すのが一般的です。このバッチ入力テンソルを作成するために、Pythonでは np.stack()
や tf.stack()
といった関数をよく使用します。今回のTensorFlow.jsでの tf.stack()
の使い方は、Pythonでの経験と非常に類似しており、スムーズに理解できるかと思います。
また、PythonでImageDataGeneratorなどを使ってバッチ単位でデータを生成・処理するパイプラインを構築した経験があれば、TensorFlow.jsでも同様に、複数の画像ソース(例えばサーバーからフェッチした画像ファイルや、IndexedDBに保存された画像データなど)から非同期的に画像を読み込み、前処理を行い、一定数たまったらバッチとして処理する、といったパイプラインを構築することが可能です。非同期処理には Promise.all()
などが役立ちます。
実践的な考慮事項
- メモリ管理: バッチサイズを大きくすると、それに応じて使用するテンソルのメモリ量も増加します。特にブラウザ環境では利用可能なメモリが限られているため、適切なバッチサイズを選択することが重要です。
tf.dispose()
やtf.tidy()
を積極的に利用して、不要になったテンソルをこまめに解放してください。上記のコード例でもdispose()
を使用しています。 - 非同期処理: 複数の画像を読み込んだり前処理したりする過程は非同期になることが多いです。これらの非同期処理を効率的に行うために、
async/await
やPromise.all()
を活用して、すべての画像の準備が完了してからバッチテンソルを作成・推論を実行するような制御を行います。 - 入力形状の統一:
tf.stack()
で結合するテンソルは、バッチ次元以外の形状([height, width, channels]
)が全て一致している必要があります。もし入力画像のサイズが異なる場合は、tf.image.resize*()
関数などを使って、結合前に同じサイズにリサイズする前処理が必須です。 - パフォーマンス計測: 単一画像推論とバッチ推論のパフォーマンスを比較する際は、モデルの読み込み時間や画像読み込み・前処理時間は除外し、純粋な
model.predict()
の実行時間で比較すると、バッチ処理自体の効果を正確に把握できます。performance.now()
を使うとミリ秒単位で計測が可能です。
まとめ
この記事では、TensorFlow.jsを使って複数の画像データをまとめてバッチとして扱い、画像認識モデルで推論を実行する方法について解説しました。tf.stack()
を用いて個別の画像テンソルをバッチ入力テンソルに結合する基本的な手法を示し、ブラウザ環境での具体的なコード例をご紹介しました。
バッチ処理は、特にGPU環境においてモデルの並列計算能力を最大限に引き出し、全体的な推論スループットを向上させるための重要なテクニックです。Pythonでの機械学習開発におけるバッチ処理の概念と共通している部分が多く、Python経験者の方もスムーズにTF.jsでのバッチ処理を導入できるでしょう。
実践においては、利用可能なメモリ量や処理時間、非同期処理の制御などを考慮し、最適なバッチサイズや処理パイプラインを設計することが求められます。TensorFlow.jsでの画像認識開発において、大量の画像を効率的に処理する必要がある場面で、ぜひバッチ処理の活用を検討してみてください。