您可以使用 ML Kit 通過TensorFlow Lite模型執行設備上推理。
ML Kit 只能在運行 iOS 9 及更高版本的設備上使用 TensorFlow Lite 模型。
在你開始之前
- 如果您尚未將 Firebase 添加到您的應用,請按照入門指南中的步驟進行操作。
- 在您的 Podfile 中包含 ML Kit 庫:
pod 'Firebase/MLModelInterpreter', '6.25.0'
安裝或更新項目的 Pod 後,請務必使用其.xcworkspace
打開您的 Xcode 項目。 - 在您的應用中,導入 Firebase:
迅速
import Firebase
Objective-C
@import Firebase;
- 將您要使用的 TensorFlow 模型轉換為 TensorFlow Lite 格式。請參閱TOCO:TensorFlow Lite 優化轉換器。
託管或捆綁您的模型
在您可以在應用程序中使用 TensorFlow Lite 模型進行推理之前,您必須使該模型可用於 ML Kit。 ML Kit 可以使用通過 Firebase 遠程託管的 TensorFlow Lite 模型,與應用程序二進製文件捆綁,或兩者兼而有之。
通過在 Firebase 上託管模型,您可以在不發布新應用版本的情況下更新模型,並且可以使用遠程配置和 A/B 測試為不同的用戶組動態提供不同的模型。
如果您選擇僅通過使用 Firebase 託管模型來提供模型,而不是將其與您的應用捆綁在一起,則可以減少應用的初始下載大小。但請記住,如果模型未與您的應用程序捆綁在一起,則在您的應用程序第一次下載模型之前,任何與模型相關的功能都將不可用。
通過將您的模型與您的應用捆綁在一起,您可以確保在 Firebase 託管的模型不可用時,您的應用的 ML 功能仍然有效。
在 Firebase 上託管模型
在 Firebase 上託管您的 TensorFlow Lite 模型:
- 在Firebase 控制台的ML Kit部分中,單擊自定義選項卡。
- 單擊添加自定義模型(或添加另一個模型)。
- 指定將用於在 Firebase 項目中識別模型的名稱,然後上傳 TensorFlow Lite 模型文件(通常以
.tflite
或.lite
)。
將自定義模型添加到 Firebase 項目後,您可以使用您指定的名稱在應用中引用該模型。您可以隨時上傳新的 TensorFlow Lite 模型,您的應用將下載新模型並在應用下次重新啟動時開始使用它。您可以定義應用程序嘗試更新模型所需的設備條件(見下文)。
將模型與應用程序捆綁在一起
要將您的 TensorFlow Lite 模型與您的應用程序捆綁在一起,請將模型文件(通常以.tflite
或.lite
)添加到您的 Xcode 項目中,這樣做時請注意選擇複製捆綁資源。模型文件將包含在應用程序包中,並可用於 ML Kit。
加載模型
要在您的應用中使用您的 TensorFlow Lite 模型,請首先使用您的模型可用的位置配置 ML Kit:遠程使用 Firebase、在本地存儲中或兩者兼而有之。如果同時指定本地和遠程模型,則可以使用遠程模型(如果可用),如果遠程模型不可用,則回退到本地存儲的模型。
配置 Firebase 託管的模型
如果您使用 Firebase 託管模型,請創建一個CustomRemoteModel
對象,並指定您在發布模型時為其分配的名稱:
迅速
let remoteModel = CustomRemoteModel(
name: "your_remote_model" // The name you assigned in the Firebase console.
)
Objective-C
// Initialize using the name you assigned in the Firebase console.
FIRCustomRemoteModel *remoteModel =
[[FIRCustomRemoteModel alloc] initWithName:@"your_remote_model"];
然後,啟動模型下載任務,指定您希望允許下載的條件。如果模型不在設備上,或者有更新版本的模型可用,任務將從 Firebase 異步下載模型:
迅速
let downloadConditions = ModelDownloadConditions(
allowsCellularAccess: true,
allowsBackgroundDownloading: true
)
let downloadProgress = ModelManager.modelManager().download(
remoteModel,
conditions: downloadConditions
)
Objective-C
FIRModelDownloadConditions *downloadConditions =
[[FIRModelDownloadConditions alloc] initWithAllowsCellularAccess:YES
allowsBackgroundDownloading:YES];
NSProgress *downloadProgress =
[[FIRModelManager modelManager] downloadRemoteModel:remoteModel
conditions:downloadConditions];
許多應用程序在其初始化代碼中啟動下載任務,但您可以在需要使用模型之前的任何時候這樣做。
配置本地模型
如果您將模型與您的應用程序捆綁在一起,請創建一個CustomLocalModel
對象,並指定 TensorFlow Lite 模型的文件名:
迅速
guard let modelPath = Bundle.main.path(
forResource: "your_model",
ofType: "tflite",
inDirectory: "your_model_directory"
) else { /* Handle error. */ }
let localModel = CustomLocalModel(modelPath: modelPath)
Objective-C
NSString *modelPath = [NSBundle.mainBundle pathForResource:@"your_model"
ofType:@"tflite"
inDirectory:@"your_model_directory"];
FIRCustomLocalModel *localModel =
[[FIRCustomLocalModel alloc] initWithModelPath:modelPath];
從您的模型創建解釋器
配置模型源後,從其中之一創建ModelInterpreter
對象。
如果您只有一個本地捆綁模型,只需將CustomLocalModel
對像傳遞給modelInterpreter(localModel:)
:
迅速
let interpreter = ModelInterpreter.modelInterpreter(localModel: localModel)
Objective-C
FIRModelInterpreter *interpreter =
[FIRModelInterpreter modelInterpreterForLocalModel:localModel];
如果您有遠程託管模型,則必須在運行之前檢查它是否已下載。您可以使用模型管理器的isModelDownloaded(remoteModel:)
方法檢查模型下載任務的狀態。
儘管您只需要在運行解釋器之前確認這一點,但如果您同時擁有遠程託管模型和本地捆綁模型,則在實例化ModelInterpreter
時執行此檢查可能是有意義的:如果它是,則從遠程模型創建解釋器已下載,否則從本地模型下載。
迅速
var interpreter: ModelInterpreter
if ModelManager.modelManager().isModelDownloaded(remoteModel) {
interpreter = ModelInterpreter.modelInterpreter(remoteModel: remoteModel)
} else {
interpreter = ModelInterpreter.modelInterpreter(localModel: localModel)
}
Objective-C
FIRModelInterpreter *interpreter;
if ([[FIRModelManager modelManager] isModelDownloaded:remoteModel]) {
interpreter = [FIRModelInterpreter modelInterpreterForRemoteModel:remoteModel];
} else {
interpreter = [FIRModelInterpreter modelInterpreterForLocalModel:localModel];
}
如果您只有一個遠程託管的模型,您應該禁用與模型相關的功能——例如,灰顯或隱藏部分 UI——直到您確認模型已下載。
您可以通過將觀察者附加到默認通知中心來獲取模型下載狀態。確保在觀察者塊中使用對self
的弱引用,因為下載可能需要一些時間,並且在下載完成時可以釋放原始對象。例如:
迅速
NotificationCenter.default.addObserver( forName: .firebaseMLModelDownloadDidSucceed, object: nil, queue: nil ) { [weak self] notification in guard let strongSelf = self, let userInfo = notification.userInfo, let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue] as? RemoteModel, model.name == "your_remote_model" else { return } // The model was downloaded and is available on the device } NotificationCenter.default.addObserver( forName: .firebaseMLModelDownloadDidFail, object: nil, queue: nil ) { [weak self] notification in guard let strongSelf = self, let userInfo = notification.userInfo, let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue] as? RemoteModel else { return } let error = userInfo[ModelDownloadUserInfoKey.error.rawValue] // ... }
Objective-C
__weak typeof(self) weakSelf = self; [NSNotificationCenter.defaultCenter addObserverForName:FIRModelDownloadDidSucceedNotification object:nil queue:nil usingBlock:^(NSNotification *_Nonnull note) { if (weakSelf == nil | note.userInfo == nil) { return; } __strong typeof(self) strongSelf = weakSelf; FIRRemoteModel *model = note.userInfo[FIRModelDownloadUserInfoKeyRemoteModel]; if ([model.name isEqualToString:@"your_remote_model"]) { // The model was downloaded and is available on the device } }]; [NSNotificationCenter.defaultCenter addObserverForName:FIRModelDownloadDidFailNotification object:nil queue:nil usingBlock:^(NSNotification *_Nonnull note) { if (weakSelf == nil | note.userInfo == nil) { return; } __strong typeof(self) strongSelf = weakSelf; NSError *error = note.userInfo[FIRModelDownloadUserInfoKeyError]; }];
指定模型的輸入和輸出
接下來,配置模型解釋器的輸入和輸出格式。
TensorFlow Lite 模型將一個或多個多維數組作為輸入並生成輸出。這些數組包含byte
、 int
、 long
或float
值。您必須使用模型使用的數組的數量和維度(“形狀”)來配置 ML Kit。
如果您不知道模型輸入和輸出的形狀和數據類型,您可以使用 TensorFlow Lite Python 解釋器來檢查您的模型。例如:
import tensorflow as tf interpreter = tf.lite.Interpreter(model_path="my_model.tflite") interpreter.allocate_tensors() # Print input shape and type print(interpreter.get_input_details()[0]['shape']) # Example: [1 224 224 3] print(interpreter.get_input_details()[0]['dtype']) # Example: <class 'numpy.float32'> # Print output shape and type print(interpreter.get_output_details()[0]['shape']) # Example: [1 1000] print(interpreter.get_output_details()[0]['dtype']) # Example: <class 'numpy.float32'>
確定模型輸入和輸出的格式後,通過創建ModelInputOutputOptions
對象來配置應用程序的模型解釋器。
例如,浮點圖像分類模型可能將N x224x224x3 Float
值數組作為輸入,代表一批N 224x224 三通道 (RGB) 圖像,並生成 1000 個Float
值列表作為輸出,每個浮點值代表圖像是模型預測的 1000 個類別之一的概率。
對於這樣的模型,您將配置模型解釋器的輸入和輸出,如下所示:
迅速
let ioOptions = ModelInputOutputOptions() do { try ioOptions.setInputFormat(index: 0, type: .float32, dimensions: [1, 224, 224, 3]) try ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 1000]) } catch let error as NSError { print("Failed to set input or output format with error: \(error.localizedDescription)") }
Objective-C
FIRModelInputOutputOptions *ioOptions = [[FIRModelInputOutputOptions alloc] init]; NSError *error; [ioOptions setInputFormatForIndex:0 type:FIRModelElementTypeFloat32 dimensions:@[@1, @224, @224, @3] error:&error]; if (error != nil) { return; } [ioOptions setOutputFormatForIndex:0 type:FIRModelElementTypeFloat32 dimensions:@[@1, @1000] error:&error]; if (error != nil) { return; }
對輸入數據執行推理
最後,要使用模型執行推理,獲取輸入數據,對模型可能需要的數據執行任何轉換,並構建包含數據的Data
對象。
例如,如果您的模型處理圖像,並且您的模型具有[BATCH_SIZE, 224, 224, 3]
浮點值的輸入尺寸,則您可能必須將圖像的顏色值縮放到浮點範圍,如下例所示:
迅速
let image: CGImage = // Your input image guard let context = CGContext( data: nil, width: image.width, height: image.height, bitsPerComponent: 8, bytesPerRow: image.width * 4, space: CGColorSpaceCreateDeviceRGB(), bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue ) else { return false } context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height)) guard let imageData = context.data else { return false } let inputs = ModelInputs() var inputData = Data() do { for row in 0 ..< 224 { for col in 0 ..< 224 { let offset = 4 * (col * context.width + row) // (Ignore offset 0, the unused alpha channel) let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self) let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self) let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self) // Normalize channel values to [0.0, 1.0]. This requirement varies // by model. For example, some models might require values to be // normalized to the range [-1.0, 1.0] instead, and others might // require fixed-point values or the original bytes. var normalizedRed = Float32(red) / 255.0 var normalizedGreen = Float32(green) / 255.0 var normalizedBlue = Float32(blue) / 255.0 // Append normalized values to Data object in RGB order. let elementSize = MemoryLayout.size(ofValue: normalizedRed) var bytes = [UInt8](repeating: 0, count: elementSize) memcpy(&bytes, &normalizedRed, elementSize) inputData.append(&bytes, count: elementSize) memcpy(&bytes, &normalizedGreen, elementSize) inputData.append(&bytes, count: elementSize) memcpy(&ammp;bytes, &normalizedBlue, elementSize) inputData.append(&bytes, count: elementSize) } } try inputs.addInput(inputData) } catch let error { print("Failed to add input: \(error)") }
Objective-C
CGImageRef image = // Your input image long imageWidth = CGImageGetWidth(image); long imageHeight = CGImageGetHeight(image); CGContextRef context = CGBitmapContextCreate(nil, imageWidth, imageHeight, 8, imageWidth * 4, CGColorSpaceCreateDeviceRGB(), kCGImageAlphaNoneSkipFirst); CGContextDrawImage(context, CGRectMake(0, 0, imageWidth, imageHeight), image); UInt8 *imageData = CGBitmapContextGetData(context); FIRModelInputs *inputs = [[FIRModelInputs alloc] init]; NSMutableData *inputData = [[NSMutableData alloc] initWithCapacity:0]; for (int row = 0; row < 224; row++) { for (int col = 0; col < 224; col++) { long offset = 4 * (col * imageWidth + row); // Normalize channel values to [0.0, 1.0]. This requirement varies // by model. For example, some models might require values to be // normalized to the range [-1.0, 1.0] instead, and others might // require fixed-point values or the original bytes. // (Ignore offset 0, the unused alpha channel) Float32 red = imageData[offset+1] / 255.0f; Float32 green = imageData[offset+2] / 255.0f; Float32 blue = imageData[offset+3] / 255.0f; [inputData appendBytes:&red length:sizeof(red)]; [inputData appendBytes:&green length:sizeof(green)]; [inputData appendBytes:&blue length:sizeof(blue)]; } } [inputs addInput:inputData error:&error]; if (error != nil) { return nil; }
準備好模型輸入後(並確認模型可用),將輸入和輸入/輸出選項傳遞給模型解釋器的run(inputs:options:completion:)
方法。
迅速
interpreter.run(inputs: inputs, options: ioOptions) { outputs, error in guard error == nil, let outputs = outputs else { return } // Process outputs // ... }
Objective-C
[interpreter runWithInputs:inputs options:ioOptions completion:^(FIRModelOutputs * _Nullable outputs, NSError * _Nullable error) { if (error != nil || outputs == nil) { return; } // Process outputs // ... }];
您可以通過調用返回對象的output(index:)
方法來獲取輸出。例如:
迅速
// Get first and only output of inference with a batch size of 1 let output = try? outputs.output(index: 0) as? [[NSNumber]] let probabilities = output??[0]
Objective-C
// Get first and only output of inference with a batch size of 1 NSError *outputError; NSArray *probabilites = [outputs outputAtIndex:0 error:&outputError][0];
您如何使用輸出取決於您使用的模型。
例如,如果您正在執行分類,作為下一步,您可能會將結果的索引映射到它們所代表的標籤。假設您有一個文本文件,其中包含每個模型類別的標籤字符串;您可以通過執行以下操作將標籤字符串映射到輸出概率:
迅速
guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return } let fileContents = try? String(contentsOfFile: labelPath) guard let labels = fileContents?.components(separatedBy: "\n") else { return } for i in 0 ..< labels.count { if let probability = probabilities?[i] { print("\(labels[i]): \(probability)") } }
Objective-C
NSError *labelReadError = nil; NSString *labelPath = [NSBundle.mainBundle pathForResource:@"retrained_labels" ofType:@"txt"]; NSString *fileContents = [NSString stringWithContentsOfFile:labelPath encoding:NSUTF8StringEncoding error:&labelReadError]; if (labelReadError != nil || fileContents == NULL) { return; } NSArray<NSString *> *labels = [fileContents componentsSeparatedByString:@"\n"]; for (int i = 0; i < labels.count; i++) { NSString *label = labels[i]; NSNumber *probability = probabilites[i]; NSLog(@"%@: %f", label, probability.floatValue); }
附錄:模型安全
無論您如何使 TensorFlow Lite 模型可用於 ML Kit,ML Kit 都會以標準序列化 protobuf 格式將它們存儲在本地存儲中。
理論上,這意味著任何人都可以復制您的模型。然而,在實踐中,大多數模型都是特定於應用程序的,並且被優化混淆了,其風險類似於競爭對手反彙編和重用代碼的風險。不過,在您的應用程序中使用自定義模型之前,您應該意識到這種風險。