在 iOS 上使用机器学习套件通过 TensorFlow Lite 模型进行推理

您可以使用机器学习套件通过 TensorFlow Lite 模型执行基于设备的推理。

机器学习套件只能在运行 iOS 9 或更高版本的设备上使用 TensorFlow Lite 模型。

如需了解此 API 的实际应用示例,请查看 GitHub 上的机器学习套件快速入门示例

准备工作

  1. 如果您尚未将 Firebase 添加到自己的应用中,请按照入门指南中的步骤执行此操作。
  2. 在 Podfile 中添加机器学习套件库:
    pod 'Firebase/Core'
    pod 'Firebase/MLModelInterpreter'
    
    安装或更新项目的 Pod 之后,请务必使用 Xcode 项目的 .xcworkspace 来打开项目。
  3. 在您的应用中导入 Firebase:

    Swift

    import Firebase

    Objective-C

    @import Firebase;
  4. 将您要使用的 TensorFlow 模型转换为 TensorFlow Lite (tflite) 格式。请参阅 TOCO:TensorFlow Lite 优化转换器 (TOCO: TensorFlow Lite Optimizing Converter)。

托管或捆绑您的模型

要在您的应用中使用 TensorFlow Lite 模型进行推理,您必须先确保机器学习套件能够使用该模型。机器学习套件可以使用通过 Firebase 远程托管和/或存储在本地设备上的 TensorFlow Lite 模型。

如果既在 Firebase 上托管又在本地存储该模型,那么您可以确保在该模型推出最新版本时使用其最新版本;而当 Firebase 托管的模型不可用时,应用的机器学习功能仍可正常运行。

模型的安全性

无论您是以何种方式在机器学习套件中添加您的 TensorFlow Lite 模型,机器学习套件会以标准序列化的 protobuf 格式将所有模型存储到本地存储空间中。

理论上来说,这意味着任何人都可以复制您的模型。但实际上,大多数模型都是应用专用的,且通过优化进行了混淆处理,因此,风险类似于竞争对手反汇编并再利用您的代码。无论怎样,在您的应用中使用自定义模型之前,您应该了解这种风险。

在 Firebase 上托管模型

要在 Firebase 上托管您的 TensorFlow Lite 模型,请执行以下操作:

  1. Firebase 控制台ML Kit(机器学习套件)部分中,点击 自定义标签。
  2. 点击添加自定义模型(或再添加一个模型)。
  3. 指定一个名称,用于在 Firebase 项目中识别您的模型,然后上传 .tflite 文件。

将自定义模型添加到 Firebase 项目后,您可以使用指定的名称在应用中引用该模型。您随时可以为模型上传新的 .tflite 文件,并且您的应用会下载新模型,然后在应用下次重启时开始使用此新模型。您可以定义应用尝试更新模型时所需满足的设备条件(请参见下文)。

在本地提供模型

要在本地提供 TensorFlow Lite 模型,您可以将该模型与应用捆绑在一起,或者在运行时从您自己的服务器下载该模型。

要将 TensorFlow Lite 模型与您的应用捆绑在一起,请将 .tflite 文件添加到您的 Xcode 项目中,并在执行此操作时注意选择 Copy bundle resources.tflite 文件将包含在应用软件包中,并提供给机器学习套件使用。

如果您选择将该模型托管在自己的服务器上,则可以适时在应用中将该模型下载到本地存储空间。然后,该模型将以本地文件的形式提供给机器学习套件。

加载模型

要使用 TensorFlow Lite 模型进行推理,请先指定 .tflite 文件的位置。

如果您使用 Firebase 托管您的模型,请注册一个 CloudModelSource 对象,指明您在上传该模型时为其指定的名称,以及机器学习套件在什么条件下初次下载模型,后续在什么条件下下载模型更新版本。

Swift

let conditions = ModelDownloadConditions(wiFiRequired: true, idleRequired: true)
let cloudModelSource = CloudModelSource(
  modelName: "my_cloud_model",
  enableModelUpdates: true,
  initialConditions: conditions,
  updateConditions: conditions
)
let registrationSuccessful = ModelManager.modelManager().register(cloudModelSource)

Objective-C

FIRModelDownloadConditions *conditions =
    [[FIRModelDownloadConditions alloc] initWithWiFiRequired:YES
                                                idleRequired:YES];
FIRCloudModelSource *cloudModelSource =
    [[FIRCloudModelSource alloc] initWithModelName:@"my_cloud_model"
                                enableModelUpdates:YES
                                 initialConditions:conditions
                                  updateConditions:conditions];
  BOOL registrationSuccess =
      [[FIRModelManager modelManager] registerCloudModelSource:cloudModelSource];

如果您将模型与应用捆绑在了一起,或者在运行时从您自己的主机上下载了模型,请注册一个 LocalModelSource 对象,指定 .tflite 模型的本地路径并为本地模型来源分配一个在应用中识别它的唯一名称。

