TensorFlow.jsでONNXモデルをロード・推論する方法:Pythonユーザーのための実践ガイド
はじめに
本記事では、TensorFlow.js環境において、ONNX(Open Neural Network Exchange)形式の機械学習モデルをロードし、画像認識タスクの推論を実行する方法について解説します。特に、PythonでONNXモデルを扱った経験がある技術者の方々が、その知識をTensorFlow.jsでのWebブラウザやNode.js環境での実行に繋げられるよう、具体的なコード例と技術的な詳細を提供することを目的としています。
Pythonエコシステムでは、様々なフレームワーク(PyTorch, TensorFlow/Keras, scikit-learnなど)でモデルを開発し、それをONNX形式に変換して異なる環境で利用するワークフローが広く用いられています。WebブラウザやNode.jsでこれらのONNXモデルを直接実行できれば、Pythonバックエンドへの依存を減らし、クライアントサイドでの処理や、Node.jsでの軽量な推論サーバー構築などが可能になります。TensorFlow.jsエコシステムで提供されている tfjs-onnx
ライブラリを利用することで、これが実現できます。
ONNXモデルとは
ONNXは、深層学習モデルを表現するためのオープンフォーマットです。異なるフレームワーク間でモデルを相互運用可能にすることを目的としています。Python環境では、PyTorch、TensorFlow/Keras、MXNetなど、多くの主要な深層学習フレームワークがONNX形式でのモデルのエクスポートや、ONNXモデルのインポートをサポートしています。これにより、例えばPyTorchで学習したモデルをONNXに変換し、TensorFlowで読み込んで推論を実行するといったことが可能になります。
ONNXモデルは、計算グラフの構造と、モデルの重み(パラメータ)を含んでいます。計算グラフは、モデルが行う演算(畳み込み、活性化関数、行列乗算など)のシーケンスを定義します。
Pythonユーザーにとって、ONNXはモデルのエクスポート・インポートにおける標準的な選択肢の一つとして位置づけられます。TensorFlow SavedModelやKeras形式と同様に、学習済みのモデルをデプロイ可能な形式で保存する際に利用されます。
PythonでのONNXモデル準備
TensorFlow.jsでONNXモデルをロードするためには、まずPython側でモデルをONNX形式でエクスポートする必要があります。ここでは、TensorFlow/KerasとPyTorchの簡単な例を挙げます。
KerasモデルをONNXに変換する
TensorFlow/KerasモデルをONNXに変換するには、tf2onnx
というライブラリを利用するのが一般的です。
まず、必要なライブラリをインストールします。
pip install tensorflow onnx tf2onnx
次に、Kerasモデルを定義し、ダミーデータを使って推論を実行可能な状態にしてからONNX形式で保存します。
import tensorflow as tf
import tf2onnx
import onnx
# シンプルなKerasモデルを定義
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation='softmax')
])
# ダミー入力テンソルを定義(形状はモデルのinput_shapeに合わせる)
input_signature = [tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='input')]
# KerasモデルをONNXに変換
onnx_model, external_tensor_storage = tf2onnx.convert.from_keras(
model, input_signature=input_signature, opset=13
)
# ONNXモデルをファイルに保存
onnx_model_path = "keras_model.onnx"
onnx.save(onnx_model, onnx_model_path)
print(f"Keras model saved to {onnx_model_path}")
ここで、input_signature
で入力テンソルの形状と名前を指定しています。None
はバッチサイズが可変であることを示します。opset
はONNXオペレーターセットのバージョンを指定します。互換性の問題がある場合があるため、使用する tfjs-onnx
のバージョンがサポートしているopsetを確認することが推奨されます。
PyTorchモデルをONNXに変換する
PyTorchモデルをONNXに変換するには、PyTorchに内蔵されている torch.onnx.export
関数を利用します。
まず、必要なライブラリをインストールします。
pip install torch torchvision onnx
次に、PyTorchモデルを定義し、ダミー入力を使ってONNX形式で保存します。
import torch
import torch.nn as nn
# シンプルなPyTorchモデルを定義
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2D(1, 32, kernel_size=3)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(32 * 13 * 13, 10) # Conv + Pool後の出力サイズに合わせて調整が必要
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.flatten(x)
x = self.fc1(x)
return x
model = SimpleCNN()
# ダミー入力テンソルを定義(形状はモデルの入力に合わせる)
# PyTorchのデフォルトは (Batch, Channels, Height, Width)
dummy_input = torch.randn(1, 1, 28, 28)
# PyTorchモデルをONNXに変換
onnx_model_path = "pytorch_model.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_model_path,
export_params=True,
opset_version=13, # 使用するONNXオペレーターセットのバージョン
do_constant_folding=True,
input_names=['input'], # モデルの入力の名前
output_names=['output'] # モデルの出力の名前
)
print(f"PyTorch model saved to {onnx_model_path}")
PyTorchの場合も、ダミー入力を用意し、input_names
とoutput_names
を指定することが重要です。これらの名前は、TensorFlow.js側でモデルの入出力を操作する際に必要になる場合があります。
TensorFlow.jsでのONNXモデルロードと推論
Python側でONNX形式のモデルファイル(.onnx
)が用意できたら、TensorFlow.js環境でこれをロードして利用します。これには @tensorflow/tfjs-onnx
ライブラリを使用します。
ライブラリのインストール
Webブラウザ環境またはNode.js環境に応じて、TensorFlow.js本体と tfjs-onnx
をインストールします。
Webブラウザの場合:
npm install @tensorflow/tfjs @tensorflow/tfjs-onnx
# または yarn add ...
Node.jsの場合:
npm install @tensorflow/tfjs-node @tensorflow/tfjs-onnx
# または yarn add ...
または、Webページに <script>
タグでCDNからロードすることも可能です。
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-onnx@latest/dist/tfjs-onnx.min.js"></script>
モデルファイルの準備
Pythonでエクスポートした .onnx
ファイルを、Webサーバー経由でアクセスできる場所(Webブラウザの場合)、またはファイルシステム上のパス(Node.jsの場合)に配置します。
JavaScriptコード例: モデルのロード
@tensorflow/tfjs-onnx
を使ってONNXモデルをロードします。非同期処理となるため、async/await
を使用します。
import * as tf from '@tensorflow/tfjs';
import * as onnx from '@tensorflow/tfjs-onnx';
async function loadOnnxModel(modelPath) {
try {
// ONNXモデルのロード
const model = await onnx.load(modelPath);
console.log('ONNX model loaded successfully.');
return model;
} catch (error) {
console.error('Error loading ONNX model:', error);
throw error; // エラーハンドリング
}
}
// モデルファイルのパスを指定
// Webブラウザの場合、Webサーバー上のURL
// Node.jsの場合、ファイルシステムのパス (例: 'file://path/to/your/model.onnx')
const modelPath = 'path/to/your/model.onnx';
loadOnnxModel(modelPath)
.then(model => {
// モデルロード成功後の処理
// モデルオブジェクトは推論などに利用可能
console.log('Model object:', model);
})
.catch(err => {
// モデルロード失敗時の処理
console.error('Failed to load model.');
});
onnx.load()
関数にモデルファイルのパスまたはURLを渡すことで、モデルがロードされます。この関数は Promise
を返すため、非同期で処理を行います。
JavaScriptコード例: 画像前処理と推論実行
画像認識モデルで推論を行う前には、入力画像をモデルが期待する形式(テンソルの形状、データ型、値の範囲など)に前処理する必要があります。PythonでのNumPyや画像ライブラリ(PIL, OpenCVなど)を使った経験があれば、その考え方をTensorFlow.jsのテンソル操作に置き換えて実装します。
多くの場合、以下の前処理が必要です。 - 画像のリサイズ - チャンネル数の調整(ONNXモデルがRGBまたはグレースケール期待かによる) - ピクセル値のスケーリング(0-1または-1-1の範囲に正規化など) - テンソルの形状変更(HWCからCHWなど、ONNXモデルの入力形状に合わせる)
一般的な画像ファイル(例: JPEG, PNG)からテンソルへの変換には、@tensorflow/tfjs-converter
と組み合わせて利用可能な画像処理ライブラリ(例: canvas
やブラウザの Image
オブジェクト)を使用します。
import * as tf from '@tensorflow/tfjs';
import * as onnx from '@tensorflow/tfjs-onnx';
// Assuming the ONNX model is loaded in the 'model' variable
// const model = await onnx.load(modelPath);
async function preprocessImage(imageElement) {
// imageElement は HTMLImageElement または CanvasElement など
// モデルが期待する入力形状 (例: [1, 28, 28, 1] for Keras ONNX example)
const targetWidth = 28;
const targetHeight = 28;
const channels = 1; // または 3 (RGB)
// 画像をTF.jsテンソルに変換
let tensor = tf.browser.fromPixels(imageElement, channels);
// リサイズ
tensor = tf.image.resizeBilinear(tensor, [targetHeight, targetWidth]);
// データ型と値の範囲を調整 (例: float32, 0-1)
tensor = tensor.toFloat().div(255.0);
// ONNXモデルが期待する形状にreshape
// Kerasの場合: [Height, Width, Channels] -> [Batch, Height, Width, Channels]
// PyTorchの場合: [Height, Width, Channels] -> [Batch, Channels, Height, Width]
// ここではKeras ONNX例の形状 [1, 28, 28, 1] に合わせる
tensor = tensor.expandDims(0); // バッチ次元を追加
// もしPyTorch ONNXモデル ([Batch, Channels, Height, Width]) なら以下のようにpermuteを使います
// if (channels === 3) {
// tensor = tensor.transpose([0, 3, 1, 2]); // [Batch, H, W, C] -> [Batch, C, H, W]
// } else if (channels === 1) {
// // Grayscaleの場合はHWCもCHWも1チャンネルなので形状は同じですが、
// // ONNXランタイムによっては形状を明示的にCHWにする必要がある場合もあります
// // 例: tensor = tensor.transpose([0, 3, 1, 2]); // [Batch, H, W, 1] -> [Batch, 1, H, W]
// // モデルの入力形状に合わせて調整してください
// }
return tensor;
}
async function runInference(model, imageTensor) {
let outputTensor;
try {
// 推論の実行
// ONNXモデルの入出力名を確認する必要があります
// Pythonのエクスポート時に指定した input_names / output_names
const inputs = { 'input': imageTensor }; // ONNXモデルの入力名に合わせてキー名を変更
const outputs = await model.run(inputs);
// 推論結果を取得
// ONNXモデルの出力名に合わせてキー名を変更
outputTensor = outputs['output']; // ONNXモデルの出力名に合わせてキー名を変更
console.log('Inference executed successfully.');
return outputTensor;
} catch (error) {
console.error('Error during inference:', error);
throw error;
} finally {
// 不要になったテンソルのメモリを解放
// 推論に必要な中間テンソルは run 関数が内部で解放しますが、
// 入力テンソルは自分で解放する必要があります
if (imageTensor) imageTensor.dispose();
// outputTensorは返り値として使うため、ここでは解放しません
}
}
// 使用例(HTMLCanvasElementを想定)
// const imgElement = document.getElementById('myImage'); // または <canvas> 要素
// loadOnnxModel(modelPath).then(async (model) => {
// const processedTensor = await preprocessImage(imgElement);
// const resultTensor = await runInference(model, processedTensor);
// // 推論結果の処理(例: softmax後の確率値を取得)
// const predictions = await resultTensor.data();
// console.log('Predictions:', predictions);
// // 結果テンソルのメモリを解放
// resultTensor.dispose();
// }).catch(err => {
// console.error('An error occurred in the inference pipeline.');
// });
preprocessImage
関数は、HTMLの Image
要素や Canvas
要素からピクセルデータを読み込み、リサイズや正規化といった前処理を行ってTensorFlow.jsのテンソルを生成します。Pythonでの画像処理(NumPy配列操作やPIL/OpenCVの関数)に慣れている方であれば、対応するTF.jsの画像処理API(tf.image.resizeBilinear
, tensor.toFloat()
, tensor.div()
, tensor.expandDims()
, tensor.transpose()
など)を利用することで、同様の前処理パイプラインを構築できます。
runInference
関数では、前処理済みの入力テンソルを model.run()
メソッドに渡して推論を実行します。ここで重要なのは、model.run()
に渡す入力オブジェクトのキー名です。これは、PythonでONNXモデルをエクスポートする際に input_names
で指定した名前と一致させる必要があります。同様に、返ってくる出力オブジェクトのキー名も、output_names
で指定した名前と一致します。もし入出力名を指定せずにエクスポートした場合、デフォルトの名前(例: input.1
, output.1
など)が付けられることがありますので、ONNXモデルの構造を確認する必要があります(ONNX Netronなどのツールが役立ちます)。
推論結果はテンソルとして返されます。結果をJavaScriptのTypedArrayとして取得するには、非同期の tensor.data()
または同期の tensor.dataSync()
メソッドを使用します。画像分類タスクであれば、この配列から最も高い確率値に対応するクラスを取得するなどの後処理を行います。
finally
ブロックで imageTensor.dispose()
を呼び出しているのは、前処理で生成した入力テンソルが推論後は不要になるため、メモリを解放するためです。推論によって内部的に生成される中間テンソルは model.run()
が自動的に解放しますが、自分で生成・管理したテンソルは明示的に解放する必要があります。メモリ管理はパフォーマンス維持のために非常に重要です。
実装上の注意点と制約
- オペレーターのサポート:
tfjs-onnx
は、ONNXで定義されている全てのオペレーターをサポートしているわけではありません。使用しているモデルが、tfjs-onnx
がサポートしていないオペレーターを含んでいる場合、ロードまたは推論実行時にエラーが発生します。事前にtfjs-onnx
のドキュメントでサポート状況を確認するか、ONNXモデルのオペレーターを確認する必要があります。 - ONNX opset バージョン: ONNXオペレーターセットのバージョンによって利用できるオペレーターや挙動が異なります。Pythonでエクスポートする際の
opset_version
と、tfjs-onnx
がサポートするバージョンとの互換性に注意が必要です。 - モデル構造の確認: PythonでエクスポートしたONNXモデルの正確な入出力名、形状、データ型を確認しておくことが重要です。ONNX Netronなどのツールを使うと、モデルの計算グラフを視覚的に確認できます。
- パフォーマンス: ONNXモデルの実行パフォーマンスは、モデルの複雑さ、入力データのサイズ、そして使用するTensorFlow.jsバックエンド(WebGPU, WebGL, CPU)に依存します。特に大規模なモデルの場合、WebGPUまたはWebGLバックエンドを利用することで大幅な高速化が期待できます。Node.js環境であれば
tfjs-node-gpu
の利用も検討できます。 - メモリ管理: 前処理で生成したテンソルや推論結果のテンソルは、不要になったら必ず
dispose()
またはtf.dispose()
/tf.tidy()
を使ってメモリを解放してください。解放を怠ると、メモリリークが発生し、アプリケーションのパフォーマンス低下やクラッシュにつながる可能性があります。これはPythonでGPUメモリ管理を意識するのと同様の考え方です。
まとめ
本記事では、Pythonで一般的に利用されるONNX形式の機械学習モデルを、TensorFlow.js環境(Webブラウザ、Node.js)でロードし、画像認識タスクの推論を実行する一連の流れを解説しました。PythonでのONNXモデルのエクスポート方法を確認し、JavaScriptコードを用いて、モデルのロード、前処理、推論実行、結果取得の方法を具体的に示しました。
tfjs-onnx
ライブラリを利用することで、Pythonエコシステムで構築された多様なモデル資産をWebやNode.js環境へ展開することが可能になります。これにより、例えばブラウザ上で軽量な推論を実行してユーザー体験を向上させたり、Node.jsでPythonへの依存がない推論サービスを構築したりといった、新たな可能性が開かれます。
ONNXモデルの互換性やパフォーマンス、適切なメモリ管理など、実践的な開発における注意点にも触れました。これらの情報を活用して、Pythonで培った機械学習の知識と経験を、TensorFlow.jsによるクロスプラットフォームなAI開発に活かしていただければ幸いです。