在 Android 上使用机器学习套件通过 TensorFlow Lite 模型进行推理

您可以使用机器学习套件通过 TensorFlow Lite 模型执行基于设备的推理。

此 API 需采用 Android SDK 级别 16 (Jelly Bean) 或更高版本。

如需查看此 API 的实际应用示例,请查看 GitHub 上的机器学习套件快速入门示例,也可以试用代码实验室

准备工作

  1. 如果您尚未将 Firebase 添加到自己的应用中,请按照入门指南中的步骤执行此操作。
  2. 在您的应用级 build.gradle 文件中添加机器学习套件的依赖项:
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:17.0.3'
    }
    
  3. 将您要使用的 TensorFlow 模型转换为 TensorFlow Lite 格式。请参阅 TOCO:TensorFlow Lite 优化转换器

托管或捆绑您的模型

要在您的应用中使用 TensorFlow Lite 模型进行推理,您必须先确保机器学习套件能够使用该模型。机器学习套件可以使用通过 Firebase 远程托管和/或与应用二进制文件捆绑的 TensorFlow Lite 模型。

通过在 Firebase 上托管模型,您可以在不发布新应用版本的情况下更新模型,并且可以使用远程配置和 A/B 测试为不同的用户组动态提供不同的模型。

如果您选择仅通过使用 Firebase 托管模型来提供模型,但不将其与应用捆绑,则可以缩小应用的初始下载体量。但请注意,如果模型未与您的应用捆绑,那么在应用首次下载模型之前,任何与模型相关的功能都将无法使用。

通过将您的模型与应用捆绑,您可以确保当 Firebase 托管的模型不可用时,应用的机器学习功能仍可正常运行。

在 Firebase 上托管模型

要在 Firebase 上托管您的 TensorFlow Lite 模型,请执行以下操作:

  1. Firebase 控制台机器学习套件 (ML Kit) 部分中,点击自定义标签。
  2. 点击添加自定义模型(或再添加一个模型)。
  3. 指定一个名称,用于在 Firebase 项目中识别您的模型,然后上传 TensorFlow Lite 模型文件(通常以 .tflite.lite 结尾)。
  4. 在您应用的清单中声明需具有 INTERNET 权限:
    <uses-permission android:name="android.permission.INTERNET" />
    

将自定义模型添加到 Firebase 项目后,您可以使用指定的名称在应用中引用该模型。您随时可以上传新的 TensorFlow Lite 模型,并且您的应用会下载新模型,然后在应用下次重启时开始使用新模型。您可以定义应用尝试更新模型时所需满足的设备条件(请参见下文)。

将模型与应用捆绑

要将 TensorFlow Lite 模型与您的应用捆绑,请将模型文件(通常以 .tflite.lite 结尾)复制到您应用的 assets/ 文件夹。(若要先创建此文件夹,您可以右键点击 app/ 文件夹,然后依次点击新建 > 文件夹 > 资源文件夹。)

然后,将以下内容添加到应用的 build.gradle 文件中,以确保 Gradle 在构建应用时不会压缩模型:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

模型文件将包含在应用软件包中,并作为原始资源提供给机器学习套件使用。

加载模型

要在您的应用中使用 TensorFlow Lite 模型,请首先为机器学习套件配置模型所在的位置:在云端(使用 Firebase)和/或本地存储空间中。如果您同时指定了云端模型来源和本地模型来源,则机器学习套件将在云端模型来源可用时使用云端模型来源,并在云端模型来源不可用时回退为使用本地存储的模型。

配置 Firebase 托管的模型来源

如果您使用 Firebase 托管您的模型,请创建一个 FirebaseCloudModelSource 对象,并指明您在上传该模型时为其指定的名称,以及机器学习套件初次下载模型以及后续下载模型更新版本时需要满足的条件

Java
Android

FirebaseModelDownloadConditions.Builder conditionsBuilder =
        new FirebaseModelDownloadConditions.Builder().requireWifi();
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
    // Enable advanced conditions on Android Nougat and newer.
    conditionsBuilder = conditionsBuilder
            .requireCharging()
            .requireDeviceIdle();
}
FirebaseModelDownloadConditions conditions = conditionsBuilder.build();

// Build a FirebaseCloudModelSource object by specifying the name you assigned the model
// when you uploaded it in the Firebase console.
FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model")
        .enableModelUpdates(true)
        .setInitialDownloadConditions(conditions)
        .setUpdatesDownloadConditions(conditions)
        .build();
FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource);

Kotlin
Android

var conditionsBuilder: FirebaseModelDownloadConditions.Builder =
        FirebaseModelDownloadConditions.Builder().requireWifi()
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
    // Enable advanced conditions on Android Nougat and newer.
    conditionsBuilder = conditionsBuilder
            .requireCharging()
            .requireDeviceIdle()
}
val conditions = conditionsBuilder.build()

