您可以使用机器学习套件,通过 TensorFlow Lite 模型在设备上执行推理。
机器学习套件只能在运行 iOS 9 或更高版本的设备上使用 TensorFlow Lite 模型。
准备工作
- 如果您尚未将 Firebase 添加到自己的应用中,请按照入门指南中的步骤执行此操作。
- 在 Podfile 中添加机器学习套件库:
pod 'Firebase/MLModelInterpreter', '6.25.0'
在安装或更新项目的 Pod 之后,请务必使用 Xcode 项目的.xcworkspace
打开该项目。 - 在您的应用中导入 Firebase:
Swift
import Firebase
Objective-C
@import Firebase;
- 将您要使用的 TensorFlow 模型转换为 TensorFlow Lite 格式。 请参阅 TOCO:TensorFlow Lite 优化转换器。
托管或捆绑您的模型
如需在您的应用中使用 TensorFlow Lite 模型进行推理,您必须先确保机器学习套件能够使用该模型。机器学习套件可以使用在 Firebase 中远程托管和/或与应用二进制文件捆绑的 TensorFlow Lite 模型。
通过在 Firebase 上托管模型,您可以在不发布新应用版本的情况下更新模型,并且可以使用 Remote Config 和 A/B Testing 为不同的用户组动态运用不同的模型。
如果您选择仅通过在 Firebase 中托管而不是与应用捆绑的方式来提供模型,可以缩小应用的初始下载文件大小。但请注意,如果模型未与您的应用捆绑,那么在应用首次下载模型之前,任何与模型相关的功能都将无法使用。
将您的模型与应用捆绑,可以确保当 Firebase 托管的模型不可用时,应用的机器学习功能仍可正常运行。
在 Firebase 上托管模型
如需在 Firebase 上托管您的 TensorFlow Lite 模型,请执行以下操作:
- 在 Firebase 控制台的机器学习套件部分中,点击自定义标签页。
- 点击添加自定义模型(或再添加一个模型)。
- 指定一个名称,用于在 Firebase 项目中识别您的模型,然后上传 TensorFlow Lite 模型文件(通常以
.tflite
或.lite
结尾)。
将自定义模型添加到 Firebase 项目后,您可以使用指定的名称在应用中引用该模型。您随时可以上传新的 TensorFlow Lite 模型,您的应用会下载新模型,然后在下次重启时开始使用新模型。您可以定义应用尝试更新模型时所需满足的设备条件(请参见下文)。
将模型与应用捆绑在一起
如需将 TensorFlow Lite 模型与您的应用捆绑在一起,请将模型文件(通常以 .tflite
或 .lite
结尾)添加到您的 Xcode 项目中,并在执行此操作时注意选择复制软件包资源 (Copy bundle resources)。模型文件将包含在应用软件包中,并提供给机器学习套件使用。
加载模型
如需在您的应用中使用 TensorFlow Lite 模型,请首先为机器学习套件配置模型所在的位置:使用 Firebase 远程托管、在本地存储空间,或者两者同时。如果您同时指定了本地和远程模型,则可以在远程模型可用时使用远程模型,并在其不可用时回退为使用本地存储的模型。
配置 Firebase 托管的模型
如果您使用 Firebase 托管模型,请创建一个 CustomRemoteModel
对象,并在发布模型时指定您分配给模型的名称:
Swift
let remoteModel = CustomRemoteModel(
name: "your_remote_model" // The name you assigned in the Firebase console.
)
Objective-C
// Initialize using the name you assigned in the Firebase console.
FIRCustomRemoteModel *remoteModel =
[[FIRCustomRemoteModel alloc] initWithName:@"your_remote_model"];
然后,启动模型下载任务,指定允许下载的条件。如果模型不在设备上,或模型有较新的版本,则任务将从 Firebase 异步下载模型:
Swift
let downloadConditions = ModelDownloadConditions(
allowsCellularAccess: true,
allowsBackgroundDownloading: true
)
let downloadProgress = ModelManager.modelManager().download(
remoteModel,
conditions: downloadConditions
)
Objective-C
FIRModelDownloadConditions *downloadConditions =
[[FIRModelDownloadConditions alloc] initWithAllowsCellularAccess:YES
allowsBackgroundDownloading:YES];
NSProgress *downloadProgress =
[[FIRModelManager modelManager] downloadRemoteModel:remoteModel
conditions:downloadConditions];
许多应用会通过其初始化代码启动下载任务,但您可以在需要使用该模型之前随时启动下载任务。
配置本地模型
如果您将模型与应用捆绑在一起,请创建 CustomLocalModel
对象,并指定 TensorFlow Lite 模型的文件名:
Swift
guard let modelPath = Bundle.main.path(
forResource: "your_model",
ofType: "tflite",
inDirectory: "your_model_directory"
) else { /* Handle error. */ }
let localModel = CustomLocalModel(modelPath: modelPath)
Objective-C
NSString *modelPath = [NSBundle.mainBundle pathForResource:@"your_model"
ofType:@"tflite"
inDirectory:@"your_model_directory"];
FIRCustomLocalModel *localModel =
[[FIRCustomLocalModel alloc] initWithModelPath:modelPath];
根据模型创建解释器
配置模型来源后,根据其中一个模型创建 ModelInterpreter
对象。
如果您只有本地捆绑的模型,则只需将 CustomLocalModel
对象传递给 modelInterpreter(localModel:)
:
Swift
let interpreter = ModelInterpreter.modelInterpreter(localModel: localModel)
Objective-C
FIRModelInterpreter *interpreter =
[FIRModelInterpreter modelInterpreterForLocalModel:localModel];
如果您使用的是远程托管的模型,则必须在运行之前检查该模型是否已下载。您可以使用模型管理器的 isModelDownloaded(remoteModel:)
方法检查模型下载任务的状态。
虽然您只需在运行解释器之前确认这一点,但如果您同时拥有远程托管模型和本地捆绑模型,则在实例化 ModelInterpreter
时执行此检查可能是有意义的:如果已下载,则根据远程模型创建解释器,否则根据本地模型进行创建。
Swift
var interpreter: ModelInterpreter
if ModelManager.modelManager().isModelDownloaded(remoteModel) {
interpreter = ModelInterpreter.modelInterpreter(remoteModel: remoteModel)
} else {
interpreter = ModelInterpreter.modelInterpreter(localModel: localModel)
}
Objective-C
FIRModelInterpreter *interpreter;
if ([[FIRModelManager modelManager] isModelDownloaded:remoteModel]) {
interpreter = [FIRModelInterpreter modelInterpreterForRemoteModel:remoteModel];
} else {
interpreter = [FIRModelInterpreter modelInterpreterForLocalModel:localModel];
}
如果您只有远程托管的模型,则应停用与模型相关的功能(例如灰显或隐藏部分界面),直到您确认模型已下载。
您可以将观察者附加到默认通知中心,以获取模型下载状态。请务必在观察者块中使用对 self
的弱引用,因为下载可能需要一些时间,并且源对象可能到下载完成才会被释放。例如:
Swift
NotificationCenter.default.addObserver( forName: .firebaseMLModelDownloadDidSucceed, object: nil, queue: nil ) { [weak self] notification in guard let strongSelf = self, let userInfo = notification.userInfo, let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue] as? RemoteModel, model.name == "your_remote_model" else { return } // The model was downloaded and is available on the device } NotificationCenter.default.addObserver( forName: .firebaseMLModelDownloadDidFail, object: nil, queue: nil ) { [weak self] notification in guard let strongSelf = self, let userInfo = notification.userInfo, let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue] as? RemoteModel else { return } let error = userInfo[ModelDownloadUserInfoKey.error.rawValue] // ... }
Objective-C
__weak typeof(self) weakSelf = self; [NSNotificationCenter.defaultCenter addObserverForName:FIRModelDownloadDidSucceedNotification object:nil queue:nil usingBlock:^(NSNotification *_Nonnull note) { if (weakSelf == nil | note.userInfo == nil) { return; } __strong typeof(self) strongSelf = weakSelf; FIRRemoteModel *model = note.userInfo[FIRModelDownloadUserInfoKeyRemoteModel]; if ([model.name isEqualToString:@"your_remote_model"]) { // The model was downloaded and is available on the device } }]; [NSNotificationCenter.defaultCenter addObserverForName:FIRModelDownloadDidFailNotification object:nil queue:nil usingBlock:^(NSNotification *_Nonnull note) { if (weakSelf == nil | note.userInfo == nil) { return; } __strong typeof(self) strongSelf = weakSelf; NSError *error = note.userInfo[FIRModelDownloadUserInfoKeyError]; }];
指定模型的输入和输出
接下来,配置模型解释器的输入和输出格式。
TensorFlow Lite 模型采用一个或多个多维数组作为输入和生成的输出。这些数组可以包含 byte
、int
、long
或 float
值。您必须根据模型采用的数组个数和维度(“形状”)配置机器学习套件。
如果您不知道模型输入和输出的形状和数据类型,可以使用 TensorFlow Lite Python 解释器检查模型。例如:
import tensorflow as tf interpreter = tf.lite.Interpreter(model_path="my_model.tflite") interpreter.allocate_tensors() # Print input shape and type print(interpreter.get_input_details()[0]['shape']) # Example: [1 224 224 3] print(interpreter.get_input_details()[0]['dtype']) # Example: <class 'numpy.float32'> # Print output shape and type print(interpreter.get_output_details()[0]['shape']) # Example: [1 1000] print(interpreter.get_output_details()[0]['dtype']) # Example: <class 'numpy.float32'>
确定模型的输入和输出格式后,创建一个 ModelInputOutputOptions
对象来配置应用的模型解释器。
例如,一个浮点图片分类模型可能接受一个 Float
值的 Nx224x224x3 数组(这表示一批 224x224 的三通道 [RGB] 图片,包含 N 个),并且输出是一个包含 1,000 个 Float
值的列表(每个值表示该图片属于此模型预测的 1,000 个类别中的某一个类别的概率)。
对于此类模型,您需要按如下所示配置模型解析器的输入和输出:
Swift
let ioOptions = ModelInputOutputOptions() do { try ioOptions.setInputFormat(index: 0, type: .float32, dimensions: [1, 224, 224, 3]) try ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 1000]) } catch let error as NSError { print("Failed to set input or output format with error: \(error.localizedDescription)") }
Objective-C
FIRModelInputOutputOptions *ioOptions = [[FIRModelInputOutputOptions alloc] init]; NSError *error; [ioOptions setInputFormatForIndex:0 type:FIRModelElementTypeFloat32 dimensions:@[@1, @224, @224, @3] error:&error]; if (error != nil) { return; } [ioOptions setOutputFormatForIndex:0 type:FIRModelElementTypeFloat32 dimensions:@[@1, @1000] error:&error]; if (error != nil) { return; }
利用输入数据进行推理
最后,要使用模型进行推理,请获取输入数据,对数据执行模型可能需要的转换,然后构建包含这些数据的 Data
对象。
例如,如果模型要处理图片,并且其输入维度为 [BATCH_SIZE, 224, 224, 3]
浮点值,则可能必须将相应图片的颜色值调整为浮点范围,如以下示例所示:
Swift
let image: CGImage = // Your input image guard let context = CGContext( data: nil, width: image.width, height: image.height, bitsPerComponent: 8, bytesPerRow: image.width * 4, space: CGColorSpaceCreateDeviceRGB(), bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue ) else { return false } context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height)) guard let imageData = context.data else { return false } let inputs = ModelInputs() var inputData = Data() do { for row in 0 ..< 224 { for col in 0 ..< 224 { let offset = 4 * (col * context.width + row) // (Ignore offset 0, the unused alpha channel) let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self) let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self) let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self) // Normalize channel values to [0.0, 1.0]. This requirement varies // by model. For example, some models might require values to be // normalized to the range [-1.0, 1.0] instead, and others might // require fixed-point values or the original bytes. var normalizedRed = Float32(red) / 255.0 var normalizedGreen = Float32(green) / 255.0 var normalizedBlue = Float32(blue) / 255.0 // Append normalized values to Data object in RGB order. let elementSize = MemoryLayout.size(ofValue: normalizedRed) var bytes = [UInt8](repeating: 0, count: elementSize) memcpy(&bytes, &normalizedRed, elementSize) inputData.append(&bytes, count: elementSize) memcpy(&bytes, &normalizedGreen, elementSize) inputData.append(&bytes, count: elementSize) memcpy(&ammp;bytes, &normalizedBlue, elementSize) inputData.append(&bytes, count: elementSize) } } try inputs.addInput(inputData) } catch let error { print("Failed to add input: \(error)") }
Objective-C
CGImageRef image = // Your input image long imageWidth = CGImageGetWidth(image); long imageHeight = CGImageGetHeight(image); CGContextRef context = CGBitmapContextCreate(nil, imageWidth, imageHeight, 8, imageWidth * 4, CGColorSpaceCreateDeviceRGB(), kCGImageAlphaNoneSkipFirst); CGContextDrawImage(context, CGRectMake(0, 0, imageWidth, imageHeight), image); UInt8 *imageData = CGBitmapContextGetData(context); FIRModelInputs *inputs = [[FIRModelInputs alloc] init]; NSMutableData *inputData = [[NSMutableData alloc] initWithCapacity:0]; for (int row = 0; row < 224; row++) { for (int col = 0; col < 224; col++) { long offset = 4 * (col * imageWidth + row); // Normalize channel values to [0.0, 1.0]. This requirement varies // by model. For example, some models might require values to be // normalized to the range [-1.0, 1.0] instead, and others might // require fixed-point values or the original bytes. // (Ignore offset 0, the unused alpha channel) Float32 red = imageData[offset+1] / 255.0f; Float32 green = imageData[offset+2] / 255.0f; Float32 blue = imageData[offset+3] / 255.0f; [inputData appendBytes:&red length:sizeof(red)]; [inputData appendBytes:&green length:sizeof(green)]; [inputData appendBytes:&blue length:sizeof(blue)]; } } [inputs addInput:inputData error:&error]; if (error != nil) { return nil; }
准备好模型输入后(并确认模型可用后),将输入和输入/输出选项传递给模型解释器的 run(inputs:options:completion:)
方法。
Swift
interpreter.run(inputs: inputs, options: ioOptions) { outputs, error in guard error == nil, let outputs = outputs else { return } // Process outputs // ... }
Objective-C
[interpreter runWithInputs:inputs options:ioOptions completion:^(FIRModelOutputs * _Nullable outputs, NSError * _Nullable error) { if (error != nil || outputs == nil) { return; } // Process outputs // ... }];
您可以通过调用返回的对象的 output(index:)
方法来获取输出。例如:
Swift
// Get first and only output of inference with a batch size of 1 let output = try? outputs.output(index: 0) as? [[NSNumber]] let probabilities = output??[0]
Objective-C
// Get first and only output of inference with a batch size of 1 NSError *outputError; NSArray *probabilites = [outputs outputAtIndex:0 error:&outputError][0];
如何使用输出取决于您所使用的模型。
例如,如果您要执行分类,那么您可以在下一步中将结果的索引映射到它们所代表的标签。假设您有一个文本文件,其中包含模型的各个类别的标签字符串;您可以通过执行如下操作将这些标签字符串映射到输出概率:
Swift
guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return } let fileContents = try? String(contentsOfFile: labelPath) guard let labels = fileContents?.components(separatedBy: "\n") else { return } for i in 0 ..< labels.count { if let probability = probabilities?[i] { print("\(labels[i]): \(probability)") } }
Objective-C
NSError *labelReadError = nil; NSString *labelPath = [NSBundle.mainBundle pathForResource:@"retrained_labels" ofType:@"txt"]; NSString *fileContents = [NSString stringWithContentsOfFile:labelPath encoding:NSUTF8StringEncoding error:&labelReadError]; if (labelReadError != nil || fileContents == NULL) { return; } NSArray<NSString *> *labels = [fileContents componentsSeparatedByString:@"\n"]; for (int i = 0; i < labels.count; i++) { NSString *label = labels[i]; NSNumber *probability = probabilites[i]; NSLog(@"%@: %f", label, probability.floatValue); }
附录:模型的安全性
无论您以何种方式在机器学习套件中添加自己的 TensorFlow Lite 模型,机器学习套件会以标准序列化的 protobuf 格式将所有模型存储到本地存储空间中。
这意味着,从理论上说,任何人都可以复制您的模型。但实际上,大多数模型都是针对具体的应用,且通过优化进行了混淆处理,因此,这一风险与竞争对手对您的代码进行反汇编和再利用的风险类似。但无论怎样,在您的应用中使用自定义模型之前,您应该了解这种风险。