Google is committed to advancing racial equity for Black communities. See how.

Use a custom TensorFlow Lite model on iOS

If your app uses custom TensorFlow Lite models, you can use Firebase ML to deploy your models. By deploying models with Firebase, you can reduce the initial download size of your app and update your app's ML models without releasing a new version of your app. And, with Remote Config and A/B Testing, you can dynamically serve different models to different sets of users.

TensorFlow Lite runs only on devices using iOS 9 and newer.

TensorFlow Lite models

TensorFlow Lite models are ML models that are optimized to run on mobile devices. To get a TensorFlow Lite model:

Before you begin

  1. If you have not already added Firebase to your app, do so by following the steps in the getting started guide.
  2. Include Firebase ML in your Podfile:

    Swift

    pod 'Firebase/MLModelInterpreter'
    pod 'TensorFlowLiteSwift'
    

    Objective-C

    pod 'Firebase/MLModelInterpreter'
    pod 'TensorFlowLiteObjC'
    
    After you install or update your project's Pods, be sure to open your Xcode project using its .xcworkspace.
  3. In your app, import Firebase:

    Swift

    import Firebase
    import TensorFlowLite
    

    Objective-C

    @import Firebase;
    @import TFLTensorFlowLite;
    

1. Deploy your model

Deploy your custom TensorFlow models using either the Firebase console or the Firebase Admin Python and Node.js SDKs. See Deploy and manage custom models.

After you add a custom model to your Firebase project, you can reference the model in your apps using the name you specified. At any time, you can upload a new TensorFlow Lite model, and your app will download the new model and start using it when the app next restarts. You can define the device conditions required for your app to attempt to update the model (see below).

2. Download the model to the device

To use your TensorFlow Lite model in your app, first use the Firebase ML SDK to download the latest version of the model to the device.

To start the model download, call the model manager's download() method, specifying the name you assigned the model when you uploaded it and the conditions under which you want to allow downloading. If the model isn't on the device, or if a newer version of the model is available, the task will asynchronously download the model from Firebase.

You should disable model-related functionality—for example, grey-out or hide part of your UI—until you confirm the model has been downloaded.

Swift

let remoteModel = CustomRemoteModel(
  name: "your_remote_model"  // The name you assigned in the Firebase console.
)
let downloadConditions = ModelDownloadConditions(
  allowsCellularAccess: true,
  allowsBackgroundDownloading: true
)
let downloadProgress = ModelManager.modelManager().download(
  remoteModel,
  conditions: downloadConditions
)

Objective-C

// Initialize using the name you assigned in the Firebase console.
FIRCustomRemoteModel *remoteModel =
    [[FIRCustomRemoteModel alloc] initWithName:@"your_remote_model"];
FIRModelDownloadConditions *downloadConditions =
    [[FIRModelDownloadConditions alloc] initWithAllowsCellularAccess:YES
                                         allowsBackgroundDownloading:YES];

NSProgress *downloadProgress =
    [[FIRModelManager modelManager] downloadRemoteModel:remoteModel
                                             conditions:downloadConditions];

Many apps start the download task in their initialization code, but you can do so at any point before you need to use the model.

You can get the model download status by attaching observers to the default Notification Center. Be sure to use a weak reference to self in the observer block, since downloads can take some time, and the originating object can be freed by the time the download finishes. For example:

Swift

NotificationCenter.default.addObserver(
    forName: .firebaseMLModelDownloadDidSucceed,
    object: nil,
    queue: nil
) { [weak self] notification in
    guard let strongSelf = self,
        let userInfo = notification.userInfo,
        let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue]
            as? RemoteModel,
        model.name == "your_remote_model"
        else { return }
    // The model was downloaded and is available on the device
}

NotificationCenter.default.addObserver(
    forName: .firebaseMLModelDownloadDidFail,
    object: nil,
    queue: nil
) { [weak self] notification in
    guard let strongSelf = self,
        let userInfo = notification.userInfo,
        let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue]
            as? RemoteModel
        else { return }
    let error = userInfo[ModelDownloadUserInfoKey.error.rawValue]
    // ...
}

Objective-C

__weak typeof(self) weakSelf = self;

[NSNotificationCenter.defaultCenter
    addObserverForName:FIRModelDownloadDidSucceedNotification
                object:nil
                 queue:nil
            usingBlock:^(NSNotification *_Nonnull note) {
              if (weakSelf == nil | note.userInfo == nil) {
                return;
              }
              __strong typeof(self) strongSelf = weakSelf;

              FIRRemoteModel *model = note.userInfo[FIRModelDownloadUserInfoKeyRemoteModel];
              if ([model.name isEqualToString:@"your_remote_model"]) {
                // The model was downloaded and is available on the device
              }
            }];

