Sử dụng mô hình TensorFlow Lite để suy luận với ML Kit trên Android

Bạn có thể sử dụng ML Kit để thực hiện suy luận trên thiết bị bằng mô hình TensorFlow Lite .

API này yêu cầu Android SDK cấp 16 (Jelly Bean) hoặc mới hơn.

Trước khi bắt đầu

  1. Nếu bạn chưa có, hãy thêm Firebase vào dự án Android của bạn .
  2. Thêm các phần phụ thuộc cho thư viện ML Kit Android vào tệp Gradle mô-đun (cấp ứng dụng) của bạn (thường là app/build.gradle ):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. Chuyển đổi mô hình TensorFlow bạn muốn sử dụng sang định dạng TensorFlow Lite. Xem TOCO: Bộ chuyển đổi tối ưu hóa TensorFlow Lite .

Lưu trữ hoặc đóng gói mô hình của bạn

Trước khi có thể sử dụng mô hình TensorFlow Lite để suy luận trong ứng dụng của mình, bạn phải cung cấp mô hình đó cho ML Kit. ML Kit có thể sử dụng các mô hình TensorFlow Lite được lưu trữ từ xa bằng Firebase, đi kèm với tệp nhị phân ứng dụng hoặc cả hai.

Bằng cách lưu trữ một mô hình trên Firebase, bạn có thể cập nhật mô hình mà không cần phát hành phiên bản ứng dụng mới, đồng thời bạn có thể sử dụng Cấu hình từ xa và Thử nghiệm A/B để phân phát động các mô hình khác nhau cho các nhóm người dùng khác nhau.

Nếu bạn chọn chỉ cung cấp mô hình bằng cách lưu trữ mô hình đó với Firebase chứ không kết hợp mô hình đó với ứng dụng của mình thì bạn có thể giảm kích thước tải xuống ban đầu của ứng dụng. Tuy nhiên, hãy nhớ rằng nếu mô hình không đi kèm với ứng dụng của bạn thì mọi chức năng liên quan đến mô hình sẽ không khả dụng cho đến khi ứng dụng của bạn tải mô hình xuống lần đầu tiên.

Bằng cách kết hợp mô hình với ứng dụng của bạn, bạn có thể đảm bảo các tính năng ML của ứng dụng vẫn hoạt động khi mô hình được lưu trữ trên Firebase không khả dụng.

Lưu trữ mô hình trên Firebase

Để lưu trữ mô hình TensorFlow Lite của bạn trên Firebase:

  1. Trong phần Bộ công cụ ML của bảng điều khiển Firebase , hãy nhấp vào tab Tùy chỉnh .
  2. Nhấp vào Thêm mô hình tùy chỉnh (hoặc Thêm mô hình khác ).
  3. Chỉ định tên sẽ được sử dụng để xác định mô hình của bạn trong dự án Firebase, sau đó tải tệp mô hình TensorFlow Lite lên (thường kết thúc bằng .tflite hoặc .lite ).
  4. Trong bảng kê khai ứng dụng của bạn, hãy khai báo rằng cần có quyền INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />
    

Sau khi thêm mô hình tùy chỉnh vào dự án Firebase, bạn có thể tham chiếu mô hình đó trong ứng dụng của mình bằng tên bạn đã chỉ định. Bất cứ lúc nào, bạn có thể tải lên mô hình TensorFlow Lite mới và ứng dụng của bạn sẽ tải xuống mô hình mới và bắt đầu sử dụng mô hình đó khi ứng dụng khởi động lại lần tiếp theo. Bạn có thể xác định các điều kiện thiết bị cần thiết để ứng dụng của mình cố gắng cập nhật mô hình (xem bên dưới).

Gói mô hình với một ứng dụng

Để kết hợp mô hình TensorFlow Lite với ứng dụng của bạn, hãy sao chép tệp mô hình (thường kết thúc bằng .tflite hoặc .lite ) vào thư mục assets/ của ứng dụng. (Trước tiên, bạn có thể cần tạo thư mục bằng cách nhấp chuột phải vào app/ thư mục, sau đó nhấp vào Mới > Thư mục > Thư mục nội dung .)

