Używanie modelu TensorFlow Lite do wnioskowania za pomocą ML Kit na Androidzie

Możesz użyć ML Kit do wnioskowania na urządzeniu za pomocą modelu TensorFlow Lite.

Ten interfejs API wymaga pakietu SDK na Androida na poziomie 16 (Jelly Bean) lub nowszym.

Zanim zaczniesz

  1. Dodaj Firebase do swojego projektu Android, chyba że masz to już za sobą.
  2. Dodaj zależności bibliotek ML Kit na Androida do pliku Gradle modułu (na poziomie aplikacji) (zwykle app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. Przekonwertuj model TensorFlow, którego chcesz używać, na format TensorFlow Lite. Patrz TOCO: TensorFlow Lite Optimizing Converter.

Hostowanie lub pakietowanie modelu

Zanim użyjesz w aplikacji modelu TensorFlow Lite do wnioskowania, musisz go udostępnić ML Kit. ML Kit może używać modeli TensorFlow Lite hostowanych zdalnie za pomocą Firebase, w pakiecie z plikiem binarnym aplikacji lub z użyciem obu tych metod.

Po hostowaniu modelu w Firebase możesz go aktualizować bez publikowania nowej wersji aplikacji. Możesz też korzystać ze Zdalnej konfiguracji i Testów A/B do dynamicznego udostępniania różnych modeli różnym zbiorom użytkowników.

Jeśli zdecydujesz się udostępniać model tylko przez hosting w Firebase, a nie spakować go z aplikacją, możesz zmniejszyć początkowy rozmiar pobieranej aplikacji. Pamiętaj jednak, że jeśli model nie jest połączony z aplikacją, funkcje związane z modelem nie będą dostępne, dopóki aplikacja nie pobierze modelu po raz pierwszy.

Jeśli połączysz model z aplikacją, będziesz mieć pewność, że jej funkcje systemów uczących się będą nadal działać, gdy model hostowany w Firebase będzie niedostępny.

Hostowanie modeli w Firebase

Aby hostować model TensorFlow Lite w Firebase:

  1. W sekcji ML Kit w konsoli Firebase kliknij kartę Niestandardowe.
  2. Kliknij Dodaj model niestandardowy (lub Dodaj kolejny model).
  3. Podaj nazwę, która będzie używana do identyfikowania Twojego modelu w projekcie Firebase, a następnie prześlij plik modelu TensorFlow Lite (zwykle kończący się na .tflite lub .lite).
  4. W pliku manifestu aplikacji zadeklaruj, że wymagane są uprawnienia INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />
    

Po dodaniu modelu niestandardowego do projektu Firebase możesz odwoływać się do niego w swoich aplikacjach, korzystając z określonej przez Ciebie nazwy. W każdej chwili możesz przesłać nowy model TensorFlow Lite, a aplikacja go pobierze i zacznie z niego korzystać po ponownym uruchomieniu. Możesz określić warunki urządzenia, które muszą zostać spełnione, aby aplikacja podjęła próbę aktualizacji modelu (patrz poniżej).

Połącz modele z aplikacją

Aby połączyć model TensorFlow Lite z aplikacją, skopiuj plik modelu (który zwykle kończy się na .tflite lub .lite) do folderu assets/ aplikacji. Być może trzeba będzie najpierw utworzyć folder – w tym celu kliknij prawym przyciskiem myszy folder app/, a potem kliknij Nowy > Folder > Folder zasobów.

Następnie dodaj do pliku build.gradle aplikacji ten kod, aby mieć pewność, że Gradle nie skompresuje modeli podczas jej tworzenia:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Plik modelu zostanie dołączony do pakietu aplikacji i dostępny dla ML Kit jako zasób nieprzetworzony.

Wczytaj model

Aby użyć modelu TensorFlow Lite w aplikacji, najpierw skonfiguruj ML Kit z lokalizacjami, w których model jest dostępny: zdalnie za pomocą Firebase, w pamięci lokalnej lub obu tych miejscach. Jeśli określisz model lokalny i zdalny, możesz go używać (o ile jest dostępny) i w razie potrzeby wrócić do modelu przechowywanego lokalnie.

Konfigurowanie modelu hostowanego w Firebase

Jeśli model był hostowany w Firebase, utwórz obiekt FirebaseCustomRemoteModel z nazwą przypisaną do modelu podczas jego przesyłania:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Następnie rozpocznij zadanie pobierania modelu, określając warunki, które muszą zostać spełnione, aby można było pobierać dane. Jeśli modelu nie ma na urządzeniu lub jeśli jest dostępna jego nowsza wersja, zadanie pobierze go asynchronicznie z Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Wiele aplikacji rozpoczyna zadanie pobierania w kodzie inicjowania, ale możesz to zrobić w dowolnym momencie, zanim trzeba będzie użyć modelu.

Konfigurowanie modelu lokalnego

Jeśli Twój model został połączony z aplikacją, utwórz obiekt FirebaseCustomLocalModel z nazwą pliku modelu TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Tworzenie interpretera na podstawie modelu

Po skonfigurowaniu źródeł modeli utwórz na ich podstawie obiekt FirebaseModelInterpreter.

Jeśli masz tylko model dołączony lokalnie, po prostu utwórz interpreter na podstawie obiektu FirebaseCustomLocalModel:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Jeśli masz model hostowany zdalnie, przed uruchomieniem musisz sprawdzić, czy został pobrany. Stan zadania pobierania modelu możesz sprawdzić za pomocą metody isModelDownloaded() menedżera modeli.

Chociaż musisz to potwierdzić przed uruchomieniem interpretera, jeśli masz zarówno model hostowany zdalnie, jak i model umieszczony lokalnie, warto przeprowadzić tę kontrolę podczas tworzenia instancji interpretera modelu: utwórz interpreter z modelu zdalnego (jeśli został pobrany), a z modelu lokalnego – w przeciwnym razie.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Jeśli masz tylko model hostowany zdalnie, wyłącz związane z nim funkcje – na przykład wyszarzoną lub ukryj część interfejsu użytkownika – do czasu potwierdzenia pobrania modelu. Możesz to zrobić, dołączając detektor do metody download() menedżera modeli:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Określ dane wejściowe i wyjściowe modelu

Następnie skonfiguruj formaty wejściowe i wyjściowe interpretera modelu.

Model TensorFlow Lite przyjmuje jako dane wejściowe i generuje jako dane wyjściowe jedną lub więcej wielowymiarowych tablic. Te tablice zawierają wartości byte, int, long lub float. Musisz skonfigurować ML Kit, podając liczbę i wymiary („kształt”) tablic, których używa model.

Jeśli nie znasz kształtu i typu danych wejściowych i wyjściowych modelu, możesz sprawdzić model za pomocą interpretera Pythona TensorFlow Lite. Przykład:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Po określeniu formatu danych wejściowych i wyjściowych modelu możesz skonfigurować interpreter modelu aplikacji, tworząc obiekt FirebaseModelInputOutputOptions.

Na przykład model klasyfikacji obrazów zmiennoprzecinkowych może przyjmować jako dane wejściowe tablicę Nx224 x 224 x 3 wartości float reprezentującą grupę N 3-kanałowych obrazów (RGB) 224 x 224 i wygenerować jako dane wyjściowe listę 1000 wartości float, z których każda reprezentuje prawdopodobieństwo, że obraz należy do jednej z 1000 przewidywanych przez model kategorii.

W takim modelu trzeba skonfigurować dane wejściowe i wyjściowe interpretera modelu w następujący sposób:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Przeprowadź wnioskowanie na danych wejściowych

Aby przeprowadzić wnioskowanie za pomocą modelu, pobierz dane wejściowe i wykonaj wszelkie przekształcenia danych, które są niezbędne do uzyskania tablicy wejściowej o kształcie odpowiednim dla modelu.

Jeśli na przykład masz model klasyfikacji obrazów o wejściowym kształcie wartości zmiennoprzecinkowych [1 224 224 3], możesz wygenerować tablicę wejściową z obiektu Bitmap, jak pokazano w tym przykładzie:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Następnie utwórz obiekt FirebaseModelInputs z danymi wejściowymi i przekaż go oraz specyfikację danych wejściowych i wyjściowych modelu do metody run interpretatora modelu:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Jeśli wywołanie się powiedzie, możesz uzyskać dane wyjściowe, wywołując metodę getOutput() obiektu przekazywanego do detektora sukcesu. Przykład:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Sposób wykorzystania danych wyjściowych zależy od używanego modelu.

Jeśli na przykład przeprowadzasz klasyfikację, w następnym kroku możesz zmapować indeksy wyników na etykiety, które reprezentują:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Dodatek: zabezpieczenia modelu

Niezależnie od tego, jak udostępnisz modele TensorFlow Lite w ML Kit, ML Kit będzie przechowywać je w standardowym zserializowanym formacie protokołu w pamięci lokalnej.

Teoretycznie oznacza to, że każdy może skopiować Twój model. W praktyce większość modeli jest jednak tak związana z określoną aplikacją i ukrywana przez optymalizacje, że ryzyko jest podobne do ryzyka związanego z demontażem i ponownym wykorzystaniem Twojego kodu przez konkurencję. Musisz jednak wiedzieć o tym ryzyku, zanim użyjesz w aplikacji modelu niestandardowego.

Na poziomie interfejsu Android API na poziomie 21 (Lollipop) lub nowszym model jest pobierany do katalogu, który jest wykluczony z automatycznego tworzenia kopii zapasowych.

W wersji interfejsu Android API na poziomie 20 lub starszym model jest pobierany do katalogu o nazwie com.google.firebase.ml.custom.models w pamięci wewnętrznej prywatnej aplikacji. Jeśli kopia zapasowa plików została włączona przy użyciu BackupAgent, możesz wykluczyć ten katalog.