Swift

guard let modelPath = Bundle.main.path(
  forResource: "my_model",
  ofType: "tflite"
) else {
  // Invalid model path
  return
}
let localModelSource = LocalModelSource(modelName: "my_local_model",
                                        path: modelPath)
let registrationSuccessful = ModelManager.modelManager().register(localModelSource)

Objective-C

NSString *modelPath = [NSBundle.mainBundle pathForResource:@"my_model"
                                                    ofType:@"tflite"];
FIRLocalModelSource *localModelSource =
    [[FIRLocalModelSource alloc] initWithModelName:@"my_local_model"
                                              path:modelPath];
BOOL registrationSuccess =
      [[FIRModelManager modelManager] registerLocalModelSource:localModelSource];

然后,创建一个含云端模型来源和/或本地模型来源的 ModelOptions 对象,并使用该对象获取 ModelInterpreter 的一个示例。如果您只有一个来源,请为您不使用的来源类型指定 nil

Swift

let options = ModelOptions(
  cloudModelName: "my_cloud_model",
  localModelName: "my_local_model"
)
let interpreter = ModelInterpreter(options: options)

Objective-C

FIRModelOptions *options = [[FIRModelOptions alloc] initWithCloudModelName:@"my_cloud_model"
                                                            localModelName:@"my_local_model"];
FIRModelInterpreter *interpreter = [FIRModelInterpreter modelInterpreterWithOptions:options];

如果您同时指定了云端模型来源和本地模型来源,则模型解析器将在云端模型可用时使用云端模型,并在云端模型不可用时转为使用本地模型。

指定模型的输入和输出

接下来,您必须创建一个 ModelInputOutputOptions 对象以指定模型的输入和输出格式。

TensorFlow Lite 模型支持输入一个或多个多维数组,并可在输出时生成一个或多个多维数组。这些数组包含 UInt8Int32Int64Float32。您必须根据模型采用的数组个数和维度(“形状”)配置机器学习套件。

例如,一个图片分类模型可能会被输入一个 1x640x480x3 字节数组(这表示一张 640x480 的 24 位全彩图片),并在输出时生成一个包含 1000 个 Float32 值的列表(每个值表示该图片属于此模型预测的 1000 个类别中的某一个类别的概率)。

Swift

let ioOptions = ModelInputOutputOptions()
do {
  try ioOptions.setInputFormat(index: 0, type: .uInt8, dimensions: [1, 640, 480, 3])
  try ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 1000])
} catch let error as NSError {
  print("Failed to set input or output format with error: \(error.localizedDescription)")
}

Objective-C

FIRModelInputOutputOptions *ioOptions = [[FIRModelInputOutputOptions alloc] init];
NSError *error;
[ioOptions setInputFormatForIndex:0
                             type:FIRModelElementTypeUInt8
                       dimensions:@[@1, @640, @480, @3]
                            error:&error];
if (error != nil) { return; }
[ioOptions setOutputFormatForIndex:0
                              type:FIRModelElementTypeFloat32
                        dimensions:@[@1, @1000]
                             error:&error];
if (error != nil) { return; }

根据输入数据进行推理

最后,要使用此模型执行推理,请利用模型输入创建一个 ModelInputs 对象,并将该对象以及此模型的输入和输出规范传递给模型解析器run(inputs:options:) 方法。为了获得理想效果,请将模型输入作为 Data (NSData) 对象传递。

Swift

let input = ModelInputs()
do {
  var data: Data  // or var data: Array
  // Store input data in `data`
  // ...
  try input.addInput(data)
  // Repeat as necessary for each input index
} catch let error as NSError {
  print("Failed to add input: \(error.localizedDescription)")
}

interpreter.run(inputs: input, options: ioOptions) { outputs, error in
  guard error == nil, let outputs = outputs else { return }
  // Process outputs
  // ...
}

Objective-C

FIRModelInputs *inputs = [[FIRModelInputs alloc] init];
NSData *data;  // Or NSArray *data;
// ...
[inputs addInput:data error:&error];  // Repeat as necessary.
if (error != nil) { return; }
[interpreter runWithInputs:inputs
                   options:ioOptions
                completion:^(FIRModelOutputs * _Nullable outputs,
                             NSError * _Nullable error) {
  if (error != nil || outputs == nil) {
    return;
  }
  // Process outputs
  // ...
}];

您可以通过调用返回的对象的 output(index:) 方法来获取输出。例如:

Swift

// Get first and only output of inference with a batch size of 1
let probabilities = try? outputs.output(index: 0)

Objective-C

// Get first and only output of inference with a batch size of 1
NSError *outputError;
[outputs outputAtIndex:0 error:&outputError];

如何使用输出取决于您使用什么模型。例如,如果您要执行分类,那么您可以在下一步中将结果的索引映射到它们所代表的标签。

发送以下问题的反馈:

此网页
需要帮助?请访问我们的支持页面