在 Android 上使用机器学习套件通过 TensorFlow Lite 模型进行推理

您可以使用机器学习套件通过 TensorFlow Lite 模型执行基于设备的推理。

此 API 需采用 Android SDK 级别 16 (Jelly Bean) 或更高版本。

如需了解此 API 的实际应用示例,请查看 GitHub 上的机器学习套件快速入门示例,您还可以试用代码实验室

准备工作

  1. 如果您尚未将 Firebase 添加到自己的应用中,请按照入门指南中的步骤执行此操作。
  2. 在您的应用级 build.gradle 文件中添加机器学习套件的依赖项:
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:16.0.0'
    }
    
  3. 将您要使用的 TensorFlow 模型转换为 TensorFlow Lite (tflite) 格式。请参阅 TOCO:TensorFlow Lite 优化转换器 (TOCO: TensorFlow Lite Optimizing Converter)。

托管或捆绑您的模型

要在您的应用中使用 TensorFlow Lite 模型进行推理,您必须先确保机器学习套件能够使用该模型。机器学习套件可以使用通过 Firebase 远程托管和/或存储在本地设备上的 TensorFlow Lite 模型。

如果既在 Firebase 上托管又在本地存储该模型,那么您可以确保在该模型推出最新版本时使用其最新版本;而当 Firebase 托管的模型不可用时,应用的机器学习功能仍可正常运行。

模型的安全性

无论您是以何种方式在机器学习套件中添加您的 TensorFlow Lite 模型,机器学习套件会以标准序列化的 protobuf 格式将所有模型存储到本地存储空间中。

理论上来说,这意味着任何人都可以复制您的模型。但实际上,大多数模型都是应用专用的,且通过优化进行了混淆处理,因此,风险类似于竞争对手反汇编并再利用您的代码。无论怎样,在您的应用中使用自定义模型之前,您应该了解这种风险。

在 Android API 级别 21 (Lollipop) 及更高版本中,模型会下载到从自动备份中排除的目录当中。

在 Android API 级别 20 和更低版本中,模型会下载到应用专用的内部存储中名为 com.google.firebase.ml.custom.models 的目录当中。如果您使用 BackupAgent 启用了文件备份,则可以选择排除此目录。

在 Firebase 上托管模型

要在 Firebase 上托管您的 TensorFlow Lite 模型,请执行以下操作:

  1. Firebase 控制台ML Kit(机器学习套件)部分中,点击自定义标签。
  2. 点击添加自定义模型(或再添加一个模型)。
  3. 指定一个名称,用于在 Firebase 项目中识别您的模型,然后上传 .tflite 文件。
  4. 在您应用的清单中声明需具有 INTERNET 权限:
    <uses-permission android:name="android.permission.INTERNET" />
    
  5. 如果您的模型针对的是 Android SDK 级别 18 (Jellybean) 或更低版本添加,则还需声明具有以下权限:
    <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"
                     android:maxSdkVersion="18" />
    <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"
                     android:maxSdkVersion="18" />
    

将自定义模型添加到 Firebase 项目后,您可以使用指定的名称在应用中引用该模型。您随时可以为模型上传新的 .tflite 文件,并且您的应用会下载新模型,然后在应用下次重启时开始使用此新模型。您可以定义应用尝试更新模型时所需满足的设备条件(请参见下文)。

在本地提供模型

要在本地提供 TensorFlow Lite 模型,您可以将该模型与应用捆绑在一起,或者在应用中从您自己的服务器下载该模型。

要将 TensorFlow Lite 模型与您的应用捆绑在一起,请将 .tflite 文件复制到您应用的 assets/ 文件夹(您可能需要先创建此文件夹,方法是右键点击 app/ 文件夹,然后点击新建 > 文件夹 > 资源文件夹。)

然后,将以下内容添加到您的项目的 build.gradle 文件中:

android {

    // ...

    aaptOptions {
        noCompress "tflite"
    }
}

.tflite 文件将包含在应用软件包中,并作为原始资源提供给机器学习套件使用。

如果您选择将该模型托管在自己的服务器上,则可以适时在应用中将该模型下载到本地存储空间。然后,该模型将以本地文件的形式提供给机器学习套件。

