TensorFlow Lite-Modell für Inferenz mit ML Kit unter Android verwenden

Mit ML Kit können Sie On-Device-Inferenzen mit einem TensorFlow Lite-Modell ausführen.

Für diese API ist Android SDK-Level 16 (Jelly Bean) oder höher erforderlich.

Hinweis

  1. Fügen Sie Ihrem Android-Projekt Firebase hinzu, falls noch nicht geschehen.
  2. Fügen Sie der Gradle-Datei des Moduls (auf Anwendungsebene, in der Regel app/build.gradle) die Abhängigkeiten für die ML Kit-Android-Bibliotheken hinzu:
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
  3. Konvertieren Sie das gewünschte TensorFlow-Modell in das TensorFlow Lite-Format. Weitere Informationen finden Sie unter TOCO: TensorFlow Lite Optimizing Converter.

Modell hosten oder bündeln

Bevor Sie ein TensorFlow Lite-Modell für die Inferenz in Ihrer App verwenden können, müssen Sie es für ML Kit verfügbar machen. ML Kit kann TensorFlow Lite-Modelle verwenden, die mit Firebase aus der Ferne gehostet, im App-Binärcode gebündelt oder beides sind.

Wenn Sie ein Modell in Firebase hosten, können Sie es aktualisieren, ohne eine neue App-Version zu veröffentlichen. Mit Remote Config und A/B Testing können Sie unterschiedliche Modelle dynamisch für unterschiedliche Nutzergruppen bereitstellen.

Wenn Sie das Modell nur bei Firebase hosten und nicht mit Ihrer App bündeln, können Sie die ursprüngliche Downloadgröße Ihrer App verringern. Beachten Sie jedoch, dass alle modellverbundenen Funktionen erst verfügbar sind, wenn Ihre App das Modell zum ersten Mal herunterlädt.

Wenn Sie Ihr Modell mit Ihrer App bündeln, können Sie dafür sorgen, dass die ML-Funktionen Ihrer App auch dann funktionieren, wenn das in Firebase gehostete Modell nicht verfügbar ist.

Modelle in Firebase hosten

So hosten Sie Ihr TensorFlow Lite-Modell auf Firebase:

  1. Klicken Sie in der Firebase-Konsole im Bereich ML Kit auf den Tab Benutzerdefiniert.
  2. Klicken Sie auf Benutzerdefiniertes Modell hinzufügen oder Weitere Modelle hinzufügen.
  3. Geben Sie einen Namen an, der zum Identifizieren Ihres Modells in Ihrem Firebase-Projekt verwendet wird, und laden Sie dann die TensorFlow Lite-Modelldatei hoch, die in der Regel auf .tflite oder .lite endet.
  4. Deklarieren Sie im Manifest Ihrer App, dass die INTERNET-Berechtigung erforderlich ist:
    <uses-permission android:name="android.permission.INTERNET" />

Nachdem Sie Ihrem Firebase-Projekt ein benutzerdefiniertes Modell hinzugefügt haben, können Sie in Ihren Apps auf das Modell mit dem angegebenen Namen verweisen. Sie können jederzeit ein neues TensorFlow Lite-Modell hochladen. Ihre App lädt das neue Modell herunter und verwendet es beim nächsten Neustart der App. Sie können die Gerätebedingungen festlegen, die erforderlich sind, damit Ihre App versucht, das Modell zu aktualisieren (siehe unten).

Modelle mit einer App bündeln

Wenn Sie Ihr TensorFlow Lite-Modell mit Ihrer App bündeln möchten, kopieren Sie die Modelldatei (normalerweise endet sie auf .tflite oder .lite) in den Ordner assets/ Ihrer App. Möglicherweise musst du den Ordner zuerst erstellen. Klicke dazu mit der rechten Maustaste auf den Ordner app/ und dann auf Neu > Ordner > Assets-Ordner.

Fügen Sie dann der build.gradle-Datei Ihrer App Folgendes hinzu, damit Gradle die Modelle beim Erstellen der App nicht komprimiert:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Die Modelldatei wird in das App-Paket aufgenommen und ML Kit als Roh-Asset zur Verfügung gestellt.

Modell laden

Wenn Sie Ihr TensorFlow Lite-Modell in Ihrer App verwenden möchten, müssen Sie zuerst ML Kit mit den Speicherorten konfigurieren, an denen Ihr Modell verfügbar ist: per Remotezugriff über Firebase, im lokalen Speicher oder beides. Wenn Sie sowohl ein lokales als auch ein Remote-Modell angeben, können Sie das Remote-Modell verwenden, wenn es verfügbar ist. Andernfalls wird das lokal gespeicherte Modell verwendet.

Ein von Firebase gehostetes Modell konfigurieren

Wenn Sie Ihr Modell bei Firebase gehostet haben, erstellen Sie ein FirebaseCustomRemoteModel-Objekt und geben Sie den Namen an, den Sie dem Modell beim Hochladen zugewiesen haben:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Starten Sie dann die Aufgabe zum Herunterladen des Modells und geben Sie die Bedingungen an, unter denen der Download zulässig sein soll. Wenn das Modell nicht auf dem Gerät vorhanden ist oder eine neuere Version des Modells verfügbar ist, wird es von der Aufgabe asynchron von Firebase heruntergeladen:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Viele Apps starten die Downloadaufgabe in ihrem Initialisierungscode, Sie können dies aber auch jederzeit tun, bevor Sie das Modell verwenden müssen.

Lokales Modell konfigurieren

Wenn Sie das Modell mit Ihrer App gebündelt haben, erstellen Sie ein FirebaseCustomLocalModel-Objekt und geben Sie den Dateinamen des TensorFlow Lite-Modells an:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Interpreter aus Ihrem Modell erstellen

