Użyj modelu TensorFlow Lite do wnioskowania za pomocą ML Kit w systemie iOS

Możesz użyć ML Kit, aby przeprowadzić wnioskowanie na urządzeniu z modelem TensorFlow Lite .

ML Kit może używać modeli TensorFlow Lite tylko na urządzeniach z systemem iOS 9 i nowszym.

Zanim zaczniesz

  1. Jeśli nie dodałeś jeszcze Firebase do swojej aplikacji, zrób to, wykonując czynności opisane w przewodniku wprowadzającym .
  2. Dołącz biblioteki ML Kit do swojego Podfile:
    pod 'Firebase/MLModelInterpreter', '6.25.0'
    
    Po zainstalowaniu lub zaktualizowaniu Podów swojego projektu pamiętaj o otwarciu projektu Xcode przy użyciu jego .xcworkspace .
  3. W swojej aplikacji zaimportuj Firebase:

    Szybki

    import Firebase

    Cel C

    @import Firebase;
  4. Konwertuj model TensorFlow, którego chcesz użyć, do formatu TensorFlow Lite. Zobacz TOCO: Konwerter optymalizujący TensorFlow Lite .

Hostuj lub pakuj swój model

Zanim będziesz mógł użyć modelu TensorFlow Lite do wnioskowania w swojej aplikacji, musisz udostępnić model ML Kit. ML Kit może korzystać z modeli TensorFlow Lite hostowanych zdalnie przy użyciu Firebase, dołączonych do pliku binarnego aplikacji lub obu.

Hostując model w Firebase, możesz aktualizować model bez wydawania nowej wersji aplikacji, a także możesz używać zdalnej konfiguracji i testów A/B do dynamicznego udostępniania różnych modeli różnym grupom użytkowników.

Jeśli zdecydujesz się udostępnić model tylko poprzez hostowanie go w Firebase, a nie łączyć go z aplikacją, możesz zmniejszyć początkowy rozmiar pobieranej aplikacji. Pamiętaj jednak, że jeśli model nie jest dołączony do Twojej aplikacji, wszelkie funkcje związane z modelem nie będą dostępne, dopóki aplikacja nie pobierze modelu po raz pierwszy.

Łącząc model z aplikacją, możesz mieć pewność, że funkcje uczenia maszynowego w aplikacji będą nadal działać, gdy model hostowany przez Firebase nie będzie dostępny.

Hostuj modele w Firebase

Aby hostować model TensorFlow Lite w Firebase:

  1. W sekcji ML Kit konsoli Firebase kliknij kartę Niestandardowe .
  2. Kliknij opcję Dodaj model niestandardowy (lub Dodaj kolejny model ).
  3. Podaj nazwę, która będzie używana do identyfikacji Twojego modelu w projekcie Firebase, a następnie prześlij plik modelu TensorFlow Lite (zwykle kończący się na .tflite lub .lite ).

Po dodaniu niestandardowego modelu do projektu Firebase możesz odwoływać się do modelu w swoich aplikacjach, używając określonej nazwy. W dowolnym momencie możesz przesłać nowy model TensorFlow Lite, a Twoja aplikacja pobierze nowy model i zacznie go używać po ponownym uruchomieniu aplikacji. Możesz zdefiniować warunki urządzenia wymagane, aby Twoja aplikacja podjęła próbę aktualizacji modelu (patrz poniżej).

Połącz modele z aplikacją

Aby powiązać model TensorFlow Lite z aplikacją, dodaj plik modelu (zwykle kończący się na .tflite lub .lite ) do projektu Xcode, pamiętając o zaznaczeniu opcji Kopiuj zasoby pakietu. Plik modelu zostanie zawarty w pakiecie aplikacji i będzie dostępny dla ML Kit.

Załaduj model

Aby używać modelu TensorFlow Lite w swojej aplikacji, najpierw skonfiguruj ML Kit z lokalizacjami, w których Twój model jest dostępny: zdalnie przy użyciu Firebase, w pamięci lokalnej lub w obu przypadkach. Jeśli określisz zarówno model lokalny, jak i zdalny, możesz użyć modelu zdalnego, jeśli jest on dostępny, i wrócić do modelu przechowywanego lokalnie, jeśli model zdalny nie jest dostępny.

