Sử dụng mô hình TensorFlow Lite tuỳ chỉnh trên Android

Nếu ứng dụng của bạn sử dụng các mô hình TensorFlow Lite tuỳ chỉnh, thì bạn có thể dùng công nghệ học máy của Firebase để triển khai các mô hình. Bằng cách triển khai mô hình bằng Firebase, bạn có thể giảm kích thước tải xuống ban đầu của ứng dụng và cập nhật các mô hình học máy của ứng dụng mà không cần phát hành phiên bản mới của ứng dụng. Ngoài ra, với Cấu hình từ xa và Thử nghiệm A/B, bạn có thể tự động phân phát nhiều mô hình cho nhiều nhóm người dùng.

Mô hình TensorFlow Lite

Mô hình TensorFlow Lite là mô hình học máy được tối ưu hoá để chạy trên thiết bị di động. Cách tạo mô hình TensorFlow Lite:

Trước khi bắt đầu

  1. Thêm Firebase vào dự án Android của bạn nếu bạn chưa thực hiện.
  2. Trong tệp Gradle (cấp ứng dụng) mô-đun (thường là <project>/<app-module>/build.gradle.kts hoặc <project>/<app-module>/build.gradle), hãy thêm phần phụ thuộc cho thư viện trình tải mô hình máy học Firebase dành cho Android. Bạn nên sử dụng Firebase Android BoM để kiểm soát việc tạo phiên bản thư viện.

    Ngoài ra, trong quá trình thiết lập trình tải mô hình học máy xuống Firebase, bạn cần thêm SDK TensorFlow Lite vào ứng dụng của mình.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:33.1.1"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Bằng cách sử dụng Firebase Android BoM, ứng dụng của bạn sẽ luôn sử dụng các phiên bản tương thích của thư viện Android trên Firebase.

    (Thay thế) Thêm các phần phụ thuộc của thư viện Firebase mà không sử dụng BoM

    Nếu chọn không sử dụng BoM của Firebase, bạn phải chỉ định từng phiên bản thư viện Firebase trong dòng phần phụ thuộc.

    Xin lưu ý rằng nếu sử dụng nhiều thư viện Firebase trong ứng dụng, thì bạn nên sử dụng BoM để quản lý các phiên bản thư viện, qua đó đảm bảo tất cả các phiên bản đều tương thích.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:25.0.0")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    Bạn đang tìm một mô-đun thư viện dành riêng cho Kotlin? Kể từ tháng 10 năm 2023 (Firebase BoM 32.5.0), cả nhà phát triển Kotlin và Java đều có thể phụ thuộc vào mô-đun thư viện chính (để biết thông tin chi tiết, vui lòng xem Câu hỏi thường gặp về sáng kiến này).
  3. Trong tệp kê khai của ứng dụng, hãy khai báo rằng ứng dụng cần có quyền INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Triển khai mô hình

Triển khai các mô hình TensorFlow tuỳ chỉnh bằng bảng điều khiển của Firebase hoặc SDK Node.js và Python Admin của Firebase. Xem phần Triển khai và quản lý các mô hình tuỳ chỉnh.

Sau khi thêm một mô hình tùy chỉnh vào dự án Firebase, bạn có thể tham chiếu mô hình đó trong các ứng dụng của mình bằng cách sử dụng tên mà bạn đã chỉ định. Bạn có thể triển khai mô hình TensorFlow Lite mới và tải mô hình mới xuống thiết bị của người dùng bất cứ lúc nào bằng cách gọi getModel() (xem bên dưới).

2. Tải mô hình xuống thiết bị và khởi chạy trình thông dịch TensorFlow Lite

Để sử dụng mô hình TensorFlow Lite trong ứng dụng của bạn, trước tiên, hãy sử dụng SDK máy học Firebase để tải phiên bản mới nhất của mô hình xuống thiết bị. Sau đó, hãy tạo thực thể cho trình thông dịch TensorFlow Lite bằng mô hình.

Để bắt đầu tải mô hình xuống, hãy gọi phương thức getModel() của trình tải mô hình xuống, trong đó chỉ định tên mà bạn đã chỉ định cho mô hình khi tải mô hình lên, liệu bạn có muốn luôn tải mô hình mới nhất xuống hay không và các điều kiện mà bạn muốn cho phép tải xuống.

Bạn có thể chọn một trong ba hành vi tải xuống sau:

