欢迎参加我们将于 2022 年 10 月 18 日举办的 Firebase 峰会(线上线下同时进行),了解 Firebase 如何帮助您加快应用开发速度、满怀信心地发布应用并在之后需要时轻松地扩大应用规模。立即报名

Użyj modelu TensorFlow Lite do wnioskowania z ML Kit na Androidzie

Możesz użyć ML Kit, aby przeprowadzić wnioskowanie na urządzeniu z modelem TensorFlow Lite .

Ten interfejs API wymaga Android SDK na poziomie 16 (Jelly Bean) lub nowszym.

Zanim zaczniesz

  1. Jeśli jeszcze tego nie zrobiłeś, dodaj Firebase do swojego projektu na Androida .
  2. Dodaj zależności dla bibliotek ML Kit Android do pliku Gradle modułu (na poziomie aplikacji) (zwykle app/build.gradle ):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. Przekonwertuj model TensorFlow, którego chcesz użyć, na format TensorFlow Lite. Zobacz TOCO: Konwerter optymalizujący TensorFlow Lite .

Hostuj lub pakuj swój model

Zanim będzie można użyć modelu TensorFlow Lite do wnioskowania w aplikacji, należy udostępnić model ML Kit. ML Kit może korzystać z modeli TensorFlow Lite hostowanych zdalnie za pomocą Firebase, w pakiecie z plikiem binarnym aplikacji lub z jednym i drugim.

Hostując model w Firebase, możesz zaktualizować model bez wydawania nowej wersji aplikacji, a także za pomocą Zdalnej konfiguracji i testów A/B możesz dynamicznie udostępniać różne modele różnym grupom użytkowników.

Jeśli zdecydujesz się udostępnić model tylko przez hostowanie go w Firebase, a nie łączyć go z aplikacją, możesz zmniejszyć początkowy rozmiar pobieranej aplikacji. Pamiętaj jednak, że jeśli model nie jest dołączony do Twojej aplikacji, żadne funkcje związane z modelem nie będą dostępne, dopóki Twoja aplikacja nie pobierze modelu po raz pierwszy.

Łącząc swój model z aplikacją, możesz mieć pewność, że funkcje ML aplikacji będą nadal działać, gdy model hostowany przez Firebase nie jest dostępny.

Hostuj modele w Firebase

Aby hostować swój model TensorFlow Lite w Firebase:

  1. W sekcji ML Kit konsoli Firebase kliknij kartę Niestandardowe .
  2. Kliknij Dodaj model niestandardowy (lub Dodaj kolejny model ).
  3. Podaj nazwę, która będzie używana do identyfikacji Twojego modelu w projekcie Firebase, a następnie prześlij plik modelu TensorFlow Lite (zazwyczaj kończący się na .tflite lub .lite ).
  4. W manifeście aplikacji zadeklaruj, że wymagane jest uprawnienie INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />
    

Po dodaniu niestandardowego modelu do projektu Firebase możesz odwoływać się do modelu w swoich aplikacjach, używając określonej przez siebie nazwy. W dowolnym momencie możesz przesłać nowy model TensorFlow Lite, a Twoja aplikacja pobierze nowy model i zacznie go używać po ponownym uruchomieniu aplikacji. Możesz zdefiniować warunki urządzenia wymagane przez aplikację, aby spróbować zaktualizować model (patrz poniżej).

Połącz modele z aplikacją

Aby połączyć model TensorFlow Lite z aplikacją, skopiuj plik modelu (zazwyczaj kończący się na .tflite lub .lite ) do folderu asset assets/ s aplikacji. (Być może najpierw trzeba będzie utworzyć folder, klikając prawym przyciskiem myszy app/ folder, a następnie klikając Nowy > Folder > Folder zasobów .)

Następnie dodaj następujące elementy do pliku build.gradle swojej aplikacji, aby upewnić się, że Gradle nie kompresuje modeli podczas tworzenia aplikacji:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Plik modelu zostanie dołączony do pakietu aplikacji i będzie dostępny dla ML Kit jako surowy zasób.

Załaduj model

Aby użyć modelu TensorFlow Lite w swojej aplikacji, najpierw skonfiguruj zestaw ML Kit z lokalizacjami, w których model jest dostępny: zdalnie przy użyciu Firebase, w pamięci lokalnej lub w obu tych miejscach. Jeśli określisz zarówno model lokalny, jak i zdalny, możesz użyć modelu zdalnego, jeśli jest dostępny, i wrócić do modelu przechowywanego lokalnie, jeśli model zdalny nie jest dostępny.

Skonfiguruj model hostowany w Firebase

Jeśli hostujesz swój model za pomocą Firebase, utwórz obiekt FirebaseCustomRemoteModel , określając nazwę przypisaną modelowi podczas przesyłania:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Następnie uruchom zadanie pobierania modelu, określając warunki, na jakich chcesz zezwolić na pobieranie. Jeśli modelu nie ma na urządzeniu lub jeśli dostępna jest nowsza wersja modelu, zadanie asynchronicznie pobierze model z Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Wiele aplikacji rozpoczyna zadanie pobierania w swoim kodzie inicjującym, ale możesz to zrobić w dowolnym momencie, zanim będziesz musiał użyć modelu.

