在 Android 上使用 TensorFlow Lite 模型進行推論

您可以使用 ML Kit 搭配 TensorFlow Lite 模型,在裝置上執行推論。

這個 API 需要 Android SDK 16 級別 (Jelly Bean) 以上版本。

事前準備

  1. 如果您尚未將 Firebase 新增至 Android 專案,請新增 Firebase
  2. 將 ML Kit Android 程式庫的依附元件新增至模組 (應用程式層級) Gradle 檔案 (通常為 app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
  3. 將要使用的 TensorFlow 模型轉換為 TensorFlow Lite 格式。請參閱 TOCO:TensorFlow Lite 最佳化轉換工具

代管或封裝模型

您必須先將模型提供給 ML Kit,才能在應用程式中使用 TensorFlow Lite 模型進行推論。ML Kit 可使用透過 Firebase 遠端代管的 TensorFlow Lite 模型,或與應用程式二進位檔一起封裝的模型,甚至兩者皆可。

在 Firebase 上代管模型後,您就能在未發布新版應用程式的情況下更新模型,並使用 Remote ConfigA/B Testing,為不同使用者群組動態提供不同的模型。

如果您選擇只透過 Firebase 代管模型,而非與應用程式捆綁,就能縮減應用程式的初始下載大小。不過,請注意,如果模型未與應用程式捆綁,則必須等到應用程式首次下載模型,才能使用任何模型相關功能。

將模型與應用程式捆綁後,即使 Firebase 代管的模型無法使用,應用程式的 ML 功能仍可正常運作。

在 Firebase 上代管模型

如要在 Firebase 上代管 TensorFlow Lite 模型,請按照下列步驟操作:

  1. Firebase 控制台的「ML Kit」部分,按一下「自訂」分頁標籤。
  2. 按一下「新增自訂模型」 (或「新增其他模型」)。
  3. 指定在 Firebase 專案中用於識別模型的名稱,然後上傳 TensorFlow Lite 模型檔案 (通常結尾為 .tflite.lite)。
  4. 在應用程式的資訊清單中,宣告需要 INTERNET 權限:
    <uses-permission android:name="android.permission.INTERNET" />

將自訂模型新增至 Firebase 專案後,您就可以使用指定的名稱在應用程式中參照模型。您隨時可以上傳新的 TensorFlow Lite 模型,應用程式會在下次重新啟動時下載新模型並開始使用。您可以定義應用程式嘗試更新模型時所需的裝置條件 (請參閱下文)。

將模型與應用程式組合

如要將 TensorFlow Lite 模型與應用程式組合,請將模型檔案 (通常結尾為 .tflite.lite) 複製到應用程式的 assets/ 資料夾。(您可能需要先建立資料夾,方法是按一下 app/ 資料夾的滑鼠右鍵,然後依序點選「New」>「Folder」>「Assets Folder」)。

接著,請在應用程式的 build.gradle 檔案中新增以下內容,確保 Gradle 在建構應用程式時不會壓縮模型:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

模型檔案會納入應用程式套件,並可供 ML Kit 做為原始資產使用。

載入模型

如要在應用程式中使用 TensorFlow Lite 模型,請先根據模型可用的位置設定 ML Kit:使用 Firebase 進行遠端存取,或在本機儲存空間中進行存取,甚至兩者皆可。如果您同時指定本機和遠端模型,則可在遠端模型可用時使用該模型,如果遠端模型無法使用,則會改回使用本機儲存的模型。

設定 Firebase 代管的模型

如果您使用 Firebase 代管模型,請建立 FirebaseCustomRemoteModel 物件,指定您在上傳模型時指派的名稱:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

接著,啟動模型下載工作,並指定要允許下載的條件。如果裝置上沒有模型,或是有較新版本的模型可供使用,工作會從 Firebase 異步下載模型:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

許多應用程式會在初始化程式碼中啟動下載工作,但您可以在需要使用模型之前的任何時間點啟動下載工作。

設定本機模型

如果您已將模型與應用程式捆綁,請建立 FirebaseCustomLocalModel 物件,指定 TensorFlow Lite 模型的檔案名稱:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

使用模型建立轉譯器

設定模型來源後,請從其中一個來源建立 FirebaseModelInterpreter 物件。

如果您只有本機內建的模型,請直接透過 FirebaseCustomLocalModel 物件建立轉譯器:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

如果您使用的是遠端代管模型,請務必先確認模型已下載,再執行模型。您可以使用模型管理員的 isModelDownloaded() 方法,查看模型下載作業的狀態。

雖然您只需要在執行轉譯器前確認這項資訊,但如果您同時擁有遠端代管模型和本機內建模型,在例項化模型轉譯器時執行這項檢查可能會比較合理:如果已下載遠端模型,請從該模型建立轉譯器;如果未下載,請從本機模型建立轉譯器。

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

如果您只有遠端代管的模型,請在確認模型已下載前,停用模型相關功能 (例如將部分 UI 設為灰色或隱藏)。方法是將事件監聽器附加至模型管理員的 download() 方法:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

指定模型的輸入和輸出

接下來,請設定模型解譯器的輸入和輸出格式。

TensorFlow Lite 模型會將一或多個多維陣列做為輸入,並產生輸出。這些陣列包含 byteintlongfloat 值。您必須根據模型使用的陣列數量和維度 (「形狀」) 設定 ML Kit。

如果您不清楚模型輸入和輸出內容的形狀和資料類型,可以使用 TensorFlow Lite Python 轉譯器檢查模型。例如:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

確定模型輸入和輸出格式後,您可以建立 FirebaseModelInputOutputOptions 物件,設定應用程式的模型解譯器。

舉例來說,浮點圖像分類模型可能會將 Nx224x224x3 的 float 值陣列做為輸入值,代表一批 N 224x224 三通道 (RGB) 圖片,並產生 1000 個 float 值的清單,每個值代表圖片屬於模型預測的 1000 個類別之一的機率。

針對這類模型,您可以將模型解譯器的輸入和輸出設為如下所示:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

對輸入資料執行推論

最後,如要使用模型執行推論,請取得輸入資料,並對資料執行任何必要的轉換,以便取得模型正確形狀的輸入陣列。

舉例來說,如果您有圖片分類模型,且輸入形狀為 [1 224 224 3] 浮點值,您可以從 Bitmap 物件產生輸入陣列,如以下範例所示:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

接著,請使用輸入資料建立 FirebaseModelInputs 物件,並將該物件和模型的輸入和輸出規格傳遞至模型解譯器run 方法:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

如果呼叫成功,您可以呼叫傳遞至成功事件監聽器的物件 getOutput() 方法,取得輸出內容。例如:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

輸出內容的使用方式取決於您使用的模型。

舉例來說,如果您要執行分類作業,下一步可以將結果的索引對應至所代表的標籤:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

附錄:模型安全性

無論您如何讓 TensorFlow Lite 模型可供 ML Kit 使用,ML Kit 都會以標準序列化 protobuf 格式將模型儲存在本機儲存空間中。

理論上,這表示任何人都可以複製您的模型。不過,實際上,大多數模型都是應用程式專屬,且經過最佳化處理而變得難以解讀,因此風險與競爭對手將您的程式碼反組合及重複使用相似。不過,在應用程式中使用自訂模型前,請先瞭解這項風險。

在 Android API 級別 21 (Lollipop) 以上版本中,系統會將模型下載至 從自動備份中排除的目錄。

在 Android API 級別 20 以下版本中,模型會下載至應用程式私人內部儲存空間中的 com.google.firebase.ml.custom.models 目錄。如果您使用 BackupAgent 啟用檔案備份功能,可以選擇排除這個目錄。