Go to console

Use a TensorFlow Lite model for inference with ML Kit on Android

You can use ML Kit to perform on-device inference with a TensorFlow Lite model.

This API requires Android SDK level 16 (Jelly Bean) or newer.

See the ML Kit quickstart sample on GitHub for an example of this API in use, or try the codelab.

Before you begin

  1. If you haven't already, add Firebase to your Android project.
  2. In your project-level build.gradle file, make sure to include Google's Maven repository in both your buildscript and allprojects sections.
  3. Add the dependencies for the ML Kit Android libraries to your module (app-level) Gradle file (usually app/build.gradle):
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:21.0.0'
    }
    apply plugin: 'com.google.gms.google-services'
    
  4. Convert the TensorFlow model you want to use to TensorFlow Lite format. See TOCO: TensorFlow Lite Optimizing Converter.

Host or bundle your model

Before you can use a TensorFlow Lite model for inference in your app, you must make the model available to ML Kit. ML Kit can use TensorFlow Lite models hosted remotely using Firebase, bundled with the app binary, or both.

By hosting a model on Firebase, you can update the model without releasing a new app version, and you can use Remote Config and A/B Testing to dynamically serve different models to different sets of users.

If you choose to only provide the model by hosting it with Firebase, and not bundle it with your app, you can reduce the initial download size of your app. Keep in mind, though, that if the model is not bundled with your app, any model-related functionality will not be available until your app downloads the model for the first time.

By bundling your model with your app, you can ensure your app's ML features still work when the Firebase-hosted model isn't available.

Host models on Firebase

To host your TensorFlow Lite model on Firebase:

  1. In the ML Kit section of the Firebase console, click the Custom tab.
  2. Click Add custom model (or Add another model).
  3. Specify a name that will be used to identify your model in your Firebase project, then upload the TensorFlow Lite model file (usually ending in .tflite or .lite).
  4. In your app's manifest, declare that INTERNET permission is required:
    <uses-permission android:name="android.permission.INTERNET" />
    

After you add a custom model to your Firebase project, you can reference the model in your apps using the name you specified. At any time, you can upload a new TensorFlow Lite model, and your app will download the new model and start using it when the app next restarts. You can define the device conditions required for your app to attempt to update the model (see below).

Bundle models with an app

To bundle your TensorFlow Lite model with your app, copy the model file (usually ending in .tflite or .lite) to your app's assets/ folder. (You might need to create the folder first by right-clicking the app/ folder, then clicking New > Folder > Assets Folder.)

Then, add the following to your app's build.gradle file to ensure Gradle doesn’t compress the models when building the app:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

The model file will be included in the app package and available to ML Kit as a raw asset.

Load the model

To use your TensorFlow Lite model in your app, first configure ML Kit with the locations where your model is available: remotely using Firebase, in local storage, or both. If you specify both a local and remote model, ML Kit will use the remote model if it is available, and fall back to the locally-stored model if the remote model isn't available.

Configure a Firebase-hosted model

If you hosted your model with Firebase, create a FirebaseRemoteModel object, specifying the name you assigned the model when you uploaded it, and the conditions under which ML Kit should download the model initially and when an update is available.

Java

FirebaseModelDownloadConditions.Builder conditionsBuilder =
        new FirebaseModelDownloadConditions.Builder().requireWifi();
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
    // Enable advanced conditions on Android Nougat and newer.
    conditionsBuilder = conditionsBuilder
            .requireCharging()
            .requireDeviceIdle();
}
FirebaseModelDownloadConditions conditions = conditionsBuilder.build();

// Build a remote model source object by specifying the name you assigned the model
// when you uploaded it in the Firebase console.
FirebaseRemoteModel cloudSource = new FirebaseRemoteModel.Builder("my_cloud_model")
        .enableModelUpdates(true)
        .setInitialDownloadConditions(conditions)
        .setUpdatesDownloadConditions(conditions)
        .build();
FirebaseModelManager.getInstance().registerRemoteModel(cloudSource);

Kotlin

var conditionsBuilder: FirebaseModelDownloadConditions.Builder =
        FirebaseModelDownloadConditions.Builder().requireWifi()
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
    // Enable advanced conditions on Android Nougat and newer.
    conditionsBuilder = conditionsBuilder
            .requireCharging()
            .requireDeviceIdle()
}
val conditions = conditionsBuilder.build()

// Build a remote model object by specifying the name you assigned the model
// when you uploaded it in the Firebase console.
val cloudSource = FirebaseRemoteModel.Builder("my_cloud_model")
        .enableModelUpdates(true)
        .setInitialDownloadConditions(conditions)
        .setUpdatesDownloadConditions(conditions)
        .build()
FirebaseModelManager.getInstance().registerRemoteModel(cloudSource)

Configure a local model

If you bundled the model with your app, create a FirebaseLocalModel object, specifying the filename of the TensorFlow Lite model and assigning the model a name you will use in the next step.

