获取我们在 Firebase 峰会上发布的所有信息,了解 Firebase 可如何帮助您加快应用开发速度并满怀信心地运行应用。了解详情

Użyj modelu TensorFlow Lite do wnioskowania z ML Kit na Androidzie

Za pomocą zestawu ML Kit można przeprowadzać wnioskowanie na urządzeniu za pomocą modelu TensorFlow Lite .

Ten interfejs API wymaga Android SDK poziomu 16 (Jelly Bean) lub nowszego.

Zanim zaczniesz

  1. Jeśli jeszcze tego nie zrobiłeś, dodaj Firebase do swojego projektu na Androida .
  2. Dodaj zależności dla bibliotek ML Kit Android do swojego modułu (na poziomie aplikacji) Plik Gradle (zwykle app/build.gradle ):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. Przekonwertuj model TensorFlow, którego chcesz użyć, na format TensorFlow Lite. Zobacz TOCO: Konwerter optymalizujący TensorFlow Lite .

Hostuj lub pakuj swój model

Zanim będzie można użyć modelu TensorFlow Lite do wnioskowania w aplikacji, należy udostępnić ten model w ML Kit. ML Kit może korzystać z modeli TensorFlow Lite hostowanych zdalnie za pomocą Firebase, w pakiecie z plikiem binarnym aplikacji lub obu.

Hostując model w Firebase, możesz aktualizować model bez wydawania nowej wersji aplikacji, a także używać zdalnej konfiguracji i testów A/B do dynamicznego udostępniania różnych modeli różnym grupom użytkowników.

Jeśli zdecydujesz się udostępnić model tylko przez hostowanie go w Firebase, a nie w pakiecie z aplikacją, możesz zmniejszyć początkowy rozmiar pobieranej aplikacji. Pamiętaj jednak, że jeśli model nie jest dołączony do Twojej aplikacji, wszelkie funkcje związane z modelem nie będą dostępne, dopóki aplikacja nie pobierze modelu po raz pierwszy.

Łącząc swój model z aplikacją, możesz mieć pewność, że funkcje uczenia maszynowego będą nadal działać, gdy model hostowany przez Firebase nie będzie dostępny.

Hostuj modele w Firebase

Aby hostować swój model TensorFlow Lite w Firebase:

  1. W sekcji ML Kit konsoli Firebase kliknij kartę Niestandardowe .
  2. Kliknij Dodaj model niestandardowy (lub Dodaj inny model ).
  3. Podaj nazwę, która będzie używana do identyfikacji Twojego modelu w projekcie Firebase, a następnie prześlij plik modelu TensorFlow Lite (zwykle kończący się na .tflite lub .lite ).
  4. W manifeście swojej aplikacji zadeklaruj, że wymagane jest uprawnienie INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />
    

Po dodaniu niestandardowego modelu do projektu Firebase możesz odwoływać się do modelu w swoich aplikacjach, używając określonej nazwy. W dowolnym momencie możesz przesłać nowy model TensorFlow Lite, a Twoja aplikacja pobierze nowy model i zacznie z niego korzystać, gdy aplikacja zostanie ponownie uruchomiona. Możesz zdefiniować warunki urządzenia wymagane, aby Twoja aplikacja próbowała zaktualizować model (patrz poniżej).

Połącz modele z aplikacją

Aby połączyć model TensorFlow Lite ze swoją aplikacją, skopiuj plik modelu (zwykle z .tflite lub .lite ) do folderu asset assets/ aplikacji. (Może być konieczne uprzednie utworzenie folderu przez kliknięcie app/ folderu prawym przyciskiem myszy, a następnie kliknięcie opcji Nowy > Folder > Folder zasobów .)

Następnie dodaj następujące elementy do pliku build.gradle aplikacji, aby upewnić się, że Gradle nie kompresuje modeli podczas tworzenia aplikacji:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Plik modelu zostanie dołączony do pakietu aplikacji i będzie dostępny dla ML Kit jako nieprzetworzony zasób.

Załaduj model

Aby użyć swojego modelu TensorFlow Lite w swojej aplikacji, najpierw skonfiguruj ML Kit z lokalizacjami, w których Twój model jest dostępny: zdalnie przy użyciu Firebase, w pamięci lokalnej lub w obu przypadkach. Jeśli określisz zarówno model lokalny, jak i zdalny, możesz użyć modelu zdalnego, jeśli jest dostępny, i wrócić do modelu przechowywanego lokalnie, jeśli model zdalny nie jest dostępny.

Skonfiguruj model hostowany przez Firebase

Jeśli hostowałeś swój model w Firebase, utwórz obiekt FirebaseCustomRemoteModel , określając nazwę, którą przypisałeś modelowi podczas przesyłania:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Następnie uruchom zadanie pobierania modelu, określając warunki, na jakich chcesz zezwolić na pobieranie. Jeśli modelu nie ma na urządzeniu lub dostępna jest nowsza wersja modelu, zadanie asynchronicznie pobierze model z Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Wiele aplikacji uruchamia zadanie pobierania w swoim kodzie inicjującym, ale możesz to zrobić w dowolnym momencie, zanim będzie trzeba użyć modelu.

Skonfiguruj model lokalny