// Build a FirebaseCloudModelSource object by specifying the name you assigned the model
// when you uploaded it in the Firebase console.
val cloudSource = FirebaseCloudModelSource.Builder("my_cloud_model")
        .enableModelUpdates(true)
        .setInitialDownloadConditions(conditions)
        .setUpdatesDownloadConditions(conditions)
        .build()
FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource)

配置本地模型来源

如果您已将模型与应用捆绑,请创建一个 FirebaseLocalModelSource 对象,指定 TensorFlow Lite 模型的文件名,并为模型指定一个您将在下一步中使用的名称。

Java
Android

FirebaseLocalModelSource localSource =
        new FirebaseLocalModelSource.Builder("my_local_model")  // Assign a name to this model
                .setAssetFilePath("my_model.tflite")
                .build();
FirebaseModelManager.getInstance().registerLocalModelSource(localSource);

Kotlin
Android

val localSource = FirebaseLocalModelSource.Builder("my_local_model") // Assign a name to this model
        .setAssetFilePath("my_model.tflite")
        .build()
FirebaseModelManager.getInstance().registerLocalModelSource(localSource)

根据模型来源创建解析器

配置模型来源后,创建一个含云端模型来源名称和/或本地模型来源名称的 FirebaseModelOptions 对象,并使用该对象获取 FirebaseModelInterpreter 的实例:

Java
Android

FirebaseModelOptions options = new FirebaseModelOptions.Builder()
        .setCloudModelName("my_cloud_model")
        .setLocalModelName("my_local_model")
        .build();
FirebaseModelInterpreter firebaseInterpreter =
        FirebaseModelInterpreter.getInstance(options);

Kotlin
Android

val options = FirebaseModelOptions.Builder()
        .setCloudModelName("my_cloud_model")
        .setLocalModelName("my_local_model")
        .build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

指定模型的输入和输出

接下来,配置模型解析器的输入和输出格式。

TensorFlow Lite 模型支持输入一个或多个多维数组,并可在输出时生成一个或多个多维数组。这些数组包含 byteintlongfloat 值。您必须根据模型采用的数组个数和维度(“形状”)配置机器学习套件。

如果您不知道模型输入和输出的形状和数据类型,可以使用 TensorFlow Lite Python 解析器检查模型。例如:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

确定模型的输入和输出格式后,可以通过创建 FirebaseModelInputOutputOptions 对象来配置应用的模型解析器。

例如,一个浮点图片分类模型可能会被输入一个 float 值 Nx224x224x3 数组(这表示一批 224x224 的三通道 (RGB) 图片,包含 N 个),并在输出时生成一个包含 1000 个 float 值的列表(每个值表示该图片属于此模型预测的 1000 个类别中的某一个类别的概率)。

对于此类模型,您需要按如下所示配置模型解析器的输入和输出:

Java
Android

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin
Android

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

利用输入数据进行推理

最后,要使用模型执行推理,请获取输入数据并对数据执行任何必要的转换,以获得模型的正确形状的输入数组。

例如,如果您的图片分类模型的输入形状为 [1 224 224 3] 浮点值,则可以从 Bitmap 对象生成输入数组,如以下示例所示:

Java
Android

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin
Android

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

然后,根据输入数据创建一个 FirebaseModelInputs 对象,并将该对象以及此模型的输入和输出规范传递给模型解析器run 方法:

Java
Android

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin
Android

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener(
                object : OnFailureListener {
                    override fun onFailure(e: Exception) {
                        // Task failed with an exception
                        // ...
                    }
                })

如果调用成功,您可以通过调用传递给成功侦听器的对象的 getOutput() 方法来获取输出。例如:

Java
Android

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin
Android

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

如何使用输出取决于您所使用的模型。

例如,如果您要执行分类,那么您可以在下一步中将结果的索引映射到它们所代表的标签。

Java
Android

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin
Android

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

附录:模型的安全性

无论您是以何种方式在机器学习套件中添加您的 TensorFlow Lite 模型,机器学习套件会以标准序列化的 protobuf 格式将所有模型存储到本地存储空间中。

这意味着,从理论上说,任何人都可以复制您的模型。但实际上,大多数模型都是应用专用的,且通过优化进行了混淆处理,因此,风险类似于竞争对手反汇编并再利用您的代码。无论怎样,在您的应用中使用自定义模型之前,您应该了解这种风险。

在 Android API 级别 21 (Lollipop) 及更高版本中,模型会下载到从自动备份中排除的目录当中。

在 Android API 级别 20 和更低版本中,模型会下载到应用专用的内部存储中名为 com.google.firebase.ml.custom.models 的目录当中。如果您使用 BackupAgent 启用了文件备份,则可以选择排除此目录。

发送以下问题的反馈:

此网页
需要帮助?请访问我们的支持页面