Używanie modelu TensorFlow Lite do wnioskowania za pomocą ML Kit na Androidzie

Za pomocą ML Kit możesz wykonywać wnioskowanie na urządzeniu za pomocą modelu TensorFlow Lite.

Ten interfejs API wymaga pakietu SDK Androida na poziomie 16 (Jelly Bean) lub nowszego.

Zanim zaczniesz

  1. Jeśli jeszcze tego nie zrobiono, dodaj Firebase do projektu na Androida.
  2. Dodaj zależności do bibliotek ML Kit na Androida do pliku Gradle modułu (na poziomie aplikacji) (zwykle app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
  3. Przekształć model TensorFlow, którego chcesz użyć, do formatu TensorFlow Lite. Zobacz: TOCO: narzędzie do optymalizacji TensorFlow Lite.

Hostowanie lub grupowanie modelu

Zanim użyjesz modelu TensorFlow Lite do wnioskowania w aplikacji, musisz udostępnić go pakietowi ML Kit. ML Kit może używać modeli TensorFlow Lite hostowanych zdalnie za pomocą Firebase, w pakiecie z binarną aplikacją lub w obu tych miejscach.

Przechowywanie modelu w Firebase umożliwia jego aktualizowanie bez wydawania nowej wersji aplikacji. Możesz też używać funkcji Remote ConfigA/B Testing, aby dynamicznie dostarczać różne modele różnym grupom użytkowników.

Jeśli zdecydujesz się udostępnić tylko model, hostując go w Firebase, a nie dołączając go do aplikacji, możesz zmniejszyć początkowy rozmiar pobierania aplikacji. Pamiętaj jednak, że jeśli model nie jest dołączony do aplikacji, żadne funkcje związane z modelem nie będą dostępne, dopóki aplikacja nie pobierze modelu po raz pierwszy.

Dzięki temu możesz mieć pewność, że funkcje ML w aplikacji będą działać, nawet jeśli model hostowany w Firebase jest niedostępny.

Hostowanie modeli w Firebase

Aby hostować model TensorFlow Lite w Firebase:

  1. W sekcji Zestaw MLFirebase konsoli kliknij kartę Niestandardowe.
  2. Kliknij Dodaj model niestandardowy (lub Dodaj inny model).
  3. Podaj nazwę, która będzie używana do identyfikowania modelu w projekcie Firebase, a potem prześlij plik modelu TensorFlow Lite (zazwyczaj kończy się na .tflite lub .lite).
  4. W pliku manifestu aplikacji zadeklaruj, że wymagane jest uprawnienie INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

Po dodaniu niestandardowego modelu do projektu Firebase możesz odwoływać się do niego w swoich aplikacjach, podając jego nazwę. W każdej chwili możesz przesłać nowy model TensorFlow Lite. Aplikacja pobierze go i zacznie go używać przy następnym uruchomieniu. Możesz zdefiniować warunki na urządzeniu, które aplikacja musi spełnić, aby spróbować zaktualizować model (patrz poniżej).

Pakowanie modeli z aplikacją

Aby dołączyć model TensorFlow Lite do aplikacji, skopiuj plik modelu (zwykle kończy się na .tflite lub .lite) do folderu assets/ aplikacji. (Możesz najpierw utworzyć folder, klikając app/ prawym przyciskiem myszy, a następnie Nowy > Folder > Folder zasobów).

Następnie dodaj do pliku build.gradle aplikacji ten ciąg, aby Gradle nie kompresował modeli podczas kompilowania aplikacji:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Plik modelu zostanie dodany do pakietu aplikacji i będzie dostępny dla ML Kit jako surowy zasób.

Wczytaj model

Aby korzystać z modelu TensorFlow Lite w aplikacji, najpierw skonfiguruj pakiet ML Kit, określając miejsca, w których model jest dostępny: zdalnie za pomocą Firebase, w pamięci lokalnej lub w obu tych miejscach. Jeśli określisz model lokalny i zdalny, możesz użyć modelu zdalnego, jeśli jest dostępny, albo wrócić do modelu przechowywanego lokalnie, jeśli model zdalny jest niedostępny.

Konfigurowanie modelu hostowanego w Firebase

Jeśli model jest hostowany w Firebase, utwórz obiekt FirebaseCustomRemoteModel, podając nazwę przypisaną do modelu podczas jego przesyłania:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Następnie uruchom zadanie pobierania modelu, określając warunki, w jakich chcesz zezwolić na pobieranie. Jeśli model nie jest dostępny na urządzeniu lub jeśli dostępna jest nowsza wersja modelu, zadanie asynchronicznie pobierze model z Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Wiele aplikacji uruchamia zadanie pobierania w kodzie inicjującym, ale możesz to zrobić w dowolnym momencie, zanim zaczniesz używać modelu.

Konfigurowanie modelu lokalnego

Jeśli model jest w pakiecie z aplikacją, utwórz obiekt FirebaseCustomLocalModel, podając nazwę pliku modelu TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Tworzenie interpretera na podstawie modelu

Po skonfigurowaniu źródeł modelu utwórz obiekt FirebaseModelInterpreter na podstawie jednego z nich.

Jeśli masz tylko model w pakiecie lokalnym, utwórz interpretera na podstawie obiektu FirebaseCustomLocalModel:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Jeśli model jest hostowany zdalnie, przed jego uruchomieniem musisz sprawdzić, czy został pobrany. Stan pobierania modelu możesz sprawdzić, korzystając z metody isModelDownloaded() menedżera modeli.

Musisz to potwierdzić tylko przed uruchomieniem interpretera, ale jeśli masz model hostowany zdalnie i model wbudowany lokalnie, warto wykonać tę weryfikację podczas tworzenia instancji interpretera modelu: utwórz interpretera z modelu zdalnego, jeśli został pobrany, a w przeciwnym razie – z modelu lokalnego.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Jeśli masz tylko model hostowany zdalnie, wyłącz funkcje związane z modelem (np. wygaszaj lub ukryj część interfejsu użytkownika), dopóki nie potwierdzisz, że model został pobrany. Możesz to zrobić, dołączając listenera do metody download() menedżera modelu:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Określ dane wejściowe i wyjściowe modelu

Następnie skonfiguruj formaty wejściowe i wyjściowe interpretera modelu.

Model TensorFlow Lite otrzymuje jako dane wejściowe i wydaje jako dane wyjściowe co najmniej jeden wielowymiarowy tablic. Te tablice zawierają wartości byte, int, long lub float. Musisz skonfigurować ML Kit, podając liczbę i wymiary („kształt”) macierzy, których używa Twój model.

Jeśli nie znasz kształtu i typu danych wejściowych i wyjściowych modelu, możesz użyć interpretera Pythona TensorFlow Lite, aby sprawdzić model. Przykład:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Po określeniu formatu danych wejściowych i wyjściowych modelu możesz skonfigurować interpreter modelu w aplikacji, tworząc obiekt FirebaseModelInputOutputOptions.

Na przykład model klasyfikacji obrazu o typie zmiennoprzecinkowym może otrzymać jako dane wejściowe tablicę Nx224x224x3 wartości float, która reprezentuje zbiór N obrazów 224 x 224 w 3 kanałach (RGB), i wygenerować jako dane wyjściowe listę 1000 wartości float, z których każda reprezentuje prawdopodobieństwo, że obraz należy do jednej z 1000 kategorii przewidywanych przez model.

W przypadku takiego modelu skonfigurujesz dane wejściowe i wyjściowe interpretera modelu w ten sposób:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Wykonywanie wnioskowania na podstawie danych wejściowych

Aby przeprowadzić wnioskowanie za pomocą modelu, pobierz dane wejściowe i przeprowadź na nich wszystkie przekształcenia niezbędne do uzyskania tablicy wejściowej o odpowiednim kształcie dla Twojego modelu.

Jeśli na przykład masz model klasyfikacji obrazów o kształcie wejściowym [1 224 224 3] wartości zmiennoprzecinkowych, możesz wygenerować tablicę wejściową z obiektu Bitmap, jak pokazano w tym przykładzie:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Następnie utwórz obiekt FirebaseModelInputs z danymi wejściowymi i przekaż go wraz ze specyfikacją danych wejściowych i wyjściowych modelu do metody runmodel interpreter:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Jeśli wywołanie się powiedzie, możesz uzyskać dane wyjściowe, wywołując metodę getOutput() obiektu przekazanego do odbiornika sukcesu. Przykład:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Sposób wykorzystania danych wyjściowych zależy od modelu.

Jeśli na przykład wykonujesz klasyfikację, możesz w następnym kroku przypisać indeksy wyników do reprezentowanych przez nie etykiet:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Załącznik: Bezpieczeństwo modeli

Niezależnie od tego, jak udostępniasz modele TensorFlow Lite interfejsowi ML Kit, ML Kit przechowuje je w standardowym formacie serializacji protobuf w pamięci lokalnej.

Teoretycznie oznacza to, że każdy może skopiować Twój model. W praktyce jednak większość modeli jest tak skomplikowana i zaszyfrowana przez optymalizacje, że ryzyko jest podobne do ryzyka związanego z rozkładaniem i wykorzystywaniem kodu przez konkurencję. Zanim jednak użyjesz w aplikacji modelu niestandardowego, pamiętaj o tym ryzyku.

Na Androidzie API 21 (Lollipop) i nowszych model jest pobierany do katalogu, który jest wykluczony z automatycznego tworzenia kopii zapasowej.

W przypadku poziomu interfejsu API 20 i starszych model jest pobierany do katalogu com.google.firebase.ml.custom.models w pamięci wewnętrznej prywatnej aplikacji. Jeśli włączysz tworzenie kopii zapasowych plików za pomocą BackupAgent, możesz wykluczyć ten katalog.