Usa un modelo personalizado de TensorFlow Lite en Android

Si tu app usa modelos personalizados de TensorFlow Lite, puedes usar Firebase ML para implementarlos. Si implementas modelos con Firebase, puedes reducir el tamaño de la descarga inicial de tu app y actualizar sus modelos de AA sin lanzar una nueva versión. Además, con Remote Config y A/B Testing, puedes entregar de manera dinámica diferentes modelos a conjuntos distintos de usuarios.

Modelos de TensorFlow Lite

Los modelos de TensorFlow Lite son modelos de AA optimizados para ejecutarse en dispositivos móviles. Deberás realizar lo siguiente para obtener un modelo de TensorFlow Lite:

Antes de comenzar

  1. Si aún no lo has hecho, agrega Firebase a tu proyecto de Android.
  2. En el archivo de Gradle del módulo (nivel de app) (generalmente <project>/<app-module>/build.gradle.kts o <project>/<app-module>/build.gradle), agrega la dependencia de la biblioteca de Firebase ML Model Downloader para Android. Te recomendamos usar Firebase Android BoM para controlar las versiones de las bibliotecas.

    Además, como parte de la configuración de Firebase ML Model Downloader, debes agregar el SDK de TensorFlow Lite a tu app.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:33.6.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Cuando usas Firebase Android BoM, tu app siempre usará versiones compatibles de las bibliotecas de Firebase para Android.

    (Alternativa)  Agrega dependencias de la biblioteca de Firebase sin usar la BoM

    Si eliges no usar la Firebase BoM, debes especificar cada versión de la biblioteca de Firebase en su línea de dependencia.

    Ten en cuenta que, si usas múltiples bibliotecas de Firebase en tu app, es muy recomendable que uses la BoM para administrar las versiones de las bibliotecas para garantizar que todas las versiones sean compatibles.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:25.0.1")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    ¿Buscas un módulo de biblioteca específico de Kotlin? A partir de octubre de 2023 (Firebase BoM 32.5.0), tanto los desarrolladores de Kotlin como los de Java pueden depender del módulo de la biblioteca principal (para obtener más información, consulta la Preguntas frecuentes sobre esta iniciativa).
  3. En el manifiesto de tu app, declara que se requiera el permiso de INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Implementa tu modelo

Implementa tus modelos personalizados de TensorFlow con Firebase console o los SDK de Firebase Admin para Python y Node.js. Consulta la sección sobre cómo implementar y administrar modelos personalizados.

Después de agregar un modelo personalizado al proyecto de Firebase, podrás usar el nombre que especificaste para hacer referencia al modelo en tus apps. En cualquier momento, puedes implementar un nuevo modelo de TensorFlow Lite y descargarlo en los dispositivos de los usuarios llamando a getModel() (consulta a continuación).

2. Descarga el modelo en el dispositivo y, luego, inicializa un intérprete de TensorFlow Lite

Si quieres usar tu modelo de TensorFlow Lite en la app, primero utiliza el SDK de Firebase ML para descargar la versión más reciente del modelo en el dispositivo. Luego, crea una instancia de un intérprete de TensorFlow Lite con el modelo.

Para iniciar la descarga del modelo, llama al método getModel() del usuario que descargó el modelo, especifica el nombre que le asignaste al modelo cuando lo subiste, si quieres descargar siempre el último modelo y las condiciones en las que deseas permitir la descarga.

Puedes elegir entre tres comportamientos de descarga:

Tipo de descarga Descripción
LOCAL_MODEL Obtén el modelo local del dispositivo. Si no hay un modelo local disponible, el método se comporta como LATEST_MODEL. Usa este tipo de descarga si no quieres buscar las actualizaciones del modelo. Por ejemplo, si quieres usar Remote Config para recuperar nombres de modelos y siempre subes modelos con nombres nuevos (recomendado).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Obtén el modelo local del dispositivo y comienza a actualizarlo en segundo plano. Si no hay un modelo local disponible, el método se comporta como LATEST_MODEL.
LATEST_MODEL Obtén el modelo más reciente. Si el modelo local es la versión más reciente, se muestra ese modelo. De lo contrario, descarga el modelo más reciente. Este comportamiento se bloqueará hasta que se descargue la versión más reciente (no se recomienda). Usa este comportamiento solo si necesitas de forma explícita la versión más reciente.

Debes inhabilitar la funcionalidad relacionada con el modelo, por ejemplo, ocultar o inhabilitar parte de la IU, hasta que confirmes que se descargó el modelo.

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Muchas apps comienzan la tarea de descarga en su código de inicialización, pero puedes hacerlo en cualquier momento antes de usar el modelo.

3. Realiza inferencias sobre los datos de entrada

Obtén las formas de entrada y salida de tu modelo

El intérprete del modelo de TensorFlow Lite toma como entrada y produce como salida uno o más arrays multidimensionales. Estos arrays contienen valores byte, int, long o float. Antes de pasar datos a un modelo o utilizar su resultado, debes conocer el número y las dimensiones (“forma”) de los arrays que usa tu modelo.

Si creaste el modelo tú mismo o si el formato de entrada y salida del modelo está documentado, es posible que ya tengas esta información. Si no conoces la forma y el tipo de datos de la entrada y la salida del modelo, puedes usar el intérprete de TensorFlow Lite para inspeccionarlo. Por ejemplo:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Resultado de ejemplo:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Ejecuta el intérprete

Después de determinar el formato de entrada y salida del modelo, obtén los datos de entrada y realiza las transformaciones necesarias a fin de obtener una entrada con la forma correcta para tu modelo.

Por ejemplo, si tienes un modelo de clasificación de imágenes con una forma de entrada de valores de punto flotante [1 224 224 3], podrías generar una entrada ByteBuffer de un objeto Bitmap como se muestra en el siguiente ejemplo:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Luego, asigna un ByteBuffer lo suficientemente grande para contener los datos de salida del modelo y pasar el búfer de entrada y de salida al método run() del intérprete de TensorFlow Lite. Por ejemplo, para una forma de salida de valores de punto flotante [1 1000]:

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

La manera de utilizar el resultado depende del modelo que uses.

Por ejemplo, si realizas una clasificación, el paso siguiente puede ser asignar los índices del resultado a las etiquetas que representan:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Apéndice: Seguridad del modelo

Sin importar cómo pones a disposición tus modelos de TensorFlow Lite para Firebase ML, Firebase ML los almacena en el formato protobuf serializado estándar en almacenamiento local.

En teoría, eso significa que cualquier persona puede copiar tu modelo. Sin embargo, en la práctica, la mayoría de los modelos son tan específicos para la aplicación y ofuscados por las optimizaciones que el riesgo es comparable a que alguien de la competencia desensamble y vuelva a usar tu código. No obstante, debes estar al tanto de ese riesgo antes de usar un modelo personalizado en tu app.

En la API de Android nivel 21 (Lollipop) o posterior, el modelo se descarga en un directorio excluido de las copias de seguridad automáticas.

En una API de Android nivel 20 o anterior, el modelo se descarga en un directorio llamado com.google.firebase.ml.custom.models en el almacenamiento interno privado de la app. Si habilitas la copia de seguridad con BackupAgent, tienes la opción de excluir este directorio.