您可以使用机器学习套件,通过 TensorFlow Lite 模型在设备上执行推理。
此 API 需采用 Android SDK 级别 16 (Jelly Bean) 或更高版本。
准备工作
- 将 Firebase 添加到您的 Android 项目(如果尚未添加)。
- 将 Android 版机器学习套件库的依赖项添加到您的模块(应用级层)Gradle 文件(通常为
app/build.gradle
):apply plugin: 'com.android.application' apply plugin: 'com.google.gms.google-services' dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3' }
- 将您要使用的 TensorFlow 模型转换为 TensorFlow Lite 格式。 请参阅 TOCO:TensorFlow Lite 优化转换器。
托管或捆绑您的模型
如需在您的应用中使用 TensorFlow Lite 模型进行推理,您必须先确保机器学习套件能够使用该模型。机器学习套件可以使用在 Firebase 中远程托管和/或与应用二进制文件捆绑的 TensorFlow Lite 模型。
通过在 Firebase 上托管模型,您可以在不发布新应用版本的情况下更新模型,并且可以使用 Remote Config 和 A/B Testing 为不同的用户组动态运用不同的模型。
如果您选择仅通过在 Firebase 中托管而不是与应用捆绑的方式来提供模型,可以缩小应用的初始下载文件大小。但请注意,如果模型未与您的应用捆绑,那么在应用首次下载模型之前,任何与模型相关的功能都将无法使用。
将您的模型与应用捆绑,可以确保当 Firebase 托管的模型不可用时,应用的机器学习功能仍可正常运行。
在 Firebase 上托管模型
如需在 Firebase 上托管您的 TensorFlow Lite 模型,请执行以下操作:
- 在 Firebase 控制台的机器学习套件部分中,点击自定义标签页。
- 点击添加自定义模型(或再添加一个模型)。
- 指定一个名称,用于在 Firebase 项目中识别您的模型,然后上传 TensorFlow Lite 模型文件(通常以
.tflite
或.lite
结尾)。 - 在您应用的清单中声明需具有 INTERNET 权限:
<uses-permission android:name="android.permission.INTERNET" />
将自定义模型添加到 Firebase 项目后,您可以使用自己指定的名称在应用中引用该模型。您随时可以上传新的 TensorFlow Lite 模型,您的应用会下载新模型,然后在下次重启时开始使用新模型。您可以定义应用尝试更新模型时所需满足的设备条件(请参见下文)。
将模型与应用捆绑
如需将 TensorFlow Lite 模型与您的应用捆绑,请将模型文件(通常以 .tflite
或 .lite
结尾)复制到您应用的 assets/
文件夹。(您可能需要先创建此文件夹,方法是右键点击 app/
文件夹,然后依次点击新建 > 文件夹 > Assets 文件夹 (New > Folder > Assets Folder)。)
然后,将以下内容添加到应用的 build.gradle
文件中,以确保 Gradle 在构建应用时不会压缩模型:
android {
// ...
aaptOptions {
noCompress "tflite" // Your model's file extension: "tflite", "lite", etc.
}
}
模型文件将包含在应用软件包中,并作为原始资源提供给机器学习套件使用。
加载模型
如需在您的应用中使用 TensorFlow Lite 模型,请首先为机器学习套件配置模型所在的位置:使用 Firebase 远程托管、在本地存储空间,或者两者同时。如果您同时指定了本地模型和远程模型,则可以在远程模型可用时使用远程模型,并在远程模型不可用时回退为使用本地存储的模型。配置 Firebase 托管的模型
如果您使用 Firebase 托管您的模型,请创建一个 FirebaseCustomRemoteModel
对象,并在上传模型时指定您分配给模型的名称:
Java
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
Kotlin+KTX
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
然后,启动模型下载任务,指定允许下载的条件。如果模型不在设备上,或模型有较新的版本,则该任务将从 Firebase 异步下载模型:
Java
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
.requireWifi()
.build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
.addOnCompleteListener(new OnCompleteListener<Void>() {
@Override
public void onComplete(@NonNull Task<Void> task) {
// Success.
}
});
Kotlin+KTX
val conditions = FirebaseModelDownloadConditions.Builder()
.requireWifi()
.build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
.addOnCompleteListener {
// Success.
}
许多应用会通过其初始化代码启动下载任务,您也可以在需要使用该模型之前随时启动下载任务。
配置本地模型
如果您将模型与应用捆绑在一起,请创建 FirebaseCustomLocalModel
对象,并指定 TensorFlow Lite 模型的文件名:
Java
FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
.setAssetFilePath("your_model.tflite")
.build();
Kotlin+KTX
val localModel = FirebaseCustomLocalModel.Builder()
.setAssetFilePath("your_model.tflite")
.build()
根据模型创建解释器
配置模型来源后,根据其中一个模型创建 FirebaseModelInterpreter
对象。
如果您只有本地捆绑的模型,则只需根据 FirebaseCustomLocalModel
对象创建一个解释器即可:
Java
FirebaseModelInterpreter interpreter;
try {
FirebaseModelInterpreterOptions options =
new FirebaseModelInterpreterOptions.Builder(localModel).build();
interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
// ...
}
Kotlin+KTX
val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)
如果您使用远程托管的模型,则必须在运行之前检查确认该模型已下载。您可以使用模型管理器的 isModelDownloaded()
方法检查模型下载任务的状态。
虽然您只需在运行解释器之前确认这一点,但如果您同时拥有远程托管模型和本地捆绑模型,则在实例化模型解释器时执行此检查可能是有意义的:如果已下载远程模型,则根据该模型创建解释器,否则根据本地模型创建。
Java
FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
.addOnSuccessListener(new OnSuccessListener<Boolean>() {
@Override
public void onSuccess(Boolean isDownloaded) {
FirebaseModelInterpreterOptions options;
if (isDownloaded) {
options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
} else {
options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
}
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
// ...
}
});
Kotlin+KTX
FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
.addOnSuccessListener { isDownloaded ->
val options =
if (isDownloaded) {
FirebaseModelInterpreterOptions.Builder(remoteModel).build()
} else {
FirebaseModelInterpreterOptions.Builder(localModel).build()
}
val interpreter = FirebaseModelInterpreter.getInstance(options)
}
如果您只有远程托管的模型,应停用与模型相关的功能(例如使界面的一部分变灰或将其隐藏),直到您确认模型已下载。这可以通过将监听器附加到模型管理器的 download()
方法来实现:
Java
FirebaseModelManager.getInstance().download(remoteModel, conditions)
.addOnSuccessListener(new OnSuccessListener<Void>() {
@Override
public void onSuccess(Void v) {
// Download complete. Depending on your app, you could enable
// the ML feature, or switch from the local model to the remote
// model, etc.
}
});
Kotlin+KTX
FirebaseModelManager.getInstance().download(remoteModel, conditions)
.addOnCompleteListener {
// Download complete. Depending on your app, you could enable the ML
// feature, or switch from the local model to the remote model, etc.
}
指定模型的输入和输出
接下来,配置模型解释器的输入和输出格式。
TensorFlow Lite 模型采用一个或多个多维数组作为输入和生成的输出。这些数组可以包含 byte
、int
、long
或 float
值。您必须根据模型采用的数组个数和维度(“形状”)配置机器学习套件。
如果您不知道模型输入和输出的形状和数据类型,可以使用 TensorFlow Lite Python 解释器检查模型。例如:
import tensorflow as tf interpreter = tf.lite.Interpreter(model_path="my_model.tflite") interpreter.allocate_tensors() # Print input shape and type print(interpreter.get_input_details()[0]['shape']) # Example: [1 224 224 3] print(interpreter.get_input_details()[0]['dtype']) # Example: <class 'numpy.float32'> # Print output shape and type print(interpreter.get_output_details()[0]['shape']) # Example: [1 1000] print(interpreter.get_output_details()[0]['dtype']) # Example: <class 'numpy.float32'>
确定模型的输入和输出格式后,可以通过创建 FirebaseModelInputOutputOptions
对象来配置应用的模型解释器。
例如,浮点图片分类模型可能会接受 float
值的 Nx224x224x3 数组作为输入(表示一批 N 个 224x224 的三通道 (RGB) 图片),并会输出包含 1000 个 float
值的列表(每个值表示该图片属于此模型预测的 1000 个类别中其中一个类别的概率)。
对于此类模型,您需要按如下所示配置模型解释器的输入和输出:
Java
FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5}) .build();
Kotlin+KTX
val inputOutputOptions = FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3)) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5)) .build()
利用输入数据进行推断
最后,要使用模型执行推理,请获取输入数据并对数据执行任何必要的转换,以获得模型的正确形状的输入数组。例如,如果您的图片分类模型的输入形状为 [1 224 224 3] 浮点值,则可以基于 Bitmap
对象生成输入数组,如以下示例所示:
Java
Bitmap bitmap = getYourInputImage(); bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true); int batchNum = 0; float[][][][] input = new float[1][224][224][3]; for (int x = 0; x < 224; x++) { for (int y = 0; y < 224; y++) { int pixel = bitmap.getPixel(x, y); // Normalize channel values to [-1.0, 1.0]. This requirement varies by // model. For example, some models might require values to be normalized // to the range [0.0, 1.0] instead. input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f; input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f; input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f; } }
Kotlin+KTX
val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true) val batchNum = 0 val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } } for (x in 0..223) { for (y in 0..223) { val pixel = bitmap.getPixel(x, y) // Normalize channel values to [-1.0, 1.0]. This requirement varies by // model. For example, some models might require values to be normalized // to the range [0.0, 1.0] instead. input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f } }
然后,根据输入数据创建一个 FirebaseModelInputs
对象,并将该对象以及此模型的输入和输出规范传递给模型解释器的 run
方法:
Java
FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input) // add() as many input arrays as your model requires .build(); firebaseInterpreter.run(inputs, inputOutputOptions) .addOnSuccessListener( new OnSuccessListener<FirebaseModelOutputs>() { @Override public void onSuccess(FirebaseModelOutputs result) { // ... } }) .addOnFailureListener( new OnFailureListener() { @Override public void onFailure(@NonNull Exception e) { // Task failed with an exception // ... } });
Kotlin+KTX
val inputs = FirebaseModelInputs.Builder() .add(input) // add() as many input arrays as your model requires .build() firebaseInterpreter.run(inputs, inputOutputOptions) .addOnSuccessListener { result -> // ... } .addOnFailureListener { e -> // Task failed with an exception // ... }
如果该调用成功,您可以通过调用传递给成功监听器的对象的 getOutput()
方法来获取输出。例如:
Java
float[][] output = result.getOutput(0); float[] probabilities = output[0];
Kotlin+KTX
val output = result.getOutput<Array<FloatArray>>(0) val probabilities = output[0]
如何使用输出取决于您所使用的模型。
例如,如果您要执行分类,则可以在下一步中将结果的索引映射到它们所代表的标签:
Java
BufferedReader reader = new BufferedReader( new InputStreamReader(getAssets().open("retrained_labels.txt"))); for (int i = 0; i < probabilities.length; i++) { String label = reader.readLine(); Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i])); }
Kotlin+KTX
val reader = BufferedReader( InputStreamReader(assets.open("retrained_labels.txt"))) for (i in probabilities.indices) { val label = reader.readLine() Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i])) }
附录:模型的安全性
无论您以何种方式在机器学习套件中添加自己的 TensorFlow Lite 模型,机器学习套件会以标准序列化的 protobuf 格式将所有模型存储到本地存储空间中。
这意味着,从理论上说,任何人都可以复制您的模型。但实际上,大多数模型都是针对具体的应用,且通过优化进行了混淆处理,因此,这一风险与竞争对手对您的代码进行反汇编和再利用的风险类似。无论怎样,在您的应用中使用自定义模型之前,您应该了解这种风险。
在 Android API 级别 21 (Lollipop) 及更高版本中,模型会下载到从自动备份中排除的目录当中。
在 Android API 级别 20 和更低版本中,模型会下载到应用专用的内部存储空间中名为 com.google.firebase.ml.custom.models
的目录当中。如果您使用 BackupAgent
启用了文件备份,则可以选择排除此目录。