Skonfiguruj model hostowany w Firebase

Jeśli hostujesz swój model w Firebase, utwórz obiekt CustomRemoteModel , podając nazwę, którą przypisałeś modelowi podczas jego publikacji:

Szybki

let remoteModel = CustomRemoteModel(
  name: "your_remote_model"  // The name you assigned in the Firebase console.
)

Cel C

// Initialize using the name you assigned in the Firebase console.
FIRCustomRemoteModel *remoteModel =
    [[FIRCustomRemoteModel alloc] initWithName:@"your_remote_model"];

Następnie rozpocznij zadanie pobierania modelu, określając warunki, na jakich chcesz zezwolić na pobieranie. Jeśli modelu nie ma na urządzeniu lub jeśli dostępna jest nowsza wersja modelu, zadanie asynchronicznie pobierze model z Firebase:

Szybki

let downloadConditions = ModelDownloadConditions(
  allowsCellularAccess: true,
  allowsBackgroundDownloading: true
)

let downloadProgress = ModelManager.modelManager().download(
  remoteModel,
  conditions: downloadConditions
)

Cel C

FIRModelDownloadConditions *downloadConditions =
    [[FIRModelDownloadConditions alloc] initWithAllowsCellularAccess:YES
                                         allowsBackgroundDownloading:YES];

NSProgress *downloadProgress =
    [[FIRModelManager modelManager] downloadRemoteModel:remoteModel
                                             conditions:downloadConditions];

Wiele aplikacji rozpoczyna zadanie pobierania w kodzie inicjującym, ale można to zrobić w dowolnym momencie, zanim będzie konieczne użycie modelu.

Skonfiguruj model lokalny

Jeśli połączyłeś model ze swoją aplikacją, utwórz obiekt CustomLocalModel , określając nazwę pliku modelu TensorFlow Lite:

Szybki

guard let modelPath = Bundle.main.path(
  forResource: "your_model",
  ofType: "tflite",
  inDirectory: "your_model_directory"
) else { /* Handle error. */ }
let localModel = CustomLocalModel(modelPath: modelPath)

Cel C

NSString *modelPath = [NSBundle.mainBundle pathForResource:@"your_model"
                                                    ofType:@"tflite"
                                               inDirectory:@"your_model_directory"];
FIRCustomLocalModel *localModel =
    [[FIRCustomLocalModel alloc] initWithModelPath:modelPath];

Utwórz interpreter na podstawie swojego modelu

Po skonfigurowaniu źródeł modelu utwórz obiekt ModelInterpreter na podstawie jednego z nich.

Jeśli masz tylko model powiązany lokalnie, po prostu przekaż obiekt CustomLocalModel do modelInterpreter(localModel:) :

Szybki

let interpreter = ModelInterpreter.modelInterpreter(localModel: localModel)

Cel C

FIRModelInterpreter *interpreter =
    [FIRModelInterpreter modelInterpreterForLocalModel:localModel];

Jeśli masz model hostowany zdalnie, przed uruchomieniem musisz sprawdzić, czy został pobrany. Możesz sprawdzić status zadania pobierania modelu, korzystając z metody isModelDownloaded(remoteModel:) menedżera modeli.

Chociaż musisz to tylko potwierdzić przed uruchomieniem interpretera, jeśli masz zarówno model hostowany zdalnie, jak i model pakowany lokalnie, sensowne może być wykonanie tej kontroli podczas tworzenia instancji ModelInterpreter : utwórz interpreter na podstawie modelu zdalnego, jeśli jest został pobrany, a w przeciwnym razie z modelu lokalnego.

Szybki

var interpreter: ModelInterpreter
if ModelManager.modelManager().isModelDownloaded(remoteModel) {
  interpreter = ModelInterpreter.modelInterpreter(remoteModel: remoteModel)
} else {
  interpreter = ModelInterpreter.modelInterpreter(localModel: localModel)
}

