如果您的应用使用自定义 TensorFlow Lite 模型,那么您可以使用 Firebase ML 来部署模型。这样一来,您可以缩减应用的初始下载大小,而且无需发布应用的新版本即可更新应用的机器学习模型。此外,借助 Remote Config 和 A/B Testing,您可以为不同的用户组动态提供不同的模型。
前提条件
MLModelDownloader
库仅适用于 Swift。- TensorFlow Lite 只能在使用 iOS 9 及更高版本的设备上运行。
TensorFlow Lite 模型
TensorFlow Lite 模型是经过优化、可在移动设备上高效运行的机器学习模型。如需获取 TensorFlow Lite 模型,可采取下列两种方式之一:
准备工作
如需将 TensorFlowLite 与 Firebase 搭配使用,您必须使用 CocoaPods,因为 TensorFlowLite 目前不支持使用 Swift Package Manager 安装。如需了解如何安装 MLModelDownloader
,请参阅 CocoaPods 安装指南。
安装完成后,导入 Firebase 和 TensorFlowLite 以使用它们。
Swift
import FirebaseMLModelDownloader
import TensorFlowLite
1.部署模型
使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK 部署自定义 TensorFlow 模型。请参阅部署和管理自定义模型。
将自定义模型添加到 Firebase 项目后,您可以使用指定的名称在应用中引用该模型。您可以随时部署新的 TensorFlow Lite 模型,并调用 getModel()
来将新模型下载到用户的设备上(见下文)。
2. 将模型下载到设备并初始化一个 TensorFlow Lite 解释器
如需在您的应用中使用 TensorFlow Lite 模型,请先使用 Firebase ML SDK 将模型的最新版本下载到设备上。如需启动模型下载,请调用模型下载程序的 getModel()
方法,指定您在上传该模型时为其指定的名称、是否总是下载最新模型,以及您希望在什么条件下允许下载。
您可以从以下三种下载方式中进行选择:
下载类型 | 说明 |
---|---|
localModel
|
从设备获取本地模型。如果没有本地模型,则其行为类似于 latestModel 。如果您不想检查模型更新,请使用此下载类型。例如,您使用 Remote Config 检索模型名称,并始终以新名称上传模型(推荐)。 |
localModelUpdateInBackground
|
从设备获取本地模型,并在后台开始更新模型。如果没有本地模型,则其行为类似于 latestModel 。 |
latestModel
|
获取最新模型。如果本地模型是最新版本,将返回本地模型。否则,下载最新模型。此方式会阻塞进程,直到最新版本下载完毕(不推荐)。请仅在您明确需要最新版本的情况下才使用此方式。 |
您应该停用与模型相关的功能(例如使界面的一部分变灰或将其隐藏),直到您确认模型已下载。
Swift
let conditions = ModelDownloadConditions(allowsCellularAccess: false)
ModelDownloader.modelDownloader()
.getModel(name: "your_model",
downloadType: .localModelUpdateInBackground,
conditions: conditions) { result in
switch (result) {
case .success(let customModel):
do {
// Download complete. Depending on your app, you could enable the ML
// feature, or switch from the local model to the remote model, etc.
// The CustomModel object contains the local path of the model file,
// which you can use to instantiate a TensorFlow Lite interpreter.
let interpreter = try Interpreter(modelPath: customModel.path)
} catch {
// Error. Bad model file?
}
case .failure(let error):
// Download was unsuccessful. Don't enable ML features.
print(error)
}
}
许多应用会通过其初始化代码启动下载任务,您也可以在需要使用该模型之前随时启动下载任务。
3. 利用输入数据进行推断
获取模型的输入和输出形状
TensorFlow Lite 模型解释器采用一个或多个多维数组作为输入和输出。这些数组可以包含 byte
、int
、long
或 float
值。在将数据传递到模型或使用其结果之前,您必须知道模型使用的数组的数量和维度(“形状”)。
如果模型是您自己构建的,或者模型的输入和输出格式有文档说明,那么您可能已经知道这些信息。如果您不知道模型输入和输出的形状和数据类型,可以使用 TensorFlow Lite 解释器检查模型。例如:
Python
import tensorflow as tf interpreter = tf.lite.Interpreter(model_path="your_model.tflite") interpreter.allocate_tensors() # Print input shape and type inputs = interpreter.get_input_details() print('{} input(s):'.format(len(inputs))) for i in range(0, len(inputs)): print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype'])) # Print output shape and type outputs = interpreter.get_output_details() print('\n{} output(s):'.format(len(outputs))) for i in range(0, len(outputs)): print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))
输出示例:
1 input(s): [ 1 224 224 3] <class 'numpy.float32'> 1 output(s): [1 1000] <class 'numpy.float32'>
运行解释器
在确定模型的输入和输出格式后,需获取输入数据,然后对数据执行任何必要的转换,以便为模型生成形状正确的输入。例如,如果模型要处理图片,并且其输入维度为 [1, 224, 224, 3]
浮点值,则可能必须将相应图片的颜色值调整为浮点范围,如以下示例所示:
Swift
let image: CGImage = // Your input image
guard let context = CGContext(
data: nil,
width: image.width, height: image.height,
bitsPerComponent: 8, bytesPerRow: image.width * 4,
space: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
return false
}
context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }
var inputData = Data()
for row in 0 ..< 224 {
for col in 0 ..< 224 {
let offset = 4 * (row * context.width + col)
// (Ignore offset 0, the unused alpha channel)
let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)
// Normalize channel values to [0.0, 1.0]. This requirement varies
// by model. For example, some models might require values to be
// normalized to the range [-1.0, 1.0] instead, and others might
// require fixed-point values or the original bytes.
var normalizedRed = Float32(red) / 255.0
var normalizedGreen = Float32(green) / 255.0
var normalizedBlue = Float32(blue) / 255.0
// Append normalized values to Data object in RGB order.
let elementSize = MemoryLayout.size(ofValue: normalizedRed)
var bytes = [UInt8](repeating: 0, count: elementSize)
memcpy(&bytes, &normalizedRed, elementSize)
inputData.append(&bytes, count: elementSize)
memcpy(&bytes, &normalizedGreen, elementSize)
inputData.append(&bytes, count: elementSize)
memcpy(&ammp;bytes, &normalizedBlue, elementSize)
inputData.append(&bytes, count: elementSize)
}
}
然后,将输入 NSData
复制到解释器中并运行:
Swift
try interpreter.allocateTensors()
try interpreter.copy(inputData, toInputAt: 0)
try interpreter.invoke()
您可以通过调用解释器的 output(at:)
方法来获取模型的输出。如何使用输出取决于您所使用的模型。
例如,如果您要执行分类,那么您可以在下一步中将结果的索引映射到它们所代表的标签:
Swift
let output = try interpreter.output(at: 0)
let probabilities =
UnsafeMutableBufferPointer<Float32>.allocate(capacity: 1000)
output.data.copyBytes(to: probabilities)
guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }
for i in labels.indices {
print("\(labels[i]): \(probabilities[i])")
}
附录:模型的安全性
无论您以何种方式在 Firebase ML 中添加自己的 TensorFlow Lite 模型,Firebase ML 都会以标准序列化的 protobuf 格式将这些模型存储到本地存储空间中。
从理论上说,这意味着任何人都可以复制您的模型。但实际上,大多数模型都是针对具体的应用,且通过优化进行了混淆处理,因此,这一风险与竞争对手对您的代码进行反汇编和再利用的风险类似。但无论怎样,在您的应用中使用自定义模型之前,您应该了解这种风险。