加载模型

要使用 TensorFlow Lite 模型进行推理,请先指定 .tflite 文件的位置。

如果您使用 Firebase 托管您的模型,请创建一个 FirebaseCloudModelSource 对象,指明您在上传该模型时为其指定的名称,以及机器学习套件初次下载模型以及后续当有模型更新推出时下载更新需要满足的条件

FirebaseModelDownloadConditions.Builder conditionsBuilder =
        new FirebaseModelDownloadConditions.Builder().requireWifi();
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
    // Enable advanced conditions on Android Nougat and newer.
    conditionsBuilder = conditionsBuilder
            .requireCharging()
            .requireDeviceIdle();
}
FirebaseModelDownloadConditions conditions = conditionsBuilder.build();

// Build a FirebaseCloudModelSource object by specifying the name you assigned the model
// when you uploaded it in the Firebase console.
FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model")
        .enableModelUpdates(true)
        .setInitialDownloadConditions(conditions)
        .setUpdatesDownloadConditions(conditions)
        .build();
FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource);

如果您将模型与应用捆绑在了一起,或者在运行时从您自己的主机上下载了模型,请创建一个 FirebaseLocalModelSource 对象,指定 .tflite 模型的文件名以及该文件是原始资源(捆绑的模型)还是存储在本地的资源(在运行时下载的模型)。

FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model")
        .setAssetFilePath("mymodel.tflite")  // Or setFilePath if you downloaded from your host
        .build();
FirebaseModelManager.getInstance().registerLocalModelSource(localSource);

然后,使用云端模型来源和/或本地模型来源的名称创建一个 FirebaseModelOptions 对象,并使用该对象获取 FirebaseModelInterpreter 的一个实例:

FirebaseModelOptions options = new FirebaseModelOptions.Builder()
        .setCloudModelName("my_cloud_model")
        .setLocalModelName("my_local_model")
        .build();
FirebaseModelInterpreter firebaseInterpreter =
        FirebaseModelInterpreter.getInstance(options);

如果您同时指定了云端模型来源和本地模型来源,则模型解析器将在云端模型可用时使用云端模型,并在云端模型不可用时转为使用本地模型。

指定模型的输入和输出

接下来,您必须通过创建一个 FirebaseModelInputOutputOptions 对象来指定模型的输入和输出格式。

TensorFlow Lite 模型支持输入一个或多个多维数组,并可在输出时生成一个或多个多维数组。这些数组包含 byteintlongfloat 值。您必须根据模型采用的数组个数和维度(“形状”)配置机器学习套件。

例如,一个图片分类模型可能会被输入一个 1x640x480x3 字节数组(这表示一张 640x480 的 24 位全彩图片),并在输出时生成一个包含 1000 个 float 值的列表(每个值表示该图片属于此模型预测的 1000 个类别中的某一个类别的概率)。

FirebaseModelInputOutputOptions inputOutputOptions =
    new FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3})
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
        .build();

利用输入数据进行推理

最后,要使用此模型执行推理,请利用模型输入创建一个 FirebaseModelInputs 对象,并将该对象以及模型的输入和输出规范传递给您的模型解析器run 方法:

byte[][][][] input = new byte[1][640][480][3];
input = getYourInputData();
FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
    .add(input)  // add() as many input arrays as your model requires
    .build();
Task<FirebaseModelOutputs> result =
    firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
          new OnSuccessListener<FirebaseModelOutputs>() {
            @Override
            public void onSuccess(FirebaseModelOutputs result) {
              // ...
            }
          })
        .addOnFailureListener(
          new OnFailureListener() {
            @Override
            public void onFailure(@NonNull Exception e) {
              // Task failed with an exception
              // ...
            }
          });

您可以通过调用传递给成功侦听器的对象的 getOutput() 方法来获取输出。例如:

float[][] output = result.<float[][]>getOutput(0);
float[] probabilities = output[0];

如何使用输出取决于您使用什么模型。例如,如果您要执行分类,那么您可以在下一步中将结果的索引映射到它们所代表的标签。

发送以下问题的反馈:

此网页
需要帮助?请访问我们的支持页面