Cel C

FIRModelInterpreter *interpreter;
if ([[FIRModelManager modelManager] isModelDownloaded:remoteModel]) {
  interpreter = [FIRModelInterpreter modelInterpreterForRemoteModel:remoteModel];
} else {
  interpreter = [FIRModelInterpreter modelInterpreterForLocalModel:localModel];
}

Jeśli masz tylko model hostowany zdalnie, wyłącz funkcje związane z modelem — na przykład wyszarz lub ukryj część interfejsu użytkownika — do czasu potwierdzenia, że ​​model został pobrany.

Stan pobierania modelu możesz uzyskać podłączając obserwatorów do domyślnego Centrum powiadomień. Pamiętaj, aby użyć słabego odniesienia do self w bloku obserwatora, ponieważ pobieranie może zająć trochę czasu, a obiekt źródłowy może zostać zwolniony do czasu zakończenia pobierania. Na przykład:

Szybki

NotificationCenter.default.addObserver(
    forName: .firebaseMLModelDownloadDidSucceed,
    object: nil,
    queue: nil
) { [weak self] notification in
    guard let strongSelf = self,
        let userInfo = notification.userInfo,
        let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue]
            as? RemoteModel,
        model.name == "your_remote_model"
        else { return }
    // The model was downloaded and is available on the device
}

NotificationCenter.default.addObserver(
    forName: .firebaseMLModelDownloadDidFail,
    object: nil,
    queue: nil
) { [weak self] notification in
    guard let strongSelf = self,
        let userInfo = notification.userInfo,
        let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue]
            as? RemoteModel
        else { return }
    let error = userInfo[ModelDownloadUserInfoKey.error.rawValue]
    // ...
}

Cel C

__weak typeof(self) weakSelf = self;

[NSNotificationCenter.defaultCenter
    addObserverForName:FIRModelDownloadDidSucceedNotification
                object:nil
                 queue:nil
            usingBlock:^(NSNotification *_Nonnull note) {
              if (weakSelf == nil | note.userInfo == nil) {
                return;
              }
              __strong typeof(self) strongSelf = weakSelf;

              FIRRemoteModel *model = note.userInfo[FIRModelDownloadUserInfoKeyRemoteModel];
              if ([model.name isEqualToString:@"your_remote_model"]) {
                // The model was downloaded and is available on the device
              }
            }];

[NSNotificationCenter.defaultCenter
    addObserverForName:FIRModelDownloadDidFailNotification
                object:nil
                 queue:nil
            usingBlock:^(NSNotification *_Nonnull note) {
              if (weakSelf == nil | note.userInfo == nil) {
                return;
              }
              __strong typeof(self) strongSelf = weakSelf;

              NSError *error = note.userInfo[FIRModelDownloadUserInfoKeyError];
            }];

Określ dane wejściowe i wyjściowe modelu

Następnie skonfiguruj formaty wejściowe i wyjściowe interpretera modelu.

Model TensorFlow Lite przyjmuje jako dane wejściowe i generuje jako dane wyjściowe jedną lub więcej tablic wielowymiarowych. Tablice te zawierają wartości byte , int , long lub float . Musisz skonfigurować ML Kit podając liczbę i wymiary („kształt”) tablic używanych przez Twój model.

Jeśli nie znasz kształtu i typu danych wejściowych i wyjściowych swojego modelu, możesz użyć interpretera TensorFlow Lite Python, aby sprawdzić swój model. Na przykład:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Po określeniu formatu danych wejściowych i wyjściowych modelu skonfiguruj interpreter modelu aplikacji, tworząc obiekt ModelInputOutputOptions .

Na przykład zmiennoprzecinkowy model klasyfikacji obrazów może przyjmować jako dane wejściowe tablicę N x224x224x3 wartości Float , reprezentującą partię N 224x224 obrazów trójkanałowych (RGB), i generować jako wynik listę 1000 wartości Float , z których każda reprezentuje prawdopodobieństwo, że obraz należy do jednej z 1000 kategorii przewidywanych przez model.