Loại tệp tải xuống Mô tả
LOCAL_MODEL Lấy mẫu cục bộ từ thiết bị. Nếu không có mô hình cục bộ, mã này sẽ hoạt động như LATEST_MODEL. Hãy sử dụng loại tệp tải xuống này nếu bạn không muốn kiểm tra thông tin cập nhật về mô hình. Ví dụ: bạn đang sử dụng Cấu hình từ xa để truy xuất tên mô hình và bạn luôn tải các mô hình lên dưới tên mới (nên chọn).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Tải mô hình cục bộ từ thiết bị và bắt đầu cập nhật mô hình ở chế độ nền. Nếu không có mô hình cục bộ, mã này sẽ hoạt động như LATEST_MODEL.
MÔ HÌNH MỚI NHẤT Tải mẫu mới nhất. Nếu mô hình cục bộ là phiên bản mới nhất, hàm sẽ trả về mô hình cục bộ. Nếu không, hãy tải mô hình mới nhất xuống. Hành vi này sẽ chặn cho đến khi phiên bản mới nhất được tải xuống (không khuyến khích). Chỉ dùng hành vi này trong trường hợp bạn rõ ràng cần có phiên bản mới nhất.

Bạn nên tắt chức năng liên quan đến mô hình (ví dụ: chuyển sang màu xám hoặc ẩn một phần giao diện người dùng) cho đến khi bạn xác nhận mô hình đã được tải xuống.

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Nhiều ứng dụng bắt đầu tác vụ tải xuống trong mã khởi chạy, nhưng bạn có thể làm vậy bất cứ lúc nào trước khi cần sử dụng mô hình.

3. Tiến hành suy luận về dữ liệu đầu vào

Nhận các hình dạng đầu vào và đầu ra của mô hình

Trình phiên dịch mô hình TensorFlow Lite lấy dữ liệu đầu vào và tạo ra một hoặc nhiều mảng đa chiều dưới dạng đầu ra. Các mảng này chứa giá trị byte, int, long hoặc float. Trước khi có thể truyền dữ liệu vào mô hình hoặc sử dụng kết quả của mô hình đó, bạn phải biết số lượng và kích thước ("hình dạng") của các mảng mà mô hình của bạn sử dụng.

Nếu bạn tự tạo mô hình hoặc nếu định dạng đầu vào và đầu ra của mô hình đã được ghi lại, thì có thể bạn đã có thông tin này. Nếu không biết hình dạng và loại dữ liệu của đầu vào và đầu ra của mô hình, bạn có thể sử dụng trình thông dịch TensorFlow Lite để kiểm tra mô hình của mình. Ví dụ:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Kết quả ví dụ:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Chạy trình phiên dịch

Sau khi xác định định dạng đầu vào và đầu ra của mô hình, hãy lấy dữ liệu đầu vào và thực hiện mọi phép biến đổi đối với dữ liệu cần thiết để có được dữ liệu đầu vào có hình dạng phù hợp với mô hình.

Ví dụ: nếu có mô hình phân loại hình ảnh với hình dạng đầu vào là các giá trị dấu phẩy động [1 224 224 3], bạn có thể tạo dữ liệu đầu vào ByteBuffer từ đối tượng Bitmap như trong ví dụ sau:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Sau đó, hãy phân bổ một ByteBuffer đủ lớn để chứa dữ liệu đầu ra của mô hình và truyền vùng đệm đầu vào và vùng đệm đầu ra đến phương thức run() của trình thông dịch TensorFlow Lite. Ví dụ: đối với hình dạng đầu ra của giá trị dấu phẩy động [1 1000]:

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Cách bạn sử dụng dữ liệu đầu ra phụ thuộc vào mô hình bạn đang sử dụng.

Ví dụ: nếu đang thực hiện phân loại, thì ở bước tiếp theo, bạn có thể ánh xạ các chỉ mục của kết quả với các nhãn đại diện cho chúng:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Phụ lục: Bảo mật mô hình

Bất kể bạn cung cấp các mô hình TensorFlow Lite cho công nghệ học máy Firebase bằng cách nào, công nghệ học máy Firebase đều lưu trữ các mô hình đó ở định dạng protobuf được chuyển đổi tuần tự tiêu chuẩn trong bộ nhớ cục bộ.

Về mặt lý thuyết, điều này có nghĩa là bất kỳ ai cũng có thể sao chép mô hình của bạn. Tuy nhiên, trong thực tế, hầu hết các mô hình đều dành riêng cho ứng dụng và bị làm rối mã nguồn bởi các tính năng tối ưu hoá, dẫn đến rủi ro tương tự như các đối thủ cạnh tranh sẽ tháo rời và sử dụng lại mã của bạn. Tuy nhiên, bạn nên lưu ý rủi ro này trước khi sử dụng mô hình tuỳ chỉnh trong ứng dụng của mình.

Trên API Android cấp 21 (Lollipop) trở lên, mô hình sẽ được tải xuống một thư mục bị loại trừ khỏi tính năng sao lưu tự động.

Trên API Android cấp 20 trở xuống, mô hình được tải xuống thư mục có tên là com.google.firebase.ml.custom.models trong bộ nhớ trong riêng tư của ứng dụng. Nếu đã bật tính năng sao lưu tệp bằng BackupAgent, thì bạn có thể chọn loại trừ thư mục này.