このページは Cloud Translation API によって翻訳されました。
Switch to English

AndroidでカスタムTensorFlow Liteモデルを使用する

アプリでカスタムTensorFlow Liteモデルを使用している場合は、Firebase MLを使用してモデルをデプロイできます。 Firebaseでモデルをデプロイすることにより、アプリの初期バージョンを縮小し、アプリの新しいバージョンをリリースすることなくアプリのMLモデルを更新できます。また、Remote ConfigとA / Bテストを使用すると、さまざまなモデルをさまざまなユーザーセットに動的に提供できます。

TensorFlow Liteモデル

TensorFlow Liteモデルは、モバイルデバイスで実行するように最適化されたMLモデルです。 TensorFlow Liteモデルを取得するには:

あなたが始める前に

  1. まだの場合は、 FirebaseをAndroidプロジェクトに追加します
  2. プロジェクト・レベルのではbuild.gradleファイル、あなたの両方でGoogleのMavenのリポジトリが含まれていることを確認してくださいbuildscriptallprojectsセクション。
  3. Firebase MLおよびTensorFlow Lite Androidライブラリをモジュール(アプリレベル)のGradleファイル(通常はapp/build.gradle )に追加します:
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.4'
      implementation 'org.tensorflow:tensorflow-lite:2.0.0'
    }
    
  4. アプリのマニフェストで、インターネット権限が必要であることを宣言します:
    <uses-permission android:name="android.permission.INTERNET" />

1.モデルをデプロイする

FirebaseコンソールまたはFirebase Admin PythonおよびNode.js SDKを使用して、カスタムTensorFlowモデルをデプロイします。 カスタムモデルの展開と管理を参照してください。

Firebaseプロジェクトにカスタムモデルを追加したら、指定した名前を使用してアプリでモデルを参照できます。いつでも新しいTensorFlow Liteモデルをアップロードできます。アプリは新しいモデルをダウンロードし、アプリが次に再起動したときに使用を開始します。アプリがモデルの更新を試みるために必要なデバイス条件を定義できます(以下を参照)。

2.モデルをデバイスにダウンロードします

アプリでTensorFlow Liteモデルを使用するには、まずFirebase ML SDKを使用して、モデルの最新バージョンをデバイスにダウンロードします。

モデルのダウンロードを開始するには、モデルマネージャーのdownload()メソッドを呼び出し、アップロード時にモデルに割り当てた名前と、ダウンロードを許可する条件を指定します。モデルがデバイス上にない場合、または新しいバージョンのモデルが利用可能な場合、タスクはFirebaseからモデルを非同期でダウンロードします。

モデルがダウンロードされたことを確認するまで、モデル関連の機能(たとえば、UIのグレー表示や非表示)を無効にする必要があります。

ジャワ

FirebaseCustomRemoteModel remoteModel =
      new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

多くのアプリは、初期化コードでダウンロードタスクを開始しますが、モデルを使用する必要がある前であればいつでも開始できます。

3. TensorFlow Liteインタープリターを初期化する

モデルをデバイスにダウンロードしたら、モデルマネージャーのgetLatestModelFile()メソッドを呼び出して、モデルファイルの場所を取得できます。この値を使用して、TensorFlow Liteインタープリターをインスタンス化します。

ジャワ

FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                }
            }
        });

Kotlin + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        }
    }

4.入力データに対して推論を実行する

モデルの入力および出力形状を取得する

TensorFlow Liteモデルインタープリターは、1つ以上の多次元配列を入力として受け取り、出力として生成します。これらの配列には、 byteintlong 、またはfloat値のいずれかが含まれています。モデルにデータを渡したり、その結果を使用したりする前に、モデルが使用する配列の数と次元(「形状」)を知っておく必要があります。

自分でモデルを作成した場合、またはモデルの入力および出力形式が文書化されている場合は、すでにこの情報を持っている可能性があります。モデルの入力と出力の形状とデータ型がわからない場合は、TensorFlow Liteインタープリターを使用してモデルを検査できます。例えば:

パイソン

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

出力例:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

通訳を実行する

モデルの入力と出力の形式を決定したら、入力データを取得し、モデルに適切な形状の入力を取得するために必要なデータの変換を実行します。

たとえば、 [1 224 224 3]浮動小数点値の入力形状を持つ画像分類モデルがある場合、次の例に示すように、 Bitmapオブジェクトから入力ByteBufferを生成できます。

ジャワ

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin + KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

次に、モデルの出力を格納するのに十分な大きさのByteBuffer割り当て、入力バッファーと出力バッファーをTensorFlow Liteインタープリターのrun()メソッドに渡します。たとえば、 [1 1000]浮動小数点値の出力形状の場合:

ジャワ

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin + KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

出力の使用方法は、使用しているモデルによって異なります。

たとえば、分類を実行している場合、次のステップとして、結果のインデックスをそれらが表すラベルにマッピングすることができます。

ジャワ

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin + KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

付録:ローカルにバンドルされたモデルにフォールバック

Firebaseでモデルをホストする場合、アプリがモデルを初めてダウンロードするまで、モデル関連の機能は利用できません。一部のアプリではこれで問題ない場合がありますが、モデルでコア機能が有効になっている場合は、モデルのバージョンをアプリにバンドルして、利用可能な最良のバージョンを使用することをお勧めします。そうすることで、Firebaseでホストされているモデルが利用できない場合でもアプリのML機能が動作することを確認できます。

TensorFlow Liteモデルをアプリにバンドルするには:

  1. モデルファイル(通常は.tfliteまたは.lite終わるファイル)をアプリのassets/フォルダーにコピーします。 (最初にapp/フォルダーを右クリックし、次に[ 新規 ] > [フォルダー]> [アセットフォルダー ]をクリックしてフォルダーを作成する必要がある場合があります。)

  2. 次のbuild.gradleアプリのbuild.gradleファイルに追加して、アプリのビルド時にGradleがモデルを圧縮しないようにします。

    android {
    
        // ...
    
        aaptOptions {
            noCompress "tflite", "lite"
        }
    }
    

次に、ホストされたモデルが利用できない場合は、ローカルにバンドルされたモデルを使用します。

ジャワ

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                } else {
                    try {
                        InputStream inputStream = getAssets().open("your_fallback_model.tflite");
                        byte[] model = new byte[inputStream.available()];
                        inputStream.read(model);
                        ByteBuffer buffer = ByteBuffer.allocateDirect(model.length)
                                .order(ByteOrder.nativeOrder());
                        buffer.put(model);
                        interpreter = new Interpreter(buffer);
                    } catch (IOException e) {
                        // File not found?
                    }
                }
            }
        });

Kotlin + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        } else {
            val model = assets.open("your_fallback_model.tflite").readBytes()
            val buffer = ByteBuffer.allocateDirect(model.size).order(ByteOrder.nativeOrder())
            buffer.put(model)
            interpreter = Interpreter(buffer)
        }
    }

付録:モデルのセキュリティ

TensorFlow LiteモデルをFirebase MLで利用できるようにする方法に関係なく、Firebase MLはローカルストレージの標準のシリアル化されたprotobuf形式でモデルを保存します。

理論的には、これは誰でもモデルをコピーできることを意味します。ただし、実際には、ほとんどのモデルはアプリケーション固有であり、最適化によって難読化されているため、競合他社がコードを分解して再利用するリスクと同様です。それでも、アプリでカスタムモデルを使用する前に、このリスクに注意する必要があります。

Android APIレベル21(Lollipop)以降では、モデルは自動バックアップから除外されたディレクトリにダウンロードされます

Android APIレベル20以前では、モデルはapp-private内部ストレージのcom.google.firebase.ml.custom.modelsという名前のディレクトリにダウンロードされます。 BackupAgentを使用してファイルのバックアップを有効にした場合、このディレクトリを除外することを選択できます。