Nachdem Sie Ihre Modellquellen konfiguriert haben, erstellen Sie ein FirebaseModelInterpreter-Objekt aus einer der Quellen.

Wenn Sie nur ein lokal gebündeltes Modell haben, erstellen Sie einfach einen Interpreter aus Ihrem FirebaseCustomLocalModel-Objekt:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Wenn Sie ein extern gehostetes Modell haben, müssen Sie prüfen, ob es heruntergeladen wurde, bevor Sie es ausführen. Sie können den Status der Modelldownloadaufgabe mit der isModelDownloaded()-Methode des Modellmanagers prüfen.

Sie müssen dies zwar nur vor dem Ausführen des Interpreters bestätigen, wenn Sie jedoch sowohl ein remote gehostetes Modell als auch ein lokal gebündeltes Modell haben, kann es sinnvoll sein, diese Prüfung beim Instanziieren des Modell-Interpreters durchzuführen: Erstellen Sie einen Interpreter aus dem Remote-Modell, wenn es heruntergeladen wurde, andernfalls aus dem lokalen Modell.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Wenn Sie nur ein extern gehostetes Modell haben, sollten Sie modellverwandte Funktionen deaktivieren, z. B. einen Teil der Benutzeroberfläche grau ausblenden oder ausblenden, bis Sie bestätigen, dass das Modell heruntergeladen wurde. Dazu fügen Sie der download()-Methode des Modellmanagers einen Listener hinzu:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Eingabe und Ausgabe des Modells angeben

Konfigurieren Sie als Nächstes die Eingabe- und Ausgabeformate des Modellinterpreters.

Ein TensorFlow Lite-Modell nimmt ein oder mehrere mehrdimensionale Arrays als Eingabe und gibt ein oder mehrere mehrdimensionale Arrays als Ausgabe zurück. Diese Arrays enthalten entweder byte-, int-, long- oder float-Werte. Sie müssen ML Kit mit der Anzahl und den Dimensionen („Form“) der Arrays konfigurieren, die in Ihrem Modell verwendet werden.

Wenn Sie die Form und den Datentyp der Eingabe und Ausgabe Ihres Modells nicht kennen, können Sie Ihr Modell mit dem TensorFlow Lite Python-Interpreter untersuchen. Beispiel:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Nachdem Sie das Format der Eingabe und Ausgabe Ihres Modells festgelegt haben, können Sie den Modellinterpreter Ihrer App konfigurieren, indem Sie ein FirebaseModelInputOutputOptions-Objekt erstellen.

Ein Modell zur Bildklassifizierung mit Gleitkommazahlen könnte beispielsweise ein N × 224 × 224 × 3-Array von float-Werten als Eingabe annehmen, das einen Batch von N × 224 × 224 dreikanaligen (RGB) Bildern darstellt, und als Ausgabe eine Liste von 1.000 float-Werten generieren, die jeweils die Wahrscheinlichkeit darstellen, dass das Bild zu einer der 1.000 Kategorien gehört, die das Modell vorhersagt.

Für ein solches Modell konfigurieren Sie die Eingabe und Ausgabe des Modellinterpreters wie unten gezeigt:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Inferenzen auf Eingabedaten durchführen

Um schließlich Inferenzen mit dem Modell durchzuführen, rufen Sie Ihre Eingabedaten ab und führen Sie alle Transformationen an den Daten durch, die erforderlich sind, um ein Eingabearray der richtigen Form für Ihr Modell zu erhalten.

Wenn Sie beispielsweise ein Bildklassifizierungsmodell mit der Eingabeform [1 224 224 3] mit Gleitkommawerten haben, können Sie ein Eingabearray aus einem Bitmap-Objekt generieren, wie im folgenden Beispiel gezeigt:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Erstellen Sie dann ein FirebaseModelInputs-Objekt mit Ihren Eingabedaten und übergeben Sie es zusammen mit der Eingabe- und Ausgabespezifikation des Modells an die run-Methode des Modellinterpreters:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Wenn der Aufruf erfolgreich ist, kannst du die Ausgabe durch Aufrufen der getOutput()-Methode des Objekts abrufen, das an den Erfolgs-Listener übergeben wird. Beispiel:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Wie Sie die Ausgabe verwenden, hängt vom verwendeten Modell ab.

Wenn Sie beispielsweise eine Klassifizierung durchführen, können Sie als nächsten Schritt die Indizes des Ergebnisses den entsprechenden Labels zuordnen:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Anhang: Modellsicherheit

Unabhängig davon, wie Sie Ihre TensorFlow Lite-Modelle für ML Kit verfügbar machen, speichert ML Kit sie im standardmäßigen serialisierten Protobuf-Format im lokalen Speicher.

Theoretisch kann also jeder Ihr Modell kopieren. In der Praxis sind die meisten Modelle jedoch so anwendungsspezifisch und durch Optimierungen verschleiert, dass das Risiko ähnlich hoch ist wie bei der Deaktivierung und Wiederverwendung Ihres Codes durch Mitbewerber. Sie sollten sich jedoch dieses Risiko bewusst machen, bevor Sie ein benutzerdefiniertes Modell in Ihrer App verwenden.

Ab Android API-Level 21 (Lollipop) wird das Modell in ein Verzeichnis heruntergeladen, das von der automatischen Sicherung ausgeschlossen ist.

Bei der Android API-Ebene 20 und niedriger wird das Modell in den privaten internen Speicher der App in ein Verzeichnis mit dem Namen com.google.firebase.ml.custom.models heruntergeladen. Wenn Sie die Dateisicherung mit BackupAgent aktiviert haben, können Sie dieses Verzeichnis ausschließen.