Używanie niestandardowego modelu TensorFlow Lite na Androidzie

Jeśli Twoja aplikacja korzysta z niestandardowych modeli TensorFlow Lite, możesz wdrażać je za pomocą Firebase ML. Wdrażając modele za pomocą Firebase, możesz zmniejszyć początkowy rozmiar pobierania aplikacji i zaktualizować modele ML aplikacji bez wydawania nowej wersji aplikacji. Dzięki Remote ConfigA/B Testing możesz dynamicznie dostarczać różne modele różnym grupom użytkowników.

modele TensorFlow Lite,

Modele TensorFlow Lite to modele ML zoptymalizowane pod kątem działania na urządzeniach mobilnych. Aby uzyskać model TensorFlow Lite:

Zanim zaczniesz

  1. Jeśli jeszcze tego nie zrobiono, dodaj Firebase do projektu na Androida.
  2. pliku Gradle modułu (na poziomie aplikacji) (zwykle <project>/<app-module>/build.gradle.kts lub <project>/<app-module>/build.gradle) dodaj zależność z biblioteką pobierania modeli Firebase ML na Androida. Zalecamy używanie Firebase Android BoM do kontrolowania wersji biblioteki.

    W ramach konfigurowania narzędzia do pobierania modeli Firebase ML musisz też dodać do aplikacji pakiet SDK TensorFlow Lite.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:33.7.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Dzięki użyciu Firebase Android BoMaplikacja zawsze będzie używać zgodnych wersji bibliotek Firebase na Androida.

    (Alternatywnie)  Dodaj zależności biblioteki Firebase bez używania pakietu BoM

    Jeśli zdecydujesz się nie używać Firebase BoM, musisz podać każdą wersję biblioteki Firebase w jej wierszu zależności.

    Jeśli w aplikacji używasz kilku bibliotek Firebase, zdecydowanie zalecamy korzystanie z BoM do zarządzania wersjami bibliotek. Dzięki temu wszystkie wersje będą ze sobą zgodne.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:25.0.1")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    Szukasz modułu biblioteki dla Kotlina? Od października 2023 r. (Firebase BoM 32.5.0) deweloperzy Kotlina i Java mogą korzystać z głównego modułu biblioteki (szczegółowe informacje znajdziesz w często zadawanych pytaniach dotyczących tej inicjatywy).
  3. W pliku manifestu aplikacji zadeklaruj, że wymagane jest uprawnienie INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Wdrażanie modelu

Wdróż niestandardowe modele TensorFlow za pomocą konsoli Firebase lub pakietów Firebase Admin SDK w Pythonie i Node.js. Zobacz Wdrażanie modeli niestandardowych i zarządzanie nimi.

Po dodaniu niestandardowego modelu do projektu Firebase możesz odwoływać się do niego w swoich aplikacjach, podając jego nazwę. W dowolnym momencie możesz wdrożyć nowy model TensorFlow Lite i pobrać go na urządzenia użytkowników, wywołując funkcję getModel() (patrz poniżej).

2. Pobierz model na urządzenie i inicjuj interpreter TensorFlow Lite

Aby używać modelu TensorFlow Lite w aplikacji, najpierw pobierz najnowszą wersję modelu na urządzenie za pomocą pakietu SDK Firebase ML. Następnie utwórz instancję interpretera TensorFlow Lite z modelem.

Aby rozpocząć pobieranie modelu, wywołaj metodę getModel() pobierania modelu, podając nazwę przypisaną do modelu podczas jego przesyłania, określając, czy chcesz zawsze pobierać najnowszy model, oraz warunki, na jakich chcesz zezwolić na pobieranie.

Możesz wybrać jeden z 3 tych sposobów pobierania:

