Używanie modelu TensorFlow Lite do wnioskowania za pomocą ML Kit na Androidzie

Możesz użyć ML Kit do wnioskowania na urządzeniu za pomocą Model TensorFlow Lite.

Ten interfejs API wymaga pakietu SDK na Androida na poziomie 16 (Jelly Bean) lub nowszym.

Zanim zaczniesz

  1. Jeśli jeszcze nie masz tego za sobą, dodaj Firebase do swojego projektu na Androida.
  2. Dodaj do modułu zależności między bibliotekami ML Kit na Androida Plik Gradle (na poziomie aplikacji) (zwykle app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
  3. Przekonwertuj model TensorFlow, którego chcesz używać, na format TensorFlow Lite. Zobacz TOCO: TensorFlow Lite Optimizing Converter

Hostowanie lub pakietowanie modelu

Zanim użyjesz modelu TensorFlow Lite do wnioskowania w aplikacji, musi udostępnić model ML Kit. ML Kit może używać TensorFlow Lite modele hostowane zdalnie przez Firebase, w pakiecie z plikiem binarnym aplikacji lub oba.

Hostując model w Firebase, możesz go aktualizować bez konieczności publikowania nowej wersji aplikacji. Możesz używać Remote Config i A/B Testing do dynamicznie udostępniać różne modele różnym grupom użytkowników.

Jeśli zdecydujesz się udostępniać model tylko poprzez hosting w Firebase, a nie pakietu z aplikacją, możesz zmniejszyć początkowy rozmiar pobieranej aplikacji. Pamiętaj jednak, że jeśli do aplikacji nie dołączony jest model, funkcje związane z modelem będą dostępne dopiero po pobraniu przez aplikację z użyciem modelu po raz pierwszy.

Jeśli połączysz model z aplikacją, będziesz mieć pewność, że funkcje ML w aplikacji będą działać. działają też wtedy, gdy model hostowany przez Firebase jest niedostępny.

Hostowanie modeli w Firebase

Aby hostować model TensorFlow Lite w Firebase:

  1. W sekcji ML Kit w konsoli Firebase kliknij kartę Niestandardowe.
  2. Kliknij Dodaj model niestandardowy (lub Dodaj kolejny model).
  3. Podaj nazwę, która będzie używana do identyfikowania Twojego modelu w Firebase projektu, a następnie prześlij plik modelu TensorFlow Lite (zwykle kończący się .tflite lub .lite).
  4. W pliku manifestu aplikacji zadeklaruj, że wymagane są uprawnienia INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

Po dodaniu do projektu Firebase modelu niestandardowego możesz się odwoływać do w swoich aplikacjach o podanej przez Ciebie nazwie. W każdej chwili możesz przesłać nowego modelu TensorFlow Lite, a aplikacja go pobierze będzie można go używać po ponownym uruchomieniu aplikacji. Możesz określić, warunki wymagane do aktualizacji modelu przez aplikację (patrz poniżej).

Połącz modele z aplikacją

Aby połączyć model TensorFlow Lite z aplikacją, skopiuj plik modelu (zwykle o numerze kończącym się cyframi .tflite lub .lite) do folderu assets/ aplikacji. (Może być konieczne aby najpierw utworzyć folder, klikając prawym przyciskiem myszy folder app/, a następnie klikając Nowe > Folder > Folder Zasoby).

Następnie dodaj ten kod do pliku build.gradle aplikacji, aby upewnić się, że Gradle Nie kompresuje modeli podczas tworzenia aplikacji:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Plik z modelem zostanie dołączony do pakietu aplikacji i będzie dostępny dla ML Kit jako nieprzetworzony zasób.

Wczytaj model

Aby użyć modelu TensorFlow Lite w aplikacji, najpierw skonfiguruj ML Kit za pomocą w lokalizacjach, w których Twój model jest dostępny: zdalnie za pomocą Firebase, pamięci lokalnej lub obu tych metod. Jeśli określisz model lokalny i zdalny, możesz użyć modelu zdalnego, o ile jest dostępny, i wrócić do model przechowywany lokalnie, jeśli model zdalny jest niedostępny.

Konfigurowanie modelu hostowanego w Firebase

Jeśli model był hostowany w Firebase, utwórz FirebaseCustomRemoteModel wraz z nazwą przypisaną do modelu podczas jego przesyłania:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Następnie rozpocznij zadanie pobierania modelu, określając warunki, które którzy chcą zezwolić na pobieranie. Jeśli nie ma modelu na urządzeniu lub jest on nowszy gdy dostępna będzie wersja modelu, zadanie asynchronicznie pobierze model z Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Wiele aplikacji rozpoczyna zadanie pobierania w kodzie inicjowania, ale możesz to zrobić. więc w dowolnym momencie przed użyciem modelu.

Konfigurowanie modelu lokalnego

Jeśli pakiet został połączony z aplikacją, utwórz FirebaseCustomLocalModel obiektu, który określa nazwę pliku modelu TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Tworzenie interpretera na podstawie modelu

Po skonfigurowaniu źródeł modelu utwórz FirebaseModelInterpreter lub obiektu z jednego z nich.

Jeśli masz tylko model dołączony lokalnie, utwórz tłumacza na podstawie FirebaseCustomLocalModel obiekt:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Jeśli masz model hostowany zdalnie, musisz sprawdzić, czy został pobrane przed uruchomieniem. Stan pobierania modelu możesz sprawdzić za pomocą metody isModelDownloaded() menedżera modeli.

Mimo że przed uruchomieniem tłumaczenia rozmowy trzeba to potwierdzić, korzystają zarówno z modelu hostowanego zdalnie, jak i z pakietu lokalnego, może to sprawić, warto przeprowadzić tę kontrolę przy tworzeniu instancji interpretera modelu: utwórz z tłumacza zdalnego z modelu zdalnego, jeśli został on pobrany, oraz z lokalnego w inny sposób.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Jeśli masz tylko model hostowany zdalnie, wyłącz powiązany z nim model funkcji – np. wyszarzenia lub ukrycia części interfejsu – potwierdzasz, że model został pobrany. Aby to zrobić, dołącz detektor do metody download() menedżera modeli:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Określ dane wejściowe i wyjściowe modelu

Następnie skonfiguruj formaty wejściowe i wyjściowe interpretera modelu.

Model TensorFlow Lite przyjmuje jako dane wejściowe i generuje jako dane wyjściowe co najmniej wielowymiarowych tablic. Te tablice zawierają: byte, Wartości int, long lub float. Musisz skonfiguruj w ML Kit liczbę i wymiary („kształt”) tablic przez model.

Jeśli nie znasz kształtu i typu danych wejściowych i wyjściowych modelu, możesz użyć interpretera TensorFlow Lite Python do sprawdzenia modelu. Przykład:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Po określeniu formatu danych wejściowych i wyjściowych modelu możesz skonfiguruj interpreter modelu aplikacji, tworząc FirebaseModelInputOutputOptions.

Na przykład model klasyfikacji obrazów zmiennoprzecinkowych może przyjąć jako dane wejściowe Tablica Nx224 x 224 x 3 z wartościami float, reprezentująca grupę N Obrazy 3-kanałowe (RGB), 224 x 224, jako wynik wyjściowy 1000 wartości float, każda reprezentująca prawdopodobieństwo, do którego należy obraz jedną z 1000 kategorii prognozowanych przez model.

W przypadku takiego modelu trzeba skonfigurować dane wejściowe i wyjściowe interpretera jak poniżej:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Przeprowadź wnioskowanie na danych wejściowych

Aby przeprowadzić wnioskowanie przy użyciu modelu, pobierz dane wejściowe i wykonaj wszystkich przekształceń danych, które są niezbędne do uzyskania tablicy wejściowej funkcji odpowiedni kształt do modelu.

Jeśli na przykład masz model klasyfikacji obrazów o wejściowym kształcie [1 224 224 3] wartości zmiennoprzecinkowych można wygenerować tablicę wejściową z Bitmap zgodnie z poniższym przykładem:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Następnie utwórz obiekt FirebaseModelInputs ze swoim danych wejściowych oraz przekazywania ich wraz ze specyfikacją danych wejściowych i wyjściowych modelu do funkcji Metoda run interpretatora modelu:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Jeśli wywołanie się powiedzie, możesz uzyskać dane wyjściowe, wywołując metodę getOutput() obiektu, który jest przekazywany do detektora sukcesu. Przykład:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Sposób wykorzystania danych wyjściowych zależy od używanego modelu.

Jeśli na przykład przeprowadzasz klasyfikację, kolejnym krokiem może być zmapuj indeksy wyników na etykiety, które reprezentują:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Dodatek: zabezpieczenia modelu

Niezależnie od tego, jak udostępnisz swoje modele TensorFlow Lite ML Kit przechowuje je w standardowym zserializowanym formacie protokołu w formacie pamięci lokalnej.

Teoretycznie oznacza to, że każdy może skopiować Twój model. Pamiętaj jednak: W praktyce większość modeli jest specyficzna dla danej aplikacji i pod kątem podobnych optymalizacji, jakie stwarzają konkurencji, demontaż ponownego wykorzystania kodu. Musisz jednak wiedzieć o tym ryzyku, zanim zaczniesz niestandardowy model w swojej aplikacji.

W przypadku interfejsu API Androida na poziomie 21 (Lollipop) lub nowszym model jest pobierany do katalogu, który jest wykluczono z automatycznej kopii zapasowej.

W przypadku interfejsu Android API na poziomie 20 lub starszym model jest pobierany do katalogu z nazwą com.google.firebase.ml.custom.models w sekcji prywatnej w aplikacji pamięci wewnętrznej. Jeśli masz włączoną kopię zapasową plików za pomocą usługi BackupAgent, możesz wykluczyć ten katalog.