使用 AutoML 训练的模型检测图片中的对象 (Android)

使用 AutoML Vision Edge 训练您自己的模型后,您就可以在自己的应用中利用该模型检测图片中的对象。

您可以通过以下两种方式集成经过 AutoML Vision Edge 训练的模型:可以将模型嵌入应用的资源文件夹中以捆绑模型,也可以从 Firebase 动态下载模型。

模型捆绑方式
捆绑在您的应用中
  • 模型是应用的 APK 的一部分
  • 即使 Android 设备处于离线状态,模型也可立即使用
  • 不需要 Firebase 项目
使用 Firebase 进行托管

准备工作

  1. 如果您想下载模型,请务必将 Firebase 添加到您的 Android 项目(如果尚未添加)。捆绑模型时不需要这样做。

  2. 将 TensorFlow Lite 任务库的依赖项添加到您的模块的应用级 Gradle 文件(通常为 app/build.gradle):

    如需将模型与您的应用捆绑在一起,请执行以下操作:

    dependencies {
      // ...
      // Object detection with a bundled Auto ML model
      implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly-SNAPSHOT'
    }
    

    如需从 Firebase 动态下载模型,还要添加 Firebase ML 依赖项:

    dependencies {
      // ...
      // Object detection with an Auto ML model deployed to Firebase
      implementation platform('com.google.firebase:firebase-bom:26.1.1')
      implementation 'com.google.firebase:firebase-ml-model-interpreter'
    
      implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
    }
    

1.加载模型

配置本地模型来源

如需将模型与您的应用捆绑在一起,请执行以下操作:

  1. 将模型自您从 Google Cloud 控制台下载的 zip 归档文件解压缩。
  2. 将模型添加到应用软件包中:
    1. 如果您的项目中没有资源文件夹,请创建一个,方法是右键点击 app/ 文件夹,然后依次点击新建 > 文件夹 > 资源文件夹 (New > Folder > Assets Folder)。
    2. 将包含嵌入元数据的 tflite 模型文件复制到资源文件夹。
  3. 将以下内容添加到应用的 build.gradle 文件中,以确保 Gradle 在构建应用时不会压缩模型文件:

    android {
        // ...
        aaptOptions {
            noCompress "tflite"
        }
    }
    

    模型文件将包含在应用软件包中,并作为原始资源提供。

配置 Firebase 托管的模型来源

如需使用远程托管的模型,请创建一个 RemoteModel 对象,指明您在发布该模型时为其分配的名称:

Java

// Specify the name you assigned when you deployed the model.
FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin

// Specify the name you assigned when you deployed the model.
val remoteModel =
    FirebaseCustomRemoteModel.Builder("your_model_name").build()

然后,启动模型下载任务,指定允许下载模型的条件。如果模型不在设备上,或模型有较新的版本,则该任务将从 Firebase 异步下载模型:

Java

DownloadConditions downloadConditions = new DownloadConditions.Builder()
        .requireWifi()
        .build();
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin

val downloadConditions = DownloadConditions.Builder()
    .requireWifi()
    .build()
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
    .addOnSuccessListener {
        // Success.
    }

许多应用会通过其初始化代码启动下载任务,您也可以在需要使用该模型之前随时启动下载任务。

根据模型创建对象检测器

配置模型来源后,根据其中一个模型创建 ObjectDetector 对象。

如果您只有本地捆绑的模型,只需根据模型文件创建对象检测器,然后配置您需要的置信度得分阈值(请参阅评估您的模型):

Java

// Initialization
ObjectDetectorOptions options = ObjectDetectorOptions.builder()
    .setScoreThreshold(0)  // Evaluate your model in the Google Cloud Console
                           // to determine an appropriate value.
    .build();
ObjectDetector objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options);

Kotlin

// Initialization
val options = ObjectDetectorOptions.builder()
    .setScoreThreshold(0)  // Evaluate your model in the Google Cloud Console
                           // to determine an appropriate value.
    .build()
val objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options)

如果您使用的是远程托管的模型,则必须在运行之前检查该模型是否已下载。您可以使用模型管理器的 isModelDownloaded() 方法检查模型下载任务的状态。

虽然您只需在运行对象检测器之前确认这一点,但如果您同时拥有远程托管模型和本地捆绑模型,则可以考虑在实例化对象检测器时执行此检查:如果已下载,则根据远程模型创建对象检测器,否则根据本地模型进行创建。

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
            }
        });

Kotlin

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener { success ->

        }

如果您只有远程托管的模型,则应停用与模型相关的功能(例如使界面的一部分变灰或将其隐藏),直到您确认模型已下载。这可以通过将监听器附加到模型管理器的 download() 方法来实现。

确定模型已下载后,请根据模型文件创建对象检测器:

Java

FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    ObjectDetectorOptions options = ObjectDetectorOptions.builder()
                            .setScoreThreshold(0)
                            .build();
                    objectDetector = ObjectDetector.createFromFileAndOptions(
                            getApplicationContext(), modelFile.getPath(), options);
                }
            }
        });

Kotlin

FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnSuccessListener { modelFile ->
            val options = ObjectDetectorOptions.builder()
                    .setScoreThreshold(0f)
                    .build()
            objectDetector = ObjectDetector.createFromFileAndOptions(
                    applicationContext, modelFile.path, options)
        }

2. 准备输入图片

接下来,对于每张您想要加标签的图片,基于图片创建一个 TensorImage 对象。您可以使用 fromBitmap 方法根据 Bitmap 创建 TensorImage 对象:

Java

TensorImage image = TensorImage.fromBitmap(bitmap);

Kotlin

val image = TensorImage.fromBitmap(bitmap)

如果您的图片数据不在 Bitmap 中,您可以加载像素数组(如 TensorFlow Lite 文档中所示)。

3. 运行对象检测器

如需检测图片中的对象,请将 TensorImage 对象传递给 ObjectDetectordetect() 方法。

Java

List<Detection> results = objectDetector.detect(image);

Kotlin

val results = objectDetector.detect(image)

4. 获取已加标签的对象的相关信息

如果对象检测操作成功,则会返回 Detection 对象的列表。每个 Detection 对象代表在图片中检测到的内容。您可以获取每个对象的边界框及其标签。

例如:

Java

for (Detection result : results) {
    RectF bounds = result.getBoundingBox();
    List<Category> labels = result.getCategories();
}

Kotlin

for (result in results) {
    val bounds = result.getBoundingBox()
    val labels = result.getCategories()
}

提高实时性能的相关提示

如果要在实时应用中为图片加标签,请遵循以下准则以实现最佳帧速率:

  • 限制图片标记器的调用次数。如果在图片标记器运行时有新的视频帧可用,请丢弃该帧。如需查看示例,请参阅快速入门示例应用中的 VisionProcessorBase 类。
  • 如果要将图片标记器的输出作为图形叠加在输入图片上,请先获取结果,然后在一个步骤中完成图片的呈现和叠加。采用这一方法,每个输入帧只需在显示表面呈现一次。如需查看示例,请参阅快速入门示例应用中的 CameraSourcePreviewGraphicOverlay 类。
  • 如果您使用 Camera2 API,请以 ImageFormat.YUV_420_888 格式捕获图片。

    如果您使用旧版 Camera API,请以 ImageFormat.NV21 格式捕获图片。