Sau đó, thêm phần sau vào tệp build.gradle của ứng dụng để đảm bảo Gradle không nén các mô hình khi xây dựng ứng dụng:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Tệp mô hình sẽ được bao gồm trong gói ứng dụng và có sẵn cho ML Kit dưới dạng nội dung thô.

Tải mô hình

Để sử dụng mô hình TensorFlow Lite trong ứng dụng của bạn, trước tiên hãy định cấu hình Bộ ML với các vị trí có sẵn mô hình của bạn: sử dụng Firebase từ xa, trong bộ nhớ cục bộ hoặc cả hai. Nếu bạn chỉ định cả mô hình cục bộ và mô hình từ xa, bạn có thể sử dụng mô hình từ xa nếu có sẵn và quay lại mô hình được lưu trữ cục bộ nếu mô hình từ xa không có sẵn.

Định cấu hình mô hình được lưu trữ trên Firebase

Nếu bạn đã lưu trữ mô hình của mình bằng Firebase, hãy tạo đối tượng FirebaseCustomRemoteModel , chỉ định tên bạn đã gán cho mô hình khi tải mô hình lên:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Sau đó, bắt đầu tác vụ tải xuống mô hình, chỉ định các điều kiện mà bạn muốn cho phép tải xuống. Nếu mô hình không có trên thiết bị hoặc nếu có phiên bản mới hơn của mô hình thì tác vụ sẽ tải xuống mô hình một cách không đồng bộ từ Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Nhiều ứng dụng bắt đầu tác vụ tải xuống trong mã khởi tạo của chúng nhưng bạn có thể làm như vậy bất kỳ lúc nào trước khi cần sử dụng mô hình.

Định cấu hình mô hình cục bộ

Nếu bạn kết hợp mô hình với ứng dụng của mình, hãy tạo đối tượng FirebaseCustomLocalModel , chỉ định tên tệp của mô hình TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Tạo một trình thông dịch từ mô hình của bạn

Sau khi bạn định cấu hình các nguồn mô hình của mình, hãy tạo đối tượng FirebaseModelInterpreter từ một trong các nguồn đó.

Nếu bạn chỉ có một mô hình được đóng gói cục bộ, chỉ cần tạo một trình thông dịch từ đối tượng FirebaseCustomLocalModel của bạn:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Nếu bạn có một mô hình được lưu trữ từ xa, bạn sẽ phải kiểm tra xem nó đã được tải xuống chưa trước khi chạy nó. Bạn có thể kiểm tra trạng thái của tác vụ tải xuống mô hình bằng phương thức isModelDownloaded() của trình quản lý mô hình.

Mặc dù bạn chỉ phải xác nhận điều này trước khi chạy trình thông dịch, nhưng nếu bạn có cả mô hình được lưu trữ từ xa và mô hình được đóng gói cục bộ, bạn có thể thực hiện kiểm tra này khi khởi tạo trình thông dịch mô hình: tạo trình thông dịch từ mô hình từ xa nếu nó đã được tải xuống và từ mô hình cục bộ.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Nếu chỉ có một mô hình được lưu trữ từ xa, bạn nên tắt chức năng liên quan đến mô hình—ví dụ: chuyển sang màu xám hoặc ẩn một phần giao diện người dùng—cho đến khi bạn xác nhận rằng mô hình đã được tải xuống. Bạn có thể làm như vậy bằng cách đính kèm một trình nghe vào phương thức download() của trình quản lý mô hình:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Chỉ định đầu vào và đầu ra của mô hình

Tiếp theo, định cấu hình định dạng đầu vào và đầu ra của trình thông dịch mô hình.

