Sử dụng mô hình TensorFlow Lite để suy luận bằng Bộ công cụ học máy trên Android

Bạn có thể sử dụng Bộ công cụ học máy để tiến hành suy luận trên thiết bị bằng mô hình TensorFlow Lite.

API này yêu cầu SDK Android cấp 16 (Jelly Bean) trở lên.

Trước khi bắt đầu

  1. Thêm Firebase vào dự án Android của bạn nếu bạn chưa thực hiện.
  2. Thêm các phần phụ thuộc của thư viện Android Bộ công cụ học máy vào tệp Gradle của mô-đun (cấp ứng dụng) (thường là app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. Chuyển đổi mô hình TensorFlow mà bạn muốn sử dụng sang định dạng TensorFlow Lite. Xem TOCO: Trình chuyển đổi tối ưu hoá TensorFlow Lite.

Lưu trữ hoặc nhóm mô hình của bạn

Trước khi có thể sử dụng mô hình TensorFlow Lite để suy luận trong ứng dụng của mình, bạn phải cung cấp mô hình đó cho Bộ công cụ học máy. Bộ công cụ học máy có thể sử dụng các mô hình TensorFlow Lite được lưu trữ từ xa bằng Firebase, kết hợp với tệp nhị phân của ứng dụng hoặc cả hai.

Bằng cách lưu trữ một mô hình trên Firebase, bạn có thể cập nhật mô hình mà không cần phát hành phiên bản ứng dụng mới, đồng thời có thể sử dụng tính năng Cấu hình từ xa và Thử nghiệm A/B để phân phát linh động các mô hình khác nhau cho nhiều nhóm người dùng.

Nếu chọn chỉ cung cấp mô hình bằng cách lưu trữ mô hình đó trong Firebase chứ không gói với ứng dụng của mình, thì bạn có thể giảm kích thước tải xuống ban đầu của ứng dụng. Tuy nhiên, hãy lưu ý rằng nếu mô hình không đi kèm với ứng dụng của bạn, thì mọi chức năng liên quan đến mô hình sẽ không dùng được cho đến khi ứng dụng của bạn tải mô hình xuống lần đầu tiên.

Bằng cách kết hợp mô hình với ứng dụng của mình, bạn có thể đảm bảo các tính năng học máy của ứng dụng vẫn hoạt động khi mô hình do Firebase lưu trữ không có sẵn.

Mô hình lưu trữ trên Firebase

Cách lưu trữ mô hình TensorFlow Lite trên Firebase:

  1. Trong mục Bộ công cụ học máy của bảng điều khiển của Firebase, hãy nhấp vào thẻ Tuỳ chỉnh.
  2. Nhấp vào Thêm mô hình tuỳ chỉnh (hoặc Thêm mô hình khác).
  3. Đặt tên sẽ được dùng để xác định mô hình của bạn trong dự án Firebase, sau đó tải tệp mô hình TensorFlow Lite lên (thường có đuôi .tflite hoặc .lite).
  4. Trong tệp kê khai của ứng dụng, hãy khai báo rằng ứng dụng cần có quyền INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />
    

Sau khi thêm một mô hình tùy chỉnh vào dự án Firebase, bạn có thể tham chiếu mô hình đó trong các ứng dụng của mình bằng cách sử dụng tên mà bạn đã chỉ định. Bạn có thể tải một mô hình TensorFlow Lite mới lên bất cứ lúc nào và ứng dụng của bạn sẽ tải mô hình mới xuống và bắt đầu sử dụng mô hình đó khi ứng dụng khởi động lại vào lần tới. Bạn có thể xác định các điều kiện cần thiết về thiết bị để ứng dụng của bạn cố gắng cập nhật mô hình (xem bên dưới).

Gộp các mô hình bằng một ứng dụng

Để nhóm mô hình TensorFlow Lite với ứng dụng, hãy sao chép tệp mô hình (thường kết thúc bằng .tflite hoặc .lite) vào thư mục assets/ của ứng dụng. (Trước tiên, bạn có thể phải tạo thư mục bằng cách nhấp chuột phải vào thư mục app/, sau đó nhấp vào New > Folder > Assets Folder (Mới > Thư mục > Thư mục thành phần).)

Sau đó, hãy thêm đoạn mã sau vào tệp build.gradle của ứng dụng để đảm bảo Gradle không nén các mô hình khi tạo ứng dụng:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Tệp mô hình sẽ được đưa vào gói ứng dụng và được cung cấp trong Bộ công cụ học máy dưới dạng tài sản thô.

Tải mô hình

Để sử dụng mô hình TensorFlow Lite trong ứng dụng của bạn, trước tiên, hãy định cấu hình Bộ công cụ học máy với những vị trí có thể sử dụng mô hình của bạn: từ xa bằng Firebase, trong bộ nhớ cục bộ hoặc cả hai. Nếu chỉ định cả mô hình cục bộ lẫn mô hình từ xa, bạn có thể sử dụng mô hình từ xa nếu có sẵn và quay lại mô hình được lưu trữ cục bộ nếu không có mô hình từ xa.

Định cấu hình mô hình lưu trữ trên Firebase

Nếu bạn lưu trữ mô hình bằng Firebase, hãy tạo một đối tượng FirebaseCustomRemoteModel, trong đó chỉ định tên mà bạn đã chỉ định cho mô hình khi tải lên:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Sau đó, hãy bắt đầu tác vụ tải mô hình xuống, chỉ định các điều kiện mà bạn muốn cho phép tải xuống. Nếu mô hình không có trên thiết bị hoặc nếu có phiên bản mới hơn của mô hình, tác vụ sẽ tải xuống không đồng bộ mô hình từ Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Nhiều ứng dụng bắt đầu tác vụ tải xuống trong mã khởi chạy, nhưng bạn có thể làm vậy bất cứ lúc nào trước khi cần sử dụng mô hình.

Định cấu hình mô hình cục bộ

Nếu bạn đã đóng gói mô hình với ứng dụng, hãy tạo một đối tượng FirebaseCustomLocalModel, trong đó chỉ định tên tệp của mô hình TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Tạo phiên dịch từ mô hình của bạn

Sau khi bạn định cấu hình các nguồn mô hình, hãy tạo đối tượng FirebaseModelInterpreter từ một trong các nguồn đó.

Nếu chỉ có mô hình được gói cục bộ, bạn chỉ cần tạo trình thông dịch từ đối tượng FirebaseCustomLocalModel:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Nếu có một mô hình được lưu trữ từ xa, bạn sẽ phải kiểm tra xem mô hình đó đã được tải xuống hay chưa trước khi chạy. Bạn có thể kiểm tra trạng thái của tác vụ tải mô hình xuống bằng phương thức isModelDownloaded() của trình quản lý mô hình.

Mặc dù bạn chỉ phải xác nhận điều này trước khi chạy trình phiên dịch, nhưng nếu có cả mô hình được lưu trữ từ xa và mô hình được gói cục bộ, thì bạn nên thực hiện bước kiểm tra này khi tạo thực thể cho trình phiên dịch mô hình: tạo trình thông dịch từ mô hình từ xa nếu đã tải xuống và từ mô hình cục bộ nếu không.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Nếu chỉ có một mô hình được lưu trữ từ xa, bạn nên tắt chức năng liên quan đến mô hình (ví dụ: chuyển sang màu xám hoặc ẩn một phần giao diện người dùng) cho đến khi bạn xác nhận mô hình đã được tải xuống. Bạn có thể thực hiện việc này bằng cách đính kèm trình nghe vào phương thức download() của trình quản lý mô hình:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Chỉ định dữ liệu đầu vào và đầu ra của mô hình

Tiếp theo, hãy định cấu hình định dạng đầu vào và đầu ra của trình phiên dịch mô hình.

Mô hình TensorFlow Lite lấy dữ liệu đầu vào và tạo ra một hoặc nhiều mảng đa chiều dưới dạng đầu ra. Các mảng này chứa giá trị byte, int, long hoặc float. Bạn phải định cấu hình Bộ công cụ học máy bằng số lượng và kích thước ("hình dạng") của các mảng mà mô hình của bạn sử dụng.

Nếu không biết hình dạng và loại dữ liệu của đầu vào và đầu ra của mô hình, bạn có thể sử dụng trình thông dịch TensorFlow Lite Python để kiểm tra mô hình của mình. Ví dụ:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Sau khi xác định định dạng cho đầu vào và đầu ra của mô hình, bạn có thể định cấu hình cho trình thông dịch mô hình của ứng dụng bằng cách tạo một đối tượng FirebaseModelInputOutputOptions.

Ví dụ: một mô hình phân loại hình ảnh dấu phẩy động có thể lấy dữ liệu đầu vào một mảng Nx224x224x3 của các giá trị float, đại diện cho một loạt hình ảnh ba kênh (RGB) N 224x224 và tạo ra một danh sách gồm 1.000 giá trị float, mỗi giá trị thể hiện xác suất mà hình ảnh là một thành viên của một trong 1.000 danh mục mà mô hình dự đoán.

Đối với mô hình như vậy, bạn sẽ định cấu hình đầu vào và đầu ra của trình thông dịch mô hình như sau:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Tiến hành suy luận về dữ liệu đầu vào

Cuối cùng, để tiến hành suy luận bằng mô hình, hãy lấy dữ liệu đầu vào và thực hiện mọi phép biến đổi đối với dữ liệu cần thiết để có được một mảng đầu vào có hình dạng phù hợp với mô hình của bạn.

Ví dụ: nếu có mô hình phân loại hình ảnh với hình dạng đầu vào là các giá trị dấu phẩy động [1 224 224 3], bạn có thể tạo một mảng đầu vào từ đối tượng Bitmap như trong ví dụ sau:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Sau đó, hãy tạo một đối tượng FirebaseModelInputs bằng dữ liệu đầu vào, truyền đối tượng đó cũng như thông số đầu vào và đầu ra của mô hình sang phương thức run của trình phiên dịch mô hình:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Nếu lệnh gọi thành công, bạn có thể nhận kết quả bằng cách gọi phương thức getOutput() của đối tượng được truyền đến trình nghe thành công. Ví dụ:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Cách bạn sử dụng dữ liệu đầu ra phụ thuộc vào mô hình bạn đang sử dụng.

Ví dụ: nếu đang thực hiện phân loại, thì ở bước tiếp theo, bạn có thể ánh xạ các chỉ mục của kết quả với các nhãn đại diện cho chúng:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Phụ lục: Bảo mật mô hình

Bất kể bạn cung cấp các mô hình TensorFlow Lite bằng cách nào cho Bộ công cụ học máy, Bộ công cụ học máy đều lưu trữ các mô hình đó ở định dạng protobuf được chuyển đổi tuần tự tiêu chuẩn trong bộ nhớ cục bộ.

Về mặt lý thuyết, điều này có nghĩa là bất kỳ ai cũng có thể sao chép mô hình của bạn. Tuy nhiên, trong thực tế, hầu hết các mô hình đều dành riêng cho ứng dụng và bị làm rối mã nguồn bởi các tính năng tối ưu hoá, dẫn đến rủi ro tương tự như các đối thủ cạnh tranh sẽ tháo rời và sử dụng lại mã của bạn. Tuy nhiên, bạn nên lưu ý rủi ro này trước khi sử dụng mô hình tuỳ chỉnh trong ứng dụng của mình.

Trên API Android cấp 21 (Lollipop) trở lên, mô hình sẽ được tải xuống một thư mục bị loại trừ khỏi tính năng sao lưu tự động.

Trên API Android cấp 20 trở xuống, mô hình được tải xuống thư mục có tên là com.google.firebase.ml.custom.models trong bộ nhớ trong riêng tư của ứng dụng. Nếu đã bật tính năng sao lưu tệp bằng BackupAgent, thì bạn có thể chọn loại trừ thư mục này.