Typ pobierania Opis
LOCAL_MODEL Pobierz lokalny model z urządzenia. Jeśli nie ma dostępnego modelu lokalnego, ta funkcja działa jak LATEST_MODEL. Użyj tego typu pobierania, jeśli nie chcesz sprawdzać aktualizacji modelu. Na przykład: używasz Zdalnej konfiguracji do pobierania nazw modeli i zawsze przesyłasz modele pod nowymi nazwami (zalecane).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Pobierz lokalny model z urządzenia i zacznij aktualizować go w tle. Jeśli nie ma dostępnego modelu lokalnego, ta funkcja działa jak LATEST_MODEL.
LATEST_MODEL Pobierz najnowszy model. Jeśli lokalny model jest najnowszej wersji, zwraca lokalny model. W przeciwnym razie pobierz najnowszy model. Ta funkcja będzie blokować dostęp do aplikacji, dopóki nie zostanie pobrana najnowsza wersja (nie jest to zalecane). Używaj tego zachowania tylko w przypadkach, gdy wyraźnie potrzebujesz najnowszej wersji.

Do czasu potwierdzenia pobrania modelu należy wyłączyć funkcje związane z modelem, np. wyłączyć lub ukryć część interfejsu użytkownika.

Kotlin

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Wiele aplikacji uruchamia zadanie pobierania w kodzie inicjującym, ale możesz to zrobić w dowolnym momencie, zanim zaczniesz używać modelu.

3. Wykonywanie wnioskowania na podstawie danych wejściowych

Pobieranie kształtów danych wejściowych i wyjściowych modelu

Interpreter modelu TensorFlow Lite przyjmuje jako dane wejściowe i zwraca jako dane wyjściowe co najmniej 1 wielowymiarową tablicę. Te tablice zawierają wartości byte, int, long lub float. Zanim przekażesz dane do modelu lub użyjesz jego wyniku, musisz znać liczbę i wymiary („kształt”) tablic, których używa Twój model.

Jeśli model został utworzony przez Ciebie lub jeśli formaty danych wejściowych i wyjściowych są udokumentowane, te informacje mogą być już dostępne. Jeśli nie znasz kształtu ani typu danych wejściowych i wyjściowych modelu, możesz użyć interpretera TensorFlow Lite, aby sprawdzić model. Przykład:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Przykładowe dane wyjściowe:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Uruchamianie interpretera

Po określeniu formatu danych wejściowych i wyjściowych modelu pobierz dane wejściowe i przeprowadź na nich wszystkie niezbędne transformacje, aby uzyskać dane wejściowe o odpowiednim kształcie dla modelu.

Jeśli np. masz model klasyfikacji obrazów o kształcie danych wejściowych [1 224 224 3] o wartościach zmiennoprzecinkowych, możesz wygenerować dane wejściowe ByteBuffer z obiektu Bitmap, jak pokazano w tym przykładzie:

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Następnie przydziel ByteBuffer wystarczająco duży, aby pomieścić dane wyjściowe modelu, i przekaż bufor wejściowy oraz bufor wyjściowy do metody run() interpretera TensorFlow Lite. Na przykład w przypadku kształtu wyjściowego [1 1000] wartości zmiennoprzecinkowych:

Kotlin

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Sposób wykorzystania danych wyjściowych zależy od modelu.

Jeśli na przykład wykonujesz klasyfikację, możesz w następnym kroku przypisać indeksy wyników do reprezentowanych przez nie etykiet:

Kotlin

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Załącznik: Bezpieczeństwo modeli

Niezależnie od tego, jak udostępniasz modele TensorFlow Lite aplikacji Firebase ML, Firebase ML zapisuje je w standardowym formacie serializacji protobuf w pamięci lokalnej.

Teoretycznie oznacza to, że każdy może skopiować Twój model. W praktyce jednak większość modeli jest tak skomplikowana i zaszyfrowana przez optymalizacje, że ryzyko jest podobne do ryzyka związanego z rozkładaniem i wykorzystywaniem kodu przez konkurencję. Zanim jednak użyjesz w aplikacji modelu niestandardowego, pamiętaj o tym ryzyku.

Na Androidzie API 21 (Lollipop) i nowszych model jest pobierany do katalogu, który jest wykluczony z automatycznego tworzenia kopii zapasowej.

W przypadku poziomu interfejsu API 20 i starszych model jest pobierany do katalogu com.google.firebase.ml.custom.models w pamięci wewnętrznej prywatnej aplikacji. Jeśli włączysz tworzenie kopii zapasowych plików za pomocą BackupAgent, możesz wykluczyć ten katalog.