Mô hình TensorFlow Lite lấy đầu vào và tạo ra một hoặc nhiều mảng đa chiều làm đầu ra. Các mảng này chứa các giá trị byte , int , long hoặc float . Bạn phải định cấu hình Bộ ML với số lượng và kích thước ("hình dạng") của mảng mà mô hình của bạn sử dụng.

Nếu không biết hình dạng và kiểu dữ liệu của đầu vào và đầu ra của mô hình, bạn có thể sử dụng trình thông dịch Python TensorFlow Lite để kiểm tra mô hình của mình. Ví dụ:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Sau khi xác định định dạng đầu vào và đầu ra của mô hình, bạn có thể định cấu hình trình thông dịch mô hình của ứng dụng bằng cách tạo đối tượng FirebaseModelInputOutputOptions .

Ví dụ: mô hình phân loại hình ảnh dấu phẩy động có thể lấy đầu vào là một mảng giá trị float N x224x224x3, đại diện cho một loạt N hình ảnh ba kênh (RGB) N 224x224 và tạo ra dưới dạng đầu ra một danh sách gồm 1000 giá trị float , mỗi giá trị đại diện cho xác suất hình ảnh là thành viên của một trong 1000 loại mà mô hình dự đoán.

Đối với mô hình như vậy, bạn sẽ định cấu hình đầu vào và đầu ra của trình thông dịch mô hình như dưới đây:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Thực hiện suy luận trên dữ liệu đầu vào

Cuối cùng, để thực hiện suy luận bằng mô hình, hãy lấy dữ liệu đầu vào của bạn và thực hiện bất kỳ phép biến đổi nào trên dữ liệu cần thiết để có được mảng đầu vào có hình dạng phù hợp cho mô hình của bạn.

Ví dụ: nếu bạn có mô hình phân loại hình ảnh với hình dạng đầu vào là các giá trị dấu phẩy động [1 224 224 3], bạn có thể tạo một mảng đầu vào từ đối tượng Bitmap như trong ví dụ sau:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Sau đó, tạo một đối tượng FirebaseModelInputs với dữ liệu đầu vào của bạn và chuyển nó cũng như thông số đầu vào và đầu ra của mô hình tới phương thức run của trình thông dịch mô hình :

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Nếu cuộc gọi thành công, bạn có thể nhận được kết quả đầu ra bằng cách gọi phương thức getOutput() của đối tượng được chuyển đến trình nghe thành công. Ví dụ:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Cách bạn sử dụng đầu ra tùy thuộc vào kiểu máy bạn đang sử dụng.

Ví dụ: nếu bạn đang thực hiện phân loại, bước tiếp theo, bạn có thể ánh xạ các chỉ mục của kết quả tới các nhãn mà chúng đại diện:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Phụ lục: Bảo mật mô hình

Bất kể bạn cung cấp các mô hình TensorFlow Lite của mình cho ML Kit bằng cách nào, ML Kit sẽ lưu trữ chúng ở định dạng protobuf được tuần tự hóa tiêu chuẩn trong bộ nhớ cục bộ.

Về lý thuyết, điều này có nghĩa là bất kỳ ai cũng có thể sao chép mô hình của bạn. Tuy nhiên, trên thực tế, hầu hết các mô hình đều dành riêng cho ứng dụng và bị che khuất bởi sự tối ưu hóa nên rủi ro tương tự như rủi ro của các đối thủ cạnh tranh khi tháo rời và sử dụng lại mã của bạn. Tuy nhiên, bạn nên lưu ý đến rủi ro này trước khi sử dụng mô hình tùy chỉnh trong ứng dụng của mình.

Trên API Android cấp 21 (Lollipop) trở lên, mô hình được tải xuống thư mục được loại trừ khỏi quá trình sao lưu tự động .

Trên API Android cấp 20 trở lên, mô hình được tải xuống thư mục có tên com.google.firebase.ml.custom.models trong bộ nhớ trong riêng tư của ứng dụng. Nếu bạn đã bật sao lưu tệp bằng BackupAgent , bạn có thể chọn loại trừ thư mục này.