W przypadku takiego modelu należy skonfigurować wejście i wyjście interpretera modelu, jak pokazano poniżej:

Szybki

let ioOptions = ModelInputOutputOptions()
do {
    try ioOptions.setInputFormat(index: 0, type: .float32, dimensions: [1, 224, 224, 3])
    try ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 1000])
} catch let error as NSError {
    print("Failed to set input or output format with error: \(error.localizedDescription)")
}

Cel C

FIRModelInputOutputOptions *ioOptions = [[FIRModelInputOutputOptions alloc] init];
NSError *error;
[ioOptions setInputFormatForIndex:0
                             type:FIRModelElementTypeFloat32
                       dimensions:@[@1, @224, @224, @3]
                            error:&error];
if (error != nil) { return; }
[ioOptions setOutputFormatForIndex:0
                              type:FIRModelElementTypeFloat32
                        dimensions:@[@1, @1000]
                             error:&error];
if (error != nil) { return; }

Wykonaj wnioskowanie na danych wejściowych

Na koniec, aby przeprowadzić wnioskowanie przy użyciu modelu, pobierz dane wejściowe, wykonaj wszelkie przekształcenia danych, które mogą być niezbędne dla Twojego modelu, i zbuduj obiekt Data zawierający dane.

Na przykład, jeśli model przetwarza obrazy, a jego wymiary wejściowe to [BATCH_SIZE, 224, 224, 3] wartości zmiennoprzecinkowe, może być konieczne przeskalowanie wartości kolorów obrazu do zakresu zmiennoprzecinkowego, jak w poniższym przykładzie :

Szybki

let image: CGImage = // Your input image
guard let context = CGContext(
  data: nil,
  width: image.width, height: image.height,
  bitsPerComponent: 8, bytesPerRow: image.width * 4,
  space: CGColorSpaceCreateDeviceRGB(),
  bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
  return false
}

context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }

let inputs = ModelInputs()
var inputData = Data()
do {
  for row in 0 ..< 224 {
    for col in 0 ..< 224 {
      let offset = 4 * (col * context.width + row)
      // (Ignore offset 0, the unused alpha channel)
      let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
      let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
      let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)

      // Normalize channel values to [0.0, 1.0]. This requirement varies
      // by model. For example, some models might require values to be
      // normalized to the range [-1.0, 1.0] instead, and others might
      // require fixed-point values or the original bytes.
      var normalizedRed = Float32(red) / 255.0
      var normalizedGreen = Float32(green) / 255.0
      var normalizedBlue = Float32(blue) / 255.0

      // Append normalized values to Data object in RGB order.
      let elementSize = MemoryLayout.size(ofValue: normalizedRed)
      var bytes = [UInt8](repeating: 0, count: elementSize)
      memcpy(&bytes, &normalizedRed, elementSize)
      inputData.append(&bytes, count: elementSize)
      memcpy(&bytes, &normalizedGreen, elementSize)
      inputData.append(&bytes, count: elementSize)
      memcpy(&ammp;bytes, &normalizedBlue, elementSize)
      inputData.append(&bytes, count: elementSize)
    }
  }
  try inputs.addInput(inputData)
} catch let error {
  print("Failed to add input: \(error)")
}

Cel C

CGImageRef image = // Your input image
long imageWidth = CGImageGetWidth(image);
long imageHeight = CGImageGetHeight(image);
CGContextRef context = CGBitmapContextCreate(nil,
                                             imageWidth, imageHeight,
                                             8,
                                             imageWidth * 4,
                                             CGColorSpaceCreateDeviceRGB(),
                                             kCGImageAlphaNoneSkipFirst);
CGContextDrawImage(context, CGRectMake(0, 0, imageWidth, imageHeight), image);
UInt8 *imageData = CGBitmapContextGetData(context);

FIRModelInputs *inputs = [[FIRModelInputs alloc] init];
NSMutableData *inputData = [[NSMutableData alloc] initWithCapacity:0];

