Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

Verwenden Sie ein benutzerdefiniertes TensorFlow Lite-Modell für Android

Wenn Ihre App benutzerdefinierte TensorFlow Lite- Modelle verwendet, können Sie Ihre Modelle mit Firebase ML bereitstellen. Durch die Bereitstellung von Modellen mit Firebase können Sie die anfängliche Downloadgröße Ihrer App reduzieren und die ML-Modelle Ihrer App aktualisieren, ohne eine neue Version Ihrer App zu veröffentlichen. Mit Remote Config und A / B-Tests können Sie verschiedene Modelle dynamisch für verschiedene Benutzergruppen bereitstellen.

TensorFlow Lite Modelle

TensorFlow Lite-Modelle sind ML-Modelle, die für die Ausführung auf Mobilgeräten optimiert sind. So erhalten Sie ein TensorFlow Lite-Modell:

Bevor Sie beginnen

  1. Wenn Sie dies noch nicht getan haben, fügen Sie Firebase zu Ihrem Android-Projekt hinzu .
  2. build.gradle in Ihrer build.gradle Datei auf Projektebene sicher, dass das Maven-Repository von Google sowohl in Ihrem buildscript als auch in Ihrem allprojects Abschnitt enthalten ist.
  3. Fügen Sie die Android-Bibliotheken Firebase ML und TensorFlow Lite zu Ihrer Gradle-Datei (normalerweise app/build.gradle ) Ihres Moduls (App-Ebene) app/build.gradle :
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.4'
      implementation 'org.tensorflow:tensorflow-lite:2.0.0'
    }
    
  4. Erklären Sie im Manifest Ihrer App, dass eine INTERNET-Berechtigung erforderlich ist:
    <uses-permission android:name="android.permission.INTERNET" />

1. Stellen Sie Ihr Modell bereit

Stellen Sie Ihre benutzerdefinierten TensorFlow-Modelle entweder über die Firebase-Konsole oder über die Firebase Admin Python- und Node.js-SDKs bereit. Siehe Bereitstellen und Verwalten von benutzerdefinierten Modellen .

Nachdem Sie Ihrem Firebase-Projekt ein benutzerdefiniertes Modell hinzugefügt haben, können Sie das Modell in Ihren Apps unter dem von Ihnen angegebenen Namen referenzieren. Sie können jederzeit ein neues TensorFlow Lite-Modell hochladen. Ihre App lädt das neue Modell herunter und verwendet es beim nächsten Neustart der App. Sie können die Gerätebedingungen definieren, die erforderlich sind, damit Ihre App versucht, das Modell zu aktualisieren (siehe unten).

2. Laden Sie das Modell auf das Gerät herunter

Um Ihr TensorFlow Lite-Modell in Ihrer App zu verwenden, verwenden Sie zunächst das Firebase ML SDK, um die neueste Version des Modells auf das Gerät herunterzuladen.

Rufen Sie zum Starten des Modelldownloads die download() -Methode des Modellmanagers auf und geben Sie den Namen an, den Sie dem Modell beim Hochladen zugewiesen haben, sowie die Bedingungen, unter denen Sie das Herunterladen zulassen möchten. Wenn sich das Modell nicht auf dem Gerät befindet oder eine neuere Version des Modells verfügbar ist, lädt die Task das Modell asynchron von Firebase herunter.

Sie sollten die modellbezogene Funktionalität deaktivieren, z. B. einen Teil Ihrer Benutzeroberfläche ausblenden oder ausblenden, bis Sie bestätigen, dass das Modell heruntergeladen wurde.

Java

FirebaseCustomRemoteModel remoteModel =
      new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Viele Apps starten die Download-Aufgabe in ihrem Initialisierungscode. Sie können dies jedoch jederzeit tun, bevor Sie das Modell verwenden müssen.

3. Initialisieren Sie einen TensorFlow Lite-Interpreter

Nachdem Sie das Modell auf das Gerät heruntergeladen haben, können Sie den Speicherort der Modelldatei getLatestModelFile() indem Sie die getLatestModelFile() -Methode des getLatestModelFile() aufrufen. Verwenden Sie diesen Wert, um einen TensorFlow Lite-Interpreter zu instanziieren:

Java

FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                }
            }
        });

Kotlin + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        }
    }

4. Inferenz auf Eingabedaten durchführen

Holen Sie sich die Eingabe- und Ausgabeformen Ihres Modells

Der TensorFlow Lite-Modellinterpreter verwendet als Eingabe ein oder mehrere mehrdimensionale Arrays als Ausgabe. Diese Arrays enthalten entweder byte , int , long oder float Werte. Bevor Sie Daten an ein Modell übergeben oder dessen Ergebnis verwenden können, müssen Sie die Anzahl und die Abmessungen ("Form") der von Ihrem Modell verwendeten Arrays kennen.