Java

FirebaseLocalModel localSource =
        new FirebaseLocalModel.Builder("my_local_model")  // Assign a name to this model
                .setAssetFilePath("my_model.tflite")
                .build();
FirebaseModelManager.getInstance().registerLocalModel(localSource);

Kotlin

val localSource = FirebaseLocalModel.Builder("my_local_model") // Assign a name to this model
        .setAssetFilePath("my_model.tflite")
        .build()
FirebaseModelManager.getInstance().registerLocalModel(localSource)

Create an interpreter from your model

After you configure your model locations, create a FirebaseModelOptions object with the names of your remote model, local model, or both, and use it to get an instance of FirebaseModelInterpreter:

Java

FirebaseModelOptions options = new FirebaseModelOptions.Builder()
        .setRemoteModelName("my_cloud_model")
        .setLocalModelName("my_local_model")
        .build();
FirebaseModelInterpreter firebaseInterpreter =
        FirebaseModelInterpreter.getInstance(options);

Kotlin

val options = FirebaseModelOptions.Builder()
        .setRemoteModelName("my_cloud_model")
        .setLocalModelName("my_local_model")
        .build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Make sure the model is available on the device

Recommended: If you didn't configure a locally-bundled model, make sure the remote model has been downloaded to the device.

When you run a remotely-hosted model, if the model isn't yet available on the device, the call fails and the model is automatically downloaded to the device in the background. When the download is complete, you can run the model successfully.

If you want to handle the model downloading task more explicitly, you can start the model downloading task and check its status by calling downloadRemoteModelIfNeeded():

Java

FirebaseModelManager.getInstance().downloadRemoteModelIfNeeded(remoteModel)
        .addOnSuccessListener(
            new OnSuccessListener<Void>() {
              @Override
              public void onSuccess() {
                // Model downloaded successfully. Okay to use the model.
              }
            })
        .addOnFailureListener(
            new OnFailureListener() {
              @Override
              public void onFailure(@NonNull Exception e) {
                // Model couldn’t be downloaded or other internal error.
                // ...
              }
            });

Kotlin

FirebaseModelManager.getInstance().downloadRemoteModelIfNeeded(remoteModel)
        .addOnSuccessListener {
            // Model downloaded successfully. Okay to use the model.
        }
        .addOnFailureListener {
            // Model couldn’t be downloaded or other internal error.
            // ...
        }

The downloadRemoteModelIfNeeded() method starts downloading the model if necessary, and calls the success listener when the download finishes. If the model is already available, the method calls the success listener immediately.

Specify the model's input and output

Next, configure the model interpreter's input and output formats.

A TensorFlow Lite model takes as input and produces as output one or more multidimensional arrays. These arrays contain either byte, int, long, or float values. You must configure ML Kit with the number and dimensions ("shape") of the arrays your model uses.

If you don't know the shape and data type of your model's input and output, you can use the TensorFlow Lite Python interpreter to inspect your model. For example:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

After you have determined the format of your model's input and output, you can configure your app's model interpreter by creating a FirebaseModelInputOutputOptions object.

For example, a floating-point image classification model might take as input an Nx224x224x3 array of float values, representing a batch of N 224x224 three-channel (RGB) images, and produce as output a list of 1000 float values, each representing the probability the image is a member of one of the 1000 categories the model predicts.

For such a model, you would configure the model interpreter's input and output as shown below:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Perform inference on input data

Finally, to perform inference using the model, get your input data and perform any transformations on the data that are necessary to get an input array of the right shape for your model.

For example, if you have an image classification model with an input shape of [1 224 224 3] floating-point values, you could generate an input array from a Bitmap object as shown in the following example:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Then, create a FirebaseModelInputs object with your input data, and pass it and the model's input and output specification to the model interpreter's run method:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener(
                object : OnFailureListener {
                    override fun onFailure(e: Exception) {
                        // Task failed with an exception
                        // ...
                    }
                })

If the call succeeds, you can get the output by calling the getOutput() method of the object that is passed to the success listener. For example:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

How you use the output depends on the model you are using.

For example, if you are performing classification, as a next step, you might map the indexes of the result to the labels they represent:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Appendix: Model security

Regardless of how you make your TensorFlow Lite models available to ML Kit, ML Kit stores them in the standard serialized protobuf format in local storage.

In theory, this means that anybody can copy your model. However, in practice, most models are so application-specific and obfuscated by optimizations that the risk is similar to that of competitors disassembling and reusing your code. Nevertheless, you should be aware of this risk before you use a custom model in your app.

On Android API level 21 (Lollipop) and newer, the model is downloaded to a directory that is excluded from automatic backup.

On Android API level 20 and older, the model is downloaded to a directory named com.google.firebase.ml.custom.models in app-private internal storage. If you enabled file backup using BackupAgent, you might choose to exclude this directory.