[NSNotificationCenter.defaultCenter
    addObserverForName:FIRModelDownloadDidFailNotification
                object:nil
                 queue:nil
            usingBlock:^(NSNotification *_Nonnull note) {
              if (weakSelf == nil | note.userInfo == nil) {
                return;
              }
              __strong typeof(self) strongSelf = weakSelf;

              NSError *error = note.userInfo[FIRModelDownloadUserInfoKeyError];
            }];

3. Initialize a TensorFlow Lite interpreter

After you download the model to the device, you can get the model file location by calling the model manager's getLatestModelFilePath() method. Use this value to instantiate a TensorFlow Lite interpreter:

Swift

var interpreter: Interpreter
ModelManager.modelManager().getLatestModelFilePath(remoteModel) { (remoteModelPath, error) in
    guard error == nil else { return }
    guard let remoteModelPath = remoteModelPath else { return }
    do {
        interpreter = try Interpreter(modelPath: remoteModelPath)
    } catch {
        // Error?
    }
}

Objective-C

TFLInterpreter *interpreter;
[FIRModelManager.modelManager
 getLatestModelFilePath:remoteModel
             completion:^(NSString *_Nullable remoteModelPath, NSError *error) {
  if (remoteModelPath != null && error == null) {
    NSError *tfliteError;
    interpreter = [[TFLInterpreter alloc] initWithModelPath:remoteModelPath
                                                      error:&tfliteError];
  }
}];

4. Perform inference on input data

Get your model's input and output shapes

The TensorFlow Lite model interpreter takes as input and produces as output one or more multidimensional arrays. These arrays contain either byte, int, long, or float values. Before you can pass data to a model or use its result, you must know the number and dimensions ("shape") of the arrays your model uses.

If you built the model yourself, or if the model's input and output format is documented, you might already have this information. If you don't know the shape and data type of your model's input and output, you can use the TensorFlow Lite interpreter to inspect your model. For example:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Example output:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Run the interpreter

After you have determined the format of your model's input and output, get your input data and perform any transformations on the data that are necessary to get an input of the right shape for your model.

For example, if your model processes images, and your model has input dimensions of [1, 224, 224, 3] floating-point values, you might have to scale the image's color values to a floating-point range as in the following example:

Swift

let image: CGImage = // Your input image
guard let context = CGContext(
  data: nil,
  width: image.width, height: image.height,
  bitsPerComponent: 8, bytesPerRow: image.width * 4,
  space: CGColorSpaceCreateDeviceRGB(),
  bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
  return false
}

context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }

var inputData = Data()
for row in 0 ..&lt; 224 {
  for col in 0 ..&lt; 224 {
    let offset = 4 * (row * context.width + col)
    // (Ignore offset 0, the unused alpha channel)
    let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
    let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
    let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)

    // Normalize channel values to [0.0, 1.0]. This requirement varies
    // by model. For example, some models might require values to be
    // normalized to the range [-1.0, 1.0] instead, and others might
    // require fixed-point values or the original bytes.
    var normalizedRed = Float32(red) / 255.0
    var normalizedGreen = Float32(green) / 255.0
    var normalizedBlue = Float32(blue) / 255.0

    // Append normalized values to Data object in RGB order.
    let elementSize = MemoryLayout.size(ofValue: normalizedRed)
    var bytes = [UInt8](repeating: 0, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedRed, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedGreen, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&ammp;bytes, &amp;normalizedBlue, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
  }
}

Objective-C

CGImageRef image = // Your input image
long imageWidth = CGImageGetWidth(image);
long imageHeight = CGImageGetHeight(image);
CGContextRef context = CGBitmapContextCreate(nil,
                                             imageWidth, imageHeight,
                                             8,
                                             imageWidth * 4,
                                             CGColorSpaceCreateDeviceRGB(),
                                             kCGImageAlphaNoneSkipFirst);
CGContextDrawImage(context, CGRectMake(0, 0, imageWidth, imageHeight), image);
UInt8 *imageData = CGBitmapContextGetData(context);

NSMutableData *inputData = [[NSMutableData alloc] initWithCapacity:0];

for (int row = 0; row &lt; 224; row++) {
  for (int col = 0; col &lt; 224; col++) {
    long offset = 4 * (row * imageWidth + col);
    // Normalize channel values to [0.0, 1.0]. This requirement varies
    // by model. For example, some models might require values to be
    // normalized to the range [-1.0, 1.0] instead, and others might
    // require fixed-point values or the original bytes.
    // (Ignore offset 0, the unused alpha channel)
    Float32 red = imageData[offset+1] / 255.0f;
    Float32 green = imageData[offset+2] / 255.0f;
    Float32 blue = imageData[offset+3] / 255.0f;

    [inputData appendBytes:&amp;red length:sizeof(red)];
    [inputData appendBytes:&amp;green length:sizeof(green)];
    [inputData appendBytes:&amp;blue length:sizeof(blue)];
  }
}

