使用自定义 TensorFlow Lite 模型 (Android)

如果您的应用使用自定义 TensorFlow Lite 模型,那么您可以使用 Firebase ML 来部署模型。这样一来,您可以缩减应用的初始下载大小,而且无需发布应用的新版本即可更新应用的机器学习模型。此外,借助 Remote Config 和 A/B Testing,您可以为不同的用户组动态提供不同的模型。

TensorFlow Lite 模型

TensorFlow Lite 模型是经过优化、可在移动设备上高效运行的机器学习模型。如需获取 TensorFlow Lite 模型,可采取下列两种方式之一:

准备工作

  1. 将 Firebase 添加到您的 Android 项目(如果尚未添加)。
  2. 在您的模块(应用级)Gradle 文件(通常是 <project>/<app-module>/build.gradle.kts<project>/<app-module>/build.gradle)中,添加 Firebase ML 模型下载程序 Android 库的依赖项。我们建议使用 Firebase Android BoM 来实现库版本控制。

    此外,在设置 Firebase ML 模型下载程序时,您需要将 TensorFlow Lite SDK 添加到您的应用中。

    Kotlin+KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:32.5.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader-ktx")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    借助 Firebase Android BoM,可确保您的应用使用的始终是 Firebase Android 库的兼容版本。

    (替代方法) 在不使用 BoM 的情况下添加 Firebase 库依赖项

    如果您选择不使用 Firebase BoM,则必须在每个 Firebase 库的依赖项行中指定相应的库版本。

    请注意,如果您在应用中使用多个 Firebase 库,我们强烈建议您使用 BoM 来管理库版本,从而确保所有版本都兼容。

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader-ktx:24.2.1")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Java

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:32.5.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    借助 Firebase Android BoM,可确保您的应用使用的始终是 Firebase Android 库的兼容版本。

    (替代方法) 在不使用 BoM 的情况下添加 Firebase 库依赖项

    如果您选择不使用 Firebase BoM,则必须在每个 Firebase 库的依赖项行中指定相应的库版本。

    请注意,如果您在应用中使用多个 Firebase 库,我们强烈建议您使用 BoM 来管理库版本,从而确保所有版本都兼容。

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:24.2.1")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
  3. 在应用的清单中声明需具有 INTERNET 权限:
    <uses-permission android:name="android.permission.INTERNET" />

1. 部署模型

使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK 部署自定义 TensorFlow 模型。请参阅部署和管理自定义模型

将自定义模型添加到 Firebase 项目后,您可以使用指定的名称在应用中引用该模型。您可以随时部署新的 TensorFlow Lite 模型,并调用 getModel() 来将新模型下载到用户的设备上(见下文)。

2. 将模型下载到设备并初始化一个 TensorFlow Lite 解释器

如需在您的应用中使用 TensorFlow Lite 模型,请先使用 Firebase ML SDK 将模型的最新版本下载到设备上。然后,使用该模型实例化 TensorFlow Lite 解释器。

如需启动模型下载,请调用模型下载程序的 getModel() 方法,指定您在上传该模型时为其分配的名称、是否总是下载最新模型,以及您希望在什么条件下允许下载。

您可以从以下三种下载方式中进行选择:

下载类型 说明
LOCAL_MODEL 从设备获取本地模型。如果没有本地模型,则其行为类似于 LATEST_MODEL。如果您不想检查模型更新,请使用此下载类型。例如,您使用 Remote Config 检索模型名称,并始终以新名称上传模型(推荐)。
LOCAL_MODEL_UPDATE_IN_BACKGROUND 从设备获取本地模型,并在后台开始更新模型。如果没有本地模型,则其行为类似于 LATEST_MODEL
LATEST_MODEL 获取最新模型。如果本地模型是最新版本,将返回本地模型。否则,下载最新模型。此方式会阻塞进程,直到最新版本下载完毕(不推荐)。请仅在您明确需要最新版本的情况下才使用此方式。

您应该停用与模型相关的功能(例如使界面的一部分变灰或将其隐藏),直到您确认模型已下载。

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

许多应用会通过其初始化代码启动下载任务,您也可以在需要使用该模型之前随时启动下载任务。

3. 利用输入数据进行推断

获取模型的输入和输出形状

TensorFlow Lite 模型解释器采用一个或多个多维数组作为输入和输出。这些数组可以包含 byteintlongfloat 值。在将数据传递到模型或使用其结果之前,您必须知道模型使用的数组的数量和维度(“形状”)。

如果模型是您自己构建的,或者模型的输入和输出格式有文档说明,那么您可能已经知道这些信息。如果您不知道模型输入和输出的形状和数据类型,可以使用 TensorFlow Lite 解释器检查模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

输出示例:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

运行解释器

在确定模型的输入和输出格式后,需获取输入数据,然后对数据执行任何必要的转换,以便为模型生成形状正确的输入。

例如,如果您的图片分类模型的输入形状为 [1 224 224 3] 浮点值,则可以基于 Bitmap 对象生成输入 ByteBuffer,如以下示例所示:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

然后,分配一个足以容纳模型输出的 ByteBuffer,并将输入缓冲区和输出缓冲区传递给 TensorFlow Lite 解释器的 run() 方法。例如,如果输出形状为 [1 1000] 浮点值:

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

如何使用输出取决于您所使用的模型。

例如,如果您要执行分类,那么您可以在下一步中将结果的索引映射到它们所代表的标签:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

附录:模型的安全性

无论您以何种方式在 Firebase ML 中添加自己的 TensorFlow Lite 模型,Firebase ML 都会以标准序列化的 protobuf 格式将这些模型存储到本地存储空间中。

从理论上说,这意味着任何人都可以复制您的模型。但实际上,大多数模型都是针对具体的应用,且通过优化进行了混淆处理,因此,这一风险与竞争对手对您的代码进行反汇编和再利用的风险类似。但无论怎样,在您的应用中使用自定义模型之前,您应该了解这种风险。

在 Android API 级别 21 (Lollipop) 及更高版本中,模型会下载到从自动备份中排除的目录当中。

在 Android API 级别 20 和更低版本中,模型会下载到应用专用的内部存储空间中名为 com.google.firebase.ml.custom.models 的目录当中。如果您使用 BackupAgent 启用了文件备份,则可以选择排除此目录。