Skonfiguruj model lokalny

Jeśli model został powiązany z aplikacją, utwórz obiekt FirebaseCustomLocalModel , określając nazwę pliku modelu TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Stwórz tłumacza na podstawie swojego modelu

Po skonfigurowaniu źródeł modelu utwórz na podstawie jednego z nich obiekt FirebaseModelInterpreter .

Jeśli masz tylko model powiązany lokalnie, po prostu utwórz interpreter z obiektu FirebaseCustomLocalModel :

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Jeśli masz model hostowany zdalnie, przed uruchomieniem musisz sprawdzić, czy został pobrany. Możesz sprawdzić status zadania pobierania modelu za pomocą metody isModelDownloaded() menedżera modeli.

Chociaż musisz to tylko potwierdzić przed uruchomieniem interpretera, jeśli masz zarówno model hostowany zdalnie, jak i model wiązany lokalnie, sensowne może być wykonanie tego sprawdzenia podczas tworzenia instancji interpretera modelu: utwórz interpreter z modelu zdalnego, jeśli został pobrany, az modelu lokalnego w inny sposób.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Jeśli masz tylko model hostowany zdalnie, wyłącz funkcje związane z modelem — na przykład wyszarzanie lub ukrycie części interfejsu użytkownika — do momentu potwierdzenia, że ​​model został pobrany. Możesz to zrobić, dołączając detektor do metody download() menedżera modeli:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Określ dane wejściowe i wyjściowe modelu

Następnie skonfiguruj formaty wejściowe i wyjściowe interpretera modelu.

Model TensorFlow Lite przyjmuje jako dane wejściowe i generuje jako dane wyjściowe jedną lub więcej tablic wielowymiarowych. Te tablice zawierają wartości byte , int , long lub float . Musisz skonfigurować ML Kit z liczbą i wymiarami ("kształtem") tablic używanych przez Twój model.

Jeśli nie znasz kształtu i typu danych danych wejściowych i wyjściowych modelu, możesz użyć interpretera TensorFlow Lite Python do sprawdzenia modelu. Na przykład:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Po określeniu formatu danych wejściowych i wyjściowych modelu możesz skonfigurować interpreter modelu aplikacji, tworząc obiekt FirebaseModelInputOutputOptions .

Na przykład model klasyfikacji obrazów zmiennoprzecinkowych może przyjmować jako dane wejściowe tablicę wartości float N x224x224x3, reprezentującą partię obrazów trójkanałowych N 224x224 (RGB) i generować jako dane wyjściowe listę 1000 wartości float , z których każda reprezentuje prawdopodobieństwo, że obraz należy do jednej z 1000 kategorii przewidywanych przez model.

Dla takiego modelu należy skonfigurować wejście i wyjście interpretera modelu, jak pokazano poniżej:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Wykonaj wnioskowanie na danych wejściowych

Na koniec, aby wykonać wnioskowanie przy użyciu modelu, pobierz dane wejściowe i wykonaj dowolne przekształcenia na danych, które są niezbędne do uzyskania tablicy wejściowej o odpowiednim kształcie dla modelu.

Na przykład, jeśli masz model klasyfikacji obrazów z kształtem wejściowym równym [1 224 224 3] wartości zmiennoprzecinkowych, możesz wygenerować tablicę wejściową z obiektu Bitmap , jak pokazano w poniższym przykładzie:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Następnie utwórz obiekt FirebaseModelInputs z danymi wejściowymi i przekaż go oraz specyfikację danych wejściowych i wyjściowych modelu do metody run interpretera modelu :

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Jeśli wywołanie się powiedzie, możesz uzyskać dane wyjściowe, wywołując getOutput() obiektu, który jest przekazywany do detektora sukcesu. Na przykład:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Sposób wykorzystania danych wyjściowych zależy od używanego modelu.

Na przykład, jeśli przeprowadzasz klasyfikację, w następnym kroku możesz zmapować indeksy wyniku do etykiet, które reprezentują:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Załącznik: Zabezpieczenia modelu

Niezależnie od tego, w jaki sposób udostępniasz swoje modele TensorFlow Lite w ML Kit, ML Kit przechowuje je w standardowym serializowanym formacie protobuf w lokalnej pamięci masowej.

W teorii oznacza to, że każdy może skopiować Twój model. Jednak w praktyce większość modeli jest tak specyficzna dla aplikacji i zaciemniona przez optymalizacje, że ryzyko jest podobne do tego, jakie ma konkurencja w przypadku demontażu i ponownego użycia kodu. Niemniej jednak powinieneś być świadomy tego ryzyka, zanim użyjesz niestandardowego modelu w swojej aplikacji.

Na poziomie 21 interfejsu API systemu Android (Lollipop) i nowszych model jest pobierany do katalogu wykluczonego z automatycznego tworzenia kopii zapasowych .

W przypadku interfejsu API systemu Android na poziomie 20 i starszych model jest pobierany do katalogu o nazwie com.google.firebase.ml.custom.models w prywatnej pamięci wewnętrznej aplikacji. Jeśli włączono tworzenie kopii zapasowej plików za pomocą BackupAgent , możesz wykluczyć ten katalog.