for (int row = 0; row < 224; row++) {
  for (int col = 0; col < 224; col++) {
    long offset = 4 * (col * imageWidth + row);
    // Normalize channel values to [0.0, 1.0]. This requirement varies
    // by model. For example, some models might require values to be
    // normalized to the range [-1.0, 1.0] instead, and others might
    // require fixed-point values or the original bytes.
    // (Ignore offset 0, the unused alpha channel)
    Float32 red = imageData[offset+1] / 255.0f;
    Float32 green = imageData[offset+2] / 255.0f;
    Float32 blue = imageData[offset+3] / 255.0f;

    [inputData appendBytes:&red length:sizeof(red)];
    [inputData appendBytes:&green length:sizeof(green)];
    [inputData appendBytes:&blue length:sizeof(blue)];
  }
}

[inputs addInput:inputData error:&error];
if (error != nil) { return nil; }

Po przygotowaniu danych wejściowych modelu (i potwierdzeniu, że model jest dostępny), przekaż opcje wejścia i wejścia/wyjścia do metody run(inputs:options:completion:) interpretera modelu .

Szybki

interpreter.run(inputs: inputs, options: ioOptions) { outputs, error in
    guard error == nil, let outputs = outputs else { return }
    // Process outputs
    // ...
}

Cel C

[interpreter runWithInputs:inputs
                   options:ioOptions
                completion:^(FIRModelOutputs * _Nullable outputs,
                             NSError * _Nullable error) {
  if (error != nil || outputs == nil) {
    return;
  }
  // Process outputs
  // ...
}];

Dane wyjściowe można uzyskać, wywołując metodę output(index:) zwracanego obiektu. Na przykład:

Szybki

// Get first and only output of inference with a batch size of 1
let output = try? outputs.output(index: 0) as? [[NSNumber]]
let probabilities = output??[0]

Cel C

// Get first and only output of inference with a batch size of 1
NSError *outputError;
NSArray *probabilites = [outputs outputAtIndex:0 error:&outputError][0];

Sposób wykorzystania danych wyjściowych zależy od używanego modelu.

Na przykład, jeśli przeprowadzasz klasyfikację, w następnym kroku możesz zmapować indeksy wyniku na reprezentowane przez nie etykiety. Załóżmy, że masz plik tekstowy z ciągami etykiet dla każdej kategorii modelu; możesz zmapować ciągi etykiet na prawdopodobieństwa wyjściowe, wykonując coś takiego:

Szybki

guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }

for i in 0 ..< labels.count {
  if let probability = probabilities?[i] {
    print("\(labels[i]): \(probability)")
  }
}

Cel C

NSError *labelReadError = nil;
NSString *labelPath = [NSBundle.mainBundle pathForResource:@"retrained_labels"
                                                    ofType:@"txt"];
NSString *fileContents = [NSString stringWithContentsOfFile:labelPath
                                                   encoding:NSUTF8StringEncoding
                                                      error:&labelReadError];
if (labelReadError != nil || fileContents == NULL) { return; }
NSArray<NSString *> *labels = [fileContents componentsSeparatedByString:@"\n"];
for (int i = 0; i < labels.count; i++) {
    NSString *label = labels[i];
    NSNumber *probability = probabilites[i];
    NSLog(@"%@: %f", label, probability.floatValue);
}

Dodatek: Bezpieczeństwo modelu

Niezależnie od tego, w jaki sposób udostępnisz modele TensorFlow Lite w ML Kit, ML Kit przechowuje je w standardowym serializowanym formacie protobuf w pamięci lokalnej.

Teoretycznie oznacza to, że każdy może skopiować Twój model. Jednak w praktyce większość modeli jest tak specyficzna dla aplikacji i zaciemniona optymalizacjami, że ryzyko jest podobne do ryzyka, jakie stwarza konkurencja, która demontuje i ponownie wykorzystuje Twój kod. Niemniej jednak powinieneś zdawać sobie sprawę z tego ryzyka, zanim użyjesz niestandardowego modelu w swojej aplikacji.