Jeśli model został dołączony do Twojej aplikacji, utwórz obiekt FirebaseCustomLocalModel , określając nazwę pliku modelu TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Utwórz interpreter ze swojego modelu

Po skonfigurowaniu źródeł modelu utwórz obiekt FirebaseModelInterpreter na podstawie jednego z nich.

Jeśli masz tylko model w pakiecie lokalnym, po prostu utwórz interpreter z obiektu FirebaseCustomLocalModel :

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Jeśli masz model hostowany zdalnie, przed uruchomieniem musisz sprawdzić, czy został pobrany. Stan zadania pobierania modelu można sprawdzić za pomocą metody isModelDownloaded() menedżera modelu.

Chociaż musisz to tylko potwierdzić przed uruchomieniem interpretera, jeśli masz zarówno model hostowany zdalnie, jak i model w pakiecie lokalnym, warto przeprowadzić tę kontrolę podczas tworzenia instancji interpretera modelu: utwórz interpreter na podstawie modelu zdalnego, jeśli został pobrany, aw innym przypadku z modelu lokalnego.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Jeśli masz tylko zdalnie hostowany model, powinieneś wyłączyć funkcje związane z modelem — na przykład wyszarzyć lub ukryć część interfejsu użytkownika — do czasu potwierdzenia, że ​​model został pobrany. Możesz to zrobić, dołączając odbiornik do metody download() menedżera modeli:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Określ wejście i wyjście modelu

Następnie skonfiguruj formaty wejściowe i wyjściowe interpretera modelu.

Model TensorFlow Lite przyjmuje jako dane wejściowe i generuje jako dane wyjściowe jedną lub więcej tablic wielowymiarowych. Te tablice zawierają wartości byte , int , long lub float . Musisz skonfigurować zestaw ML przy użyciu liczby i wymiarów („kształtu”) macierzy używanych przez Twój model.

Jeśli nie znasz kształtu i typu danych danych wejściowych i wyjściowych swojego modelu, możesz użyć interpretera języka Python TensorFlow Lite, aby sprawdzić swój model. Na przykład:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Po określeniu formatu danych wejściowych i wyjściowych modelu można skonfigurować interpreter modelu aplikacji, tworząc obiekt FirebaseModelInputOutputOptions .

Na przykład model klasyfikacji obrazów zmiennoprzecinkowych może przyjmować jako dane wejściowe tablicę wartości float N x 224 x 224 x 3, reprezentującą partię N 224 x 224 trójkanałowych obrazów (RGB), i generować jako dane wyjściowe listę 1000 wartości float , z których każda reprezentuje prawdopodobieństwo, że obraz należy do jednej z 1000 kategorii przewidywanych przez model.

W przypadku takiego modelu należy skonfigurować wejście i wyjście interpretera modelu, jak pokazano poniżej:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Wykonaj wnioskowanie na danych wejściowych

Na koniec, aby przeprowadzić wnioskowanie za pomocą modelu, pobierz dane wejściowe i przeprowadź na nich wszelkie przekształcenia, które są niezbędne do uzyskania tablicy wejściowej o odpowiednim kształcie dla twojego modelu.

Na przykład, jeśli masz model klasyfikacji obrazów z kształtem wejściowym [1 224 224 3] wartości zmiennoprzecinkowych, możesz wygenerować tablicę wejściową z obiektu Bitmap , jak pokazano w poniższym przykładzie:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Następnie utwórz obiekt FirebaseModelInputs z danymi wejściowymi i przekaż go wraz ze specyfikacją danych wejściowych i wyjściowych modelu do metody run interpretera modelu :

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Jeśli wywołanie powiedzie się, możesz uzyskać dane wyjściowe, wywołując getOutput() obiektu, który jest przekazywany do detektora sukcesu. Na przykład:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Sposób wykorzystania danych wyjściowych zależy od używanego modelu.

Na przykład, jeśli przeprowadzasz klasyfikację, w następnym kroku możesz zmapować indeksy wyniku na etykiety, które reprezentują:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Dodatek: Bezpieczeństwo modeli

Niezależnie od tego, w jaki sposób udostępniasz swoje modele TensorFlow Lite w ML Kit, ML Kit przechowuje je w standardowym, serializowanym formacie protobuf w lokalnej pamięci masowej.

Teoretycznie oznacza to, że każdy może skopiować Twój model. Jednak w praktyce większość modeli jest tak specyficzna dla aplikacji i zaciemniona przez optymalizacje, że ryzyko jest podobne do ryzyka dezasemblacji i ponownego użycia kodu przez konkurencję. Niemniej jednak powinieneś być świadomy tego ryzyka, zanim użyjesz niestandardowego modelu w swojej aplikacji.

W interfejsie API systemu Android na poziomie 21 (Lollipop) i nowszych model jest pobierany do katalogu wykluczonego z automatycznego tworzenia kopii zapasowych .

W interfejsie API Androida na poziomie 20 lub starszym model jest pobierany do katalogu o nazwie com.google.firebase.ml.custom.models w prywatnej pamięci wewnętrznej aplikacji. Jeśli włączono tworzenie kopii zapasowych plików za pomocą narzędzia BackupAgent , można wykluczyć ten katalog.