Wenn Sie das Modell selbst erstellt haben oder das Eingabe- und Ausgabeformat des Modells dokumentiert ist, verfügen Sie möglicherweise bereits über diese Informationen. Wenn Sie die Form und den Datentyp der Eingabe und Ausgabe Ihres Modells nicht kennen, können Sie Ihr Modell mit dem TensorFlow Lite-Interpreter überprüfen. Beispielsweise:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Beispielausgabe:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Führen Sie den Interpreter aus

Nachdem Sie das Format der Eingabe und Ausgabe Ihres Modells festgelegt haben, rufen Sie Ihre Eingabedaten ab und führen Sie alle Transformationen für die Daten durch, die erforderlich sind, um eine Eingabe mit der richtigen Form für Ihr Modell zu erhalten.

Wenn Sie beispielsweise ein ByteBuffer mit einer Eingabeform von [1 224 224 3] Gleitkommawerten haben, können Sie einen Eingabe- ByteBuffer aus einem Bitmap Objekt generieren, wie im folgenden Beispiel gezeigt:

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin + KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

ByteBuffer dann einen ByteBuffer groß genug ist, um die Ausgabe des Modells aufzunehmen, und übergeben Sie den Eingabepuffer und den Ausgabepuffer an die run() -Methode des TensorFlow Lite-Interpreters. Zum Beispiel für eine Ausgabeform von [1 1000] Gleitkommawerten:

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin + KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Wie Sie die Ausgabe verwenden, hängt vom verwendeten Modell ab.

Wenn Sie beispielsweise als nächsten Schritt eine Klassifizierung durchführen, können Sie die Indizes des Ergebnisses den Beschriftungen zuordnen, die sie darstellen:

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin + KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Anhang: Auf ein lokal gebündeltes Modell zurückgreifen

Wenn Sie Ihr Modell mit Firebase hosten, sind keine modellbezogenen Funktionen verfügbar, bis Ihre App das Modell zum ersten Mal herunterlädt. Für einige Apps ist dies möglicherweise in Ordnung. Wenn Ihr Modell jedoch die Kernfunktionalität aktiviert, möchten Sie möglicherweise eine Version Ihres Modells mit Ihrer App bündeln und die beste verfügbare Version verwenden. Auf diese Weise können Sie sicherstellen, dass die ML-Funktionen Ihrer App funktionieren, wenn das von Firebase gehostete Modell nicht verfügbar ist.

So bündeln Sie Ihr TensorFlow Lite-Modell mit Ihrer App:

  1. Kopieren Sie die Modelldatei (normalerweise mit .tflite oder .lite ) in die assets/ Ordner Ihrer App. (Möglicherweise müssen Sie den Ordner zuerst erstellen, indem Sie mit der rechten Maustaste auf die app/ Ordner klicken und dann auf Neu> Ordner> Assets-Ordner klicken.)

  2. Fügen Sie der build.gradle Datei Ihrer App build.gradle , um sicherzustellen, dass Gradle die Modelle beim build.gradle der App nicht komprimiert:

    android {
    
        // ...
    
        aaptOptions {
            noCompress "tflite", "lite"
        }
    }
    

Verwenden Sie dann das lokal gebündelte Modell, wenn das gehostete Modell nicht verfügbar ist:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                } else {
                    try {
                        InputStream inputStream = getAssets().open("your_fallback_model.tflite");
                        byte[] model = new byte[inputStream.available()];
                        inputStream.read(model);
                        ByteBuffer buffer = ByteBuffer.allocateDirect(model.length)
                                .order(ByteOrder.nativeOrder());
                        buffer.put(model);
                        interpreter = new Interpreter(buffer);
                    } catch (IOException e) {
                        // File not found?
                    }
                }
            }
        });

Kotlin + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        } else {
            val model = assets.open("your_fallback_model.tflite").readBytes()
            val buffer = ByteBuffer.allocateDirect(model.size).order(ByteOrder.nativeOrder())
            buffer.put(model)
            interpreter = Interpreter(buffer)
        }
    }

Anhang: Modellsicherheit

Unabhängig davon, wie Sie Ihre TensorFlow Lite-Modelle Firebase ML zur Verfügung stellen, speichert Firebase ML sie im standardmäßigen serialisierten Protobuf-Format im lokalen Speicher.

Theoretisch bedeutet dies, dass jeder Ihr Modell kopieren kann. In der Praxis sind die meisten Modelle jedoch so anwendungsspezifisch und durch Optimierungen verschleiert, dass das Risiko dem von Wettbewerbern ähnelt, die Ihren Code zerlegen und wiederverwenden. Sie sollten sich dieses Risikos jedoch bewusst sein, bevor Sie ein benutzerdefiniertes Modell in Ihrer App verwenden.

Ab Android API Level 21 (Lollipop) und neuer wird das Modell in ein Verzeichnis heruntergeladen, das von der automatischen Sicherung ausgeschlossen ist .

com.google.firebase.ml.custom.models Android API Level 20 wird das Modell in ein Verzeichnis mit dem Namen com.google.firebase.ml.custom.models im app-privaten internen Speicher heruntergeladen. Wenn Sie die BackupAgent mit BackupAgent , können Sie dieses Verzeichnis ausschließen.