Then, copy your input NSData to the interpreter and run it:

Swift

try interpreter.allocateTensors()
try interpreter.copy(inputData, toInputAt: 0)
try interpreter.invoke()

Objective-C

NSError *error = nil;

[interpreter allocateTensorsWithError:&error];
if (error != nil) { return; }

TFLTensor *input = [interpreter inputTensorAtIndex:0 error:&error];
if (error != nil) { return; }

[input copyData:inputData error:&error];
if (error != nil) { return; }

[interpreter invokeWithError:&error];
if (error != nil) { return; }

You can get the model's ouput by calling the interpreter's output(at:) method. How you use the output depends on the model you are using.

For example, if you are performing classification, as a next step, you might map the indexes of the result to the labels they represent:

Swift

let output = try interpreter.output(at: 0)
let probabilities =
        UnsafeMutableBufferPointer<Float32>.allocate(capacity: 1000)
output.data.copyBytes(to: probabilities)

guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }

for i in labels.indices {
    print("\(labels[i]): \(probabilities[i])")
}

Objective-C

NSError *error = nil;

TFLTensor *output = [interpreter outputTensorAtIndex:0 error:&amp;error];
if (error != nil) { return; }

NSData *outputData = [output dataWithError:&amp;error];
if (error != nil) { return; }

NSString *labelPath = [NSBundle.mainBundle pathForResource:@"retrained_labels"
                                                    ofType:@"txt"];
NSString *fileContents = [NSString stringWithContentsOfFile:labelPath
                                                   encoding:NSUTF8StringEncoding
                                                      error:&amp;error];
if (error != nil || fileContents == NULL) { return; }
NSArray&lt;NSString *> *labels = [fileContents componentsSeparatedByString:@"\n"];
for (int i = 0; i &lt; labels.count; i++) {
    NSString *label = labels[i];
    float probability;
    [outputData getBytes:&amp;probability range:NSMakeRange(i * 4, 4)];
    NSLog(@"%@: %f", label, probability);
}

Appendix: Fall back to a locally-bundled model

When you host your model with Firebase, any model-related functionality will not be available until your app downloads the model for the first time. For some apps, this might be fine, but if your model enables core functionality, you might want to bundle a version of your model with your app and use the best-available version. By doing so, you can ensure your app's ML features work when the Firebase-hosted model isn't available.

To bundle your TensorFlow Lite model with your app, add the model file (usually ending in .tflite or .lite) to your Xcode project, taking care to add the file to app's build target when you do so. The model file will be included in the app bundle and available to your app.

Then, use the locally-bundled model when the hosted model isn't available:

Swift

var interpreter: Interpreter
ModelManager.modelManager().getLatestModelFilePath(remoteModel) { (remoteModelPath, error) in
  guard let error == nil else { return }
  do {
      if let remoteModelPath = remoteModelPath {
          interpreter = try Interpreter(modelPath: remoteModelPath)
      } else {
          let localModelPath = Bundle.main.path(
              forResource: "your_model",
              ofType: "tflite"
          )
          interpreter = try Interpreter(modelPath: localModelPath ?? "")
      }
  } catch {
      print("Error initializing TensorFlow Lite: \(error.localizedDescription)")
      return
  }
}

Objective-C

TFLInterpreter *interpreter;
[FIRModelManager.modelManager
 getLatestModelFilePath:remoteModel
             completion:^(NSString *_Nullable remoteModelPath, NSError *error) {
  if (remoteModelPath != null && error == null) {
    NSError *tfliteError;
    interpreter = [[TFLInterpreter alloc] initWithModelPath:remoteModelPath
                                                      error:&tfliteError];
  } else {
    NSError *tfliteError;
    NSString *localModelPath = [NSBundle.mainBundle pathForResource:@"model"
                                                             ofType:@"tflite"];
    interpreter = [[TFLInterpreter alloc] initWithModelPath:localModelPath
                                                      error:&tfliteError];
  }
}];

Appendix: Model security

Regardless of how you make your TensorFlow Lite models available to Firebase ML, Firebase ML stores them in the standard serialized protobuf format in local storage.

In theory, this means that anybody can copy your model. However, in practice, most models are so application-specific and obfuscated by optimizations that the risk is similar to that of competitors disassembling and reusing your code. Nevertheless, you should be aware of this risk before you use a custom model in your app.