TF.js 実践開発レシピ

TensorFlow.jsで画像認識モデルの予測根拠を可視化する:Grad-CAMによるヒートマップ生成

Tags: TensorFlow.js, 画像認識, Grad-CAM, 可視化, XAI

TensorFlow.jsで画像認識モデルの予測根拠を可視化する:Grad-CAMによるヒートマップ生成

機械学習モデルがどのように予測や判断を行っているのかを理解することは、モデルのデバッグ、改善、そしてユーザーからの信頼を得る上で非常に重要です。特に画像認識分野においては、「モデルが画像のどの部分に注目して判断を下したのか」を知ることは、その予測が妥当であるかを確認する上で有効な手段となります。このようなモデルの判断根拠を説明する技術は、説明可能なAI(XAI: Explainable AI)と呼ばれています。

Pythonで機械学習モデルを扱っている場合、LIMEやSHAPといった汎用的なXAIライブラリや、Grad-CAMのように画像に特化した可視化手法を比較的容易に利用できます。しかし、これらのモデルをTensorFlow.jsを用いてWebブラウザやNode.js上で実行する場合、Pythonのライブラリを直接利用することはできません。そこで本記事では、TensorFlow.jsを用いて、画像認識モデルの予測根拠を可視化する代表的な手法の一つであるGrad-CAM(Gradient-weighted Class Activation Mapping)を実装する方法について、具体的なコード例と共に解説します。

Grad-CAMとは

Grad-CAMは、畳み込みニューラルネットワーク(CNN)が特定のクラスを予測する際に、入力画像のどの部分に注目しているかを可視化する手法です。モデルの最後の畳み込み層の活性化マップと、その活性化マップに対する対象クラスのスコアの勾配を利用して、ヒートマップを生成します。このヒートマップを元の画像に重ね合わせることで、モデルが画像中のどこを根拠に判断したのかを直感的に理解することができます。

Grad-CAMは、モデルの再学習や構造変更を必要としないポストホック(事後的)な手法であり、幅広いCNNモデルに適用できるという利点があります。PythonでKerasやTensorFlowを使用している場合、tf.GradientTapeなどを利用して実装することが一般的ですが、TensorFlow.jsでも同様の概念とAPIを用いて実装が可能です。

TensorFlow.jsでのGrad-CAM実装の概要

TensorFlow.jsでGrad-CAMを実装するには、主に以下のステップが必要です。

  1. TensorFlow.js形式の画像認識モデルをロードします。
  2. モデルの推論を実行し、予測結果を取得します。
  3. Grad-CAMを計算するために必要な、最後の畳み込み層の活性化マップと、特定のクラス(通常はモデルが予測した、あるいは可視化したい任意のクラス)の出力に対するその活性化マップの勾配を取得します。
  4. 取得した活性化マップと勾配を組み合わせ、重み付けされた活性化マップを生成します。
  5. 生成したマップをリサイズして元の入力画像と同じサイズにし、ヒートマップとして可視化します。
  6. ヒートマップを元の画像に重ね合わせて表示します。

TensorFlow.jsでは、これらの処理をブラウザまたはNode.js環境で実行するためのAPIが提供されています。特に、勾配の計算にはtf.grad関数を使用します。

実装コード例

ここでは、事前学習済みのMobileNetV2モデル(TensorFlow.js形式に変換済み)を使用する例を示します。入力画像はHTMLの<img>要素から取得し、結果を<canvas>要素に描画することを想定します。

まず、必要なライブラリをインポートします。

import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis'; // 可視化用(任意)

// 事前学習済みMobileNetV2モデルをロード
async function loadModel() {
    // モデルのパスは環境に合わせて変更してください
    const modelUrl = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v2_1.0_224/optimized/model.json';
    console.log('Loading model...');
    const model = await tf.loadGraphModel(modelUrl);
    console.log('Model loaded.');
    return model;
}

次に、画像の前処理と推論、そしてGrad-CAM計算のための準備を行います。Grad-CAMでは特定の層の活性化マップが必要なため、モデルの一部を取得または操作する必要があります。Keras APIで構築されたモデル(Layers API)であれば、層の名前でアクセスできますが、ここではGraph Modelをロードした場合の例として、executeAsyncを使って中間層の出力を取得する方法を採用します。

Grad-CAM計算のために、元のモデルから推論パスの一部を変更した新しいモデルを作成します。この新しいモデルは、入力から最後の畳み込み層の出力までを出力するようにします。

async function preprocessImage(imgElement) {
    return tf.browser.fromPixels(imgElement)
        .resizeBilinear([224, 224]) // モデルの入力サイズにリサイズ
        .toFloat()
        .sub(127.5) // MobileNetV2の標準的な前処理
        .div(127.5)
        .expandDims(0); // バッチ次元を追加
}

async function getGradCAM(model, imageTensor, targetClassIndex) {
    return tf.tidy(() => {
        // Grad-CAM計算のために必要な、最後の畳み込み層の出力までを取得するモデルを作成
        // MobileNetV2の場合、最後の畳み込み層は 'Conv_1_Conv/Relu6' (Graph Modelの場合)
        const targetLayerName = 'Conv_1_Conv/Relu6'; 
        const outputNodes = [targetLayerName, model.outputs[0].name]; // 活性化マップと最終出力を取得

        // tf.gradを使用するために、executeAsyncではなくexecuteメソッドを使用する
        // executeメソッドは中間層の出力も指定できるが、ここではGraph Modelの例としてexecuteAsyncで別途取得するアプローチをとる
        // Layers APIモデルであれば model.layers.find(l => l.name === targetLayerName).output を使う方が容易

        // モデル全体の出力に対するターゲットクラスのスコア
        const classScore = model.executeAsync(imageTensor).then(predictions => {
             return predictions.squeeze().gather([targetClassIndex]);
        });

        // ターゲットクラスのスコアに対するターゲットレイヤーの活性化マップの勾配を計算
        // executeAsyncの結果に対する勾配を直接計算することは難しいため、
        // モデルの一部または代替パスを用意する必要がある。
        // Layers APIモデルなら簡単に gradient = tf.grad(() => model.output[targetClassIndex])(targetLayer.output) とできるが、
        // Graph Modelでは中間層の出力テンソルに対して直接勾配を計算するAPIが限られる。
        // ここでは、デモンストレーションとして概念を説明し、Layers APIモデルでのアプローチも示唆する。
        // 実装の簡略化のため、ターゲットレイヤーの出力を取得し、それに対する勾配を計算する代替的なアプローチを取る。
        // 注意: この部分のTensorFlow.js Graph Modelにおける実装は複雑になるため、
        // より現実的な実装はLayers APIモデルへの変換後に行うか、モデルを分割するなどの工夫が必要です。
        // 例として、Layers APIモデルの場合の概念コードを示す:

        /*
        // Layers APIモデルの場合の概念的なGrad-CAM勾配計算
        const targetLayer = model.getLayer(targetLayerName);
        const [classScoreTensor, activations] = tf.tidy(() => {
            const {ys: outputTensor, activation: activationTensor} = model.execute(imageTensor, [model.outputs[0].name, targetLayer.name]);
            const score = outputTensor.squeeze().gather([targetClassIndex]);
            return [score, activationTensor];
        });

        const gradient = tf.grad((x) => model.execute(x, [model.outputs[0].name]).squeeze().gather([targetClassIndex]))(activations); // 活性化マップに対する勾配
        */

        // Graph Modelの場合の簡易的なデモ(実際の実装は複雑)
        // 実際には、`tf.executeAsync`の出力テンソルに対して`tf.grad`を直接適用することは困難な場合が多いです。
        // モデルをGrad-CAM計算用に分割するか、Layers APIモデルに変換して使用することを強く推奨します。
        // ここでは、概念を示すための擬似的なコードとして、計算が必要なテンソルを仮定して進めます。

        // 実際には、ターゲットレイヤーの活性化マップテンソル(shape: [1, H, W, C])と
        // ターゲットクラスのスコアに対するそのマップの勾配テンソル(shape: [1, H, W, C])が必要です。
        // これらを取得するためのモデル操作が、Layers APIモデルであれば model.execute(input, [outputNode1, outputNode2]) などで可能ですが、
        // Graph Modelではモデル定義に依存します。
        // 仮に、activationsTensor と gradientsTensor が取得できたとします。
        // const activationsTensor = ...; // ターゲットレイヤーの活性化マップ (形状: [1, H, W, C])
        // const gradientsTensor = ...; // ターゲットクラススコアに対する活性化マップの勾配 (形状: [1, H, W, C])

        // デモ用のダミーテンソル生成(実際の計算結果ではありません)
        const height = 7; // 例: MobileNetV2 Conv_1_Conv/Relu6の出力サイズ (224/32 = 7)
        const width = 7;
        const channels = 1280; // 例: MobileNetV2 Conv_1_Conv/Relu6のチャンネル数
        const activationsTensor = tf.randomUniform([1, height, width, channels]);
        const gradientsTensor = tf.randomUniform([1, height, width, channels]);


        // 各チャンネルの勾配のグローバル平均プーリング(重み)
        const weights = gradientsTensor.mean([1, 2], true); // 形状: [1, 1, 1, C]

        // 活性化マップに重みを乗算
        const weightedActivations = activationsTensor.mul(weights); // 形状: [1, H, W, C]

        // 全チャンネルで合計して活性化マップを作成
        let cam = weightedActivations.sum(-1, true); // 形状: [1, H, W, 1]

        // ReLUを適用(負の値をゼロにする)
        cam = tf.relu(cam);

        // マップを0-1の範囲に正規化
        const max = cam.max();
        const min = cam.min();
        const camNormalized = cam.sub(min).div(max.sub(min)).squeeze(); // 形状: [H, W]

        // 後処理用にテンソルを返す
        return camNormalized;
    }); // tf.tidy() end
}

async function applyHeatmapToImage(originalImgElement, camTensor) {
    const canvas = document.createElement('canvas');
    canvas.width = originalImgElement.width;
    canvas.height = originalImgElement.height;
    const ctx = canvas.getContext('2d');

    // 元の画像を描画
    ctx.drawImage(originalImgElement, 0, 0, canvas.width, canvas.height);

    // CAMを元の画像サイズにリサイズ
    const camResized = camTensor.resizeBilinear([originalImgElement.height, originalImgElement.width]);

    // CAMテンソルをImageDataに変換
    const camData = await tf.browser.toPixels(camResized);
    const heatmapImageData = new ImageData(new Uint8ClampedArray(camData.length), originalImgElement.width, originalImgElement.height);

    // CAM値を色のヒートマップに変換(ここでは簡易的にグレースケールとして使用)
    // 実際にはカラーマップ(例: ジェットカラーマップ)を適用する
    for (let i = 0; i < camData.length; i++) {
        const intensity = camData[i]; // 0-255の値
        // 簡易的なグレースケール(高い値ほど明るく)
        heatmapImageData.data[i * 4] = intensity;     // R
        heatmapImageData.data[i * 4 + 1] = intensity; // G
        heatmapImageData.data[i * 4 + 2] = intensity; // B
        heatmapImageData.data[i * 4 + 3] = 128; // アルファ値(半透明にする)
    }

    // ヒートマップを画像に重ねて描画
    ctx.globalAlpha = 0.5; // 重ねる際の透明度
    ctx.drawImage(heatmapImageData, 0, 0);
    ctx.globalAlpha = 1.0;

    return canvas; // ヒートマップ付きのCanvas要素を返す
}

// メイン処理の例
async function runGradCAM() {
    const model = await loadModel();
    const imgElement = document.getElementById('your-image-element-id'); // HTMLのimg要素のIDを指定
    const imageTensor = await preprocessImage(imgElement);

    // 推論を実行して予測クラスを取得 (Grad-CAM計算前に必要)
    const predictions = await model.predict(imageTensor);
    const { values, indices } = tf.topk(predictions, 1);
    const predictedClassIndex = indices.dataSync()[0]; // 最も確率の高いクラスのインデックス

    console.log(`Predicted class index: ${predictedClassIndex}`);

    // ターゲットクラス(予測されたクラス)のGrad-CAMを計算
    const camTensor = await getGradCAM(model, imageTensor, predictedClassIndex);

    // 結果をCanvasに描画
    const resultCanvas = await applyHeatmapToImage(imgElement, camTensor);

    // 結果のCanvasをページに追加 (例: bodyに追加)
    document.body.appendChild(resultCanvas);

    // 使用済みテンソルを解放
    imageTensor.dispose();
    predictions.dispose();
    values.dispose();
    indices.dispose();
    camTensor.dispose(); // applyHeatmapToImage内でリサイズされたテンソルは別途解放が必要
}

// ページロード時に実行
// runGradCAM().catch(console.error);

コード解説とPythonとの比較

上記のコードは、TensorFlow.jsを用いてGrad-CAMを計算する概念を示しています。特に重要な点を以下に解説します。

開発における注意点と考慮事項

まとめ

本記事では、TensorFlow.jsを用いて画像認識モデルの予測根拠をGrad-CAMにより可視化する方法について解説しました。Pythonで機械学習モデルの開発経験がある読者の方々にとって、TensorFlow.jsでのXAI実装の一例として、Grad-CAMがどのように実現できるかをご理解いただけたかと思います。Layers APIモデルを利用することで、PythonのKerasでの記述感に近い形で中間層へのアクセスや勾配計算を行うことが可能です。

Grad-CAMは画像認識モデルのデバッグや理解に役立つ強力な手法です。ブラウザやNode.js環境でモデルをデプロイする際に、ユーザーインターフェース上でモデルの判断根拠を示すことで、アプリケーションの信頼性や使いやすさを向上させることができます。TensorFlow.jsを活用して、より実践的な画像認識アプリケーション開発に取り組んでいただければ幸いです。