TensorFlow.jsでAttention機構を用いた画像認識モデルの推論を実行する:Vision Transformer入門
画像認識分野では、長らくConvolutional Neural Network (CNN) が主流のアーキテクチャとして使用されてきました。しかし、近年では自然言語処理の分野で大きな成果を上げたTransformerモデルのAttention機構を画像認識に応用したVision Transformer (ViT) が注目を集めています。ViTはCNNとは異なるアプローチで画像の特徴を捉え、特に大規模データセットを用いた学習において高い性能を示すことが報告されています。
Pythonでの機械学習開発に慣れている方であれば、TransformerやAttentionといった概念は自然言語処理の文脈で耳にしたことがあるかもしれません。ViTは、この強力なAttention機構を画像データに適用することで、新たな視点から画像認識問題を解決しようとするものです。本記事では、このVision TransformerモデルをTensorFlow.jsで利用し、WebブラウザやNode.js環境で画像推論を実行する方法について、具体的なコード例を交えながら解説します。PythonでのViTの概念との比較や、TF.jsで扱う上での注意点にも触れていきます。
Vision Transformer (ViT) の概要
ViTの基本的なアイデアは、画像をパッチと呼ばれる小さな領域に分割し、それぞれのパッチを「トークン」として扱うことにあります。これは、自然言語処理で文を単語やサブワードのトークンに分割するアプローチと似ています。
- 画像のパッチ分割: 入力画像を、例えば16x16ピクセルといった固定サイズの小さなパッチに分割します。
- パッチの線形埋め込み (Linear Embedding): 各パッチを平坦化し、線形変換によって固定次元のベクトル表現(パッチ埋め込み)に変換します。
- 位置埋め込み (Positional Embedding) の追加: 自然言語処理のTransformerが単語の順番情報を失わないように位置埋め込みを使用するのと同様に、ViTでもパッチの位置情報を保持するために位置埋め込みを加算します。
- Transformerエンコーダー: パッチ埋め込みと位置埋め込みを合わせたシーケンスを、複数のTransformerエンコーダー層に入力します。各エンコーダー層は、Multi-Head Self-Attention機構とフィードフォワードネットワークから構成されます。Attention機構により、各パッチが画像内の他のパッチとの関連性を捉えることができます。
- 分類: Transformerエンコーダーの出力のうち、特別な分類トークン(CLSトークン)に対応する表現や、全てのパッチの表現を集約したものを使用して、最終的なクラス分類を行います。
Python/TensorFlow/KerasでViTを実装する場合、tf.keras.layers.MultiHeadAttention
のようなレイヤーを用いてTransformerエンコーダーを構築したり、tensorflow_hub
やHugging Faceのtransformers
ライブラリなどから事前学習済みモデルを利用することが一般的です。TensorFlow.jsでViTモデルを扱う場合も、これらのPythonライブラリでエクスポートされたSavedModelやKerasモデルをtfjs-converterで変換したモデルを使用することになります。
TensorFlow.jsで事前学習済みVision Transformerモデルを読み込む
ここでは、例としてTensorFlow Hubで公開されている事前学習済みVision TransformerモデルをTensorFlow.js形式に変換したものを使用することを想定します。モデルはすでにTF.js形式に変換されていると仮定し、その読み込みから始めます。
モデルはtf.loadGraphModel()
またはtf.loadLayersModel()
を使用して読み込むことができます。ViTモデルは通常、レイヤーベースの構造を持つため、tf.loadLayersModel()
が適している場合が多いでしょう。
import * as tf from '@tensorflow/tfjs';
// モデルのパス(WebサーバーからのURLまたはローカルパス)
const modelUrl = 'path/to/your/vit_model/model.json';
let model;
async function loadModel() {
try {
console.log('モデルをロード中...');
// tf.loadLayersModelを使用してモデルを読み込む
model = await tf.loadLayersModel(modelUrl);
console.log('モデルのロードが完了しました。');
// モデルのサマリーを表示(オプション)
// model.summary();
} catch (error) {
console.error('モデルのロードに失敗しました:', error);
}
}
// モデルロード関数を呼び出す
loadModel();
PythonでKerasモデルを保存し、tensorflowjs_converter
ツールでTF.js形式に変換した場合、通常はmodel.json
というファイルと重みファイル群が生成されます。tf.loadLayersModel()
には、このmodel.json
ファイルへのパスを指定します。
推論のための画像前処理
Vision Transformerは、入力画像に対して特定のサイズへのリサイズ、正規化、そしてパッチ分割といった前処理を必要とします。これらの前処理もTensorFlow.jsのテンソル操作で行うことができます。ViTモデルが期待する入力形状(例: [batch_size, height, width, channels]
)とデータ型(例: float32
)に合わせる必要があります。
通常、事前学習済みViTモデルは特定の入力サイズ(例: 224x224ピクセル)を期待します。また、ImageNetなどで事前学習されている場合、各ピクセル値は特定の平均と標準偏差で正規化されていることがあります。
以下に、画像要素 (<img>
または <canvas>
) からTensorFlow.jsのテンソルを作成し、ViTモデルが期待する形式に前処理する例を示します。
import * as tf from '@tensorflow/tfjs';
// モデルが期待する入力サイズと正規化パラメータ(例:ImageNetの平均と標準偏差)
const IMAGE_SIZE = 224;
const IMAGENET_MEAN = [0.485, 0.456, 0.406]; // RGB平均
const IMAGENET_STD = [0.229, 0.224, 0.225]; // RGB標準偏差
async function preprocessImage(imageElement) {
// HTML要素(<img>や<canvas>)からテンソルを作成
let tensor = tf.browser.fromPixels(imageElement);
// 画像をモデルの入力サイズにリサイズ
// preserveAspectRatio: true を指定するとアスペクト比を維持したままリサイズし、余白をパディングすることが多いが、
// ViTでは単純なリサイズ(変形を許容)やクロップが使われることもある。ここでは単純なリサイズ。
const resizedTensor = tf.image.resizeBilinear(tensor, [IMAGE_SIZE, IMAGE_SIZE]);
// 0-255の整数値を0-1の浮動小数点数に変換
const normalizedTensor = resizedTensor.toFloat().div(255.0);
// ImageNetの平均と標準偏差で正規化
// PyTorchなどでは[C, H, W]形式で正規化することが多いが、TF.jsの一般的な画像テンソルは[H, W, C]。
// ここでは[H, W, C]形式で各チャンネルごとに正規化を適用する。
const meanTensor = tf.tensor3d(IMAGENET_MEAN, [1, 1, 3]);
const stdTensor = tf.tensor3d(IMAGENET_STD, [1, 1, 3]);
const finalTensor = normalizedTensor.sub(meanTensor).div(stdTensor);
// モデルが期待するバッチ次元を追加 [height, width, channels] -> [1, height, width, channels]
const batchedTensor = finalTensor.expandDims(0);
// 元のテンソルをdisposeしてメモリを解放
tensor.dispose();
resizedTensor.dispose();
normalizedTensor.dispose();
meanTensor.dispose();
stdTensor.dispose();
finalTensor.dispose();
return batchedTensor;
}
// 使用例:
// const imgElement = document.getElementById('myImage'); // HTML上の<img>要素
// preprocessImage(imgElement).then(preprocessedTensor => {
// console.log('前処理後のテンソル形状:', preprocessedTensor.shape); // 例: [1, 224, 224, 3]
// // 次にこのテンソルをモデルに入力する
// // preprocessedTensor.dispose(); // 使用後にdisposeする
// });
PythonでOpenCV (cv2.resize
) やPIL (Image.resize
) を使ったリサイズ、NumPyを使った正規化に慣れている方にとって、TensorFlow.jsのtf.image.resizeBilinear
やtf.div
, tf.sub
などの操作は直感的に理解しやすいでしょう。特にtf.browser.fromPixels
はWebブラウザ環境での画像データ取り込みに特化した便利な関数です。Node.js環境の場合は、ファイルから画像を読み込み、適切なライブラリ(例: canvas
, @tensorflow/tfjs-node
と共にインストールされる画像処理ライブラリ)を使ってテンソルに変換する必要があります。
重要なのは、使用するViTモデルが学習時にどのような前処理を施されていたかを確認し、TensorFlow.jsでもそれに正確に合わせることです。特にリサイズの方法(アスペクト比維持の有無、補間方法)や正規化のパラメータ(平均、標準偏差)、ピクセル値のスケール(0-1か0-255かなど)はモデルの性能に大きく影響します。
モデルによる推論実行
前処理済みのテンソルが準備できたら、それをロードしたモデルに入力して推論を実行します。推論は非同期で行われるため、await
を使用します。
import * as tf from '@tensorflow/tfjs';
// モデルは事前にロードされていると仮定(loadModel関数参照)
// 前処理関数も定義されていると仮定(preprocessImage関数参照)
async function runInference(imageElement) {
// モデルがロードされているか確認
if (!model) {
console.error('モデルがロードされていません。');
return null;
}
// 画像の前処理を実行
const preprocessedTensor = await preprocessImage(imageElement);
// 推論の実行
console.log('推論を実行中...');
let predictions;
try {
// tf.tidyを使用して、このスコープ内で作成された中間テンソルを自動的に解放
predictions = tf.tidy(() => {
// モデルに前処理済みテンソルを入力して推論
const output = model.predict(preprocessedTensor);
// 出力が複数のテンソルである場合や、特定の出力層が必要な場合は適宜変更
// 例: return output[0];
return output;
});
console.log('推論が完了しました。');
} catch (error) {
console.error('推論中にエラーが発生しました:', error);
// エラー発生時も前処理済みテンソルを解放
preprocessedTensor.dispose();
return null;
}
// 前処理済みテンソルは推論後不要なので解放
preprocessedTensor.dispose();
// 推論結果のテンソルをJavaScriptの配列に変換して返す
// predictionsもtf.tidyスコープ外でdisposeが必要になる場合があるが、
// tf.tidyの戻り値は例外。しかし、明示的にawait predictions.array()などとする場合は
// predictions自体はdispose対象外となるため、array()呼び出し後のpredictions.dispose()は不要。
// array()はテンソルを解決するので、その後のテンソル自体は解放済みとなる。
const predictionArray = await predictions.array();
// 推論結果のテンソルを解放
predictions.dispose();
return predictionArray[0]; // バッチサイズ1なので最初の要素を返す
}
// 使用例:
// const imgElement = document.getElementById('myImage'); // HTML上の<img>要素
// runInference(imgElement).then(predictionResults => {
// if (predictionResults) {
// console.log('推論結果 (クラス確率):', predictionResults);
// // 結果を解釈してクラス名を特定するなどの後処理を行う
// }
// });
model.predict()
メソッドは、入力テンソルに対してモデルによる順伝播計算を実行します。ViTモデルの出力は通常、クラス分類の場合は各クラスに属する確率を示すテンソル(形状は [batch_size, num_classes]
)です。
推論処理においては、不要になった中間テンソルや入力/出力テンソルを適切に解放することが重要です。tf.tidy()
ブロックを使用すると、そのブロック内で生成された中間テンソルが自動的に解放されるため、メモリ管理が容易になります。ただし、tf.tidy()
からreturn
されたテンソルは自動解放の対象外となるため、その後の処理で不要になった時点で明示的にdispose()
を呼び出す必要があります。
推論結果の解釈
runInference
関数から返されるpredictionArray
は、通常、各クラスに対応する確率値の配列です。この配列のインデックスと実際のクラス名のマッピングは、モデルの学習に使用されたデータセット(例: ImageNet)に依存します。Pythonでモデルを扱っていた際も、imagenet_labels.json
のようなファイルを使ってインデックスからクラス名に変換していたのと同様です。
TensorFlow.jsでこれを扱う場合、クラス名ファイル(JSON形式など)を別途読み込み、推論結果の確率配列のインデックスとクラス名を紐付ける処理を実装します。確率が最も高いインデックスを見つけるには、JavaScriptの配列操作を使用します。
async function getTopKClasses(predictionArray, k = 5) {
// クラス名マッピングファイルをロード(別途実装が必要)
// const classLabels = await loadClassLabels('path/to/imagenet_labels.json');
// ここでは仮のクラス名配列を使用
const classLabels = ['class_0', 'class_1', 'class_2', 'class_3', 'class_4', /* ... */]; // 実際のラベル配列に置き換える
// 確率とインデックスのペアを作成
const predictionsWithIndex = predictionArray.map((probability, index) => ({ probability, index }));
// 確率の高い順にソート
predictionsWithIndex.sort((a, b) => b.probability - a.probability);
// トップKの結果を取得し、クラス名と確率のオブジェクト配列に変換
const topK = predictionsWithIndex.slice(0, k).map(item => ({
className: classLabels[item.index],
probability: item.probability
}));
return topK;
}
// 使用例(runInferenceの後)
// runInference(imgElement).then(async predictionResults => {
// if (predictionResults) {
// const top5 = await getTopKClasses(predictionResults, 5);
// console.log('トップ5の予測結果:', top5);
// // 例: [{ className: 'cat', probability: 0.92 }, { className: 'dog', probability: 0.05 }, ...]
// }
// });
Pythonでの推論結果解釈もnp.argmax
やソート、ラベルファイル読み込みなど、同様のロジックで実現できます。JavaScriptでもこれらの処理は比較的容易に実装できます。
Python経験者への補足:tfjs-converterとモデル変換
PythonでTensorFlow/Kerasを使用してViTモデルを学習またはファインチューニングした場合、それをTensorFlow.jsで利用するためにはtensorflowjs_converter
ツールを使用する必要があります。
# Python環境でtensorflowjsをインストール
pip install tensorflowjs
# SavedModel形式の場合
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_format=tfjs_layers_model \
--signature_name=serving_default \
--saved_model_tags=serve \
/path/to/your/saved_model \
/path/to/output_tfjs_model
# Keras HDF5形式の場合
tensorflowjs_converter \
--input_format=keras \
/path/to/your/keras_model.h5 \
/path/to/output_tfjs_model
ViTのようにカスタムレイヤー(例: Patch Embedding, Positional Embedding層)や複雑なアーキテクチャを含むモデルの場合、変換がスムーズに行えないことがあります。特に、カスタムなAttention機構や、標準的なKerasレイヤー以外の操作を使用している場合は注意が必要です。
変換時のよくある問題と対策:
* 未対応の操作/レイヤー: tfjs-converterがサポートしていないTensorFlow操作やカスタムレイヤーが含まれている場合、変換エラーが発生します。可能な場合は、対応している標準的なレイヤーや操作に置き換える、またはカスタムレイヤーをJavaScriptで再実装(TF.jsのカスタムレイヤー機能を使用)するなどの対応が必要になります。
* 入力形状の問題: モデルが期待する入力形状と、TF.jsで作成したテンソルの形状が一致しない場合にエラーが発生します。Pythonでのモデル定義を確認し、TF.jsでの前処理がモデルの期待する形状に合っているかを慎重に確認してください。特にバッチ次元の扱いに注意が必要です。
* データ型: TF.jsは通常float32
で動作しますが、量子化モデル(int8
など)を扱う場合は変換オプションや読み込み方法が異なります。ViTモデルを量子化してパフォーマンス最適化を図ることも可能ですが、変換および推論プロセスが少し複雑になります。
ViTモデルの場合、標準的なレイヤー(Conv2D
, Dense
, LayerNormalization
, MultiHeadAttention
など)の組み合わせで構築されていれば、比較的スムーズに変換できる可能性が高いです。TensorFlow Hubなどで公開されている公式のViTモデルは、変換ツールでの互換性が考慮されていることが多いです。
まとめ
本記事では、Attention機構を用いた画像認識モデルであるVision Transformer (ViT) をTensorFlow.jsで利用し、画像推論を実行するための基本的な流れとコード例を解説しました。モデルの読み込み、推論のための画像前処理、そして推論結果の解釈といった主要なステップについて、具体的なJavaScript/TensorFlow.jsのコードを示しました。
Pythonでの機械学習開発経験をお持ちの方にとっては、ViTの概念やモデル変換、前処理ロジックなどは既存の知識を応用しやすい部分が多いかと思います。TF.jsを使用することで、Pythonで開発した強力な画像認識モデルをWebブラウザやNode.jsといったJavaScript環境で手軽に活用できるようになります。
ViTは比較的大きなモデルサイズになりがちであるため、Webブラウザ環境での利用においてはモデルのロード時間や推論速度、メモリ使用量などが課題となることがあります。実運用においては、量子化モデルの利用、推論デバイスの選択(GPUの活用)、あるいはより軽量なViT派生モデルの検討なども重要になってくるでしょう。
今後、TransformerやAttention機構は画像認識分野だけでなく、様々なモダリティを扱うモデルにも応用されていくと考えられます。TensorFlow.jsでこれらの新しいアーキテクチャを扱えるようになることは、活用の幅を広げる上で非常に価値があります。本記事が、ViTをはじめとするAttentionベースのモデルをTensorFlow.jsで実践的に扱うための一助となれば幸いです。