欢迎参加我们将于 2022 年 10 月 18 日举办的 Firebase 峰会(线上线下同时进行),了解 Firebase 如何帮助您加快应用开发速度、满怀信心地发布应用并在之后需要时轻松地扩大应用规模。立即报名

Di chuyển từ API mô hình tùy chỉnh kế thừa

Phiên bản 22.0.2 của thư viện firebase-ml-model-interpreter giới thiệu phương thức getLatestModelFile() mới, lấy vị trí trên thiết bị của các mô hình tùy chỉnh. Bạn có thể sử dụng phương pháp này để khởi tạo trực tiếp đối tượng Interpreter TensorFlow Lite, đối tượng này bạn có thể sử dụng thay vì trình bao bọc FirebaseModelInterpreter .

Trong tương lai, đây là cách tiếp cận ưa thích. Vì phiên bản trình thông dịch TensorFlow Lite không còn được kết hợp với phiên bản thư viện Firebase, bạn có thể linh hoạt hơn để nâng cấp lên phiên bản mới của TensorFlow Lite khi bạn muốn hoặc dễ dàng sử dụng các bản dựng TensorFlow Lite tùy chỉnh.

Trang này cho biết cách bạn có thể chuyển từ sử dụng FirebaseModelInterpreter sang TensorFlow Lite Interpreter .

1. Cập nhật các phụ thuộc của dự án

Cập nhật các phần phụ thuộc của dự án của bạn để bao gồm phiên bản 22.0.2 của thư viện firebase-ml-model-interpreter -preter (hoặc mới hơn) và thư viện tensorflow-lite :

Trước

implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.1'

Sau

implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.2'
implementation 'org.tensorflow:tensorflow-lite:2.0.0'

2. Tạo trình thông dịch TensorFlow Lite thay vì trình thông dịch FirebaseModelInterpreter

Thay vì tạo FirebaseModelInterpreter , hãy lấy vị trí của mô hình trên thiết bị bằng getLatestModelFile() và sử dụng nó để tạo TensorFlow Lite Interpreter .

Trước

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelInterpreterOptions options =
        new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val options = FirebaseModelInterpreterOptions.Builder(remoteModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Sau

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    // Instantiate an org.tensorflow.lite.Interpreter object.
                    Interpreter interpreter = new Interpreter(modelFile);
                }
            }
        });

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.getResult()
        if (modelFile != null) {
            // Instantiate an org.tensorflow.lite.Interpreter object.
            interpreter = Interpreter(modelFile)
        }
    }

3. Cập nhật mã chuẩn bị đầu vào và đầu ra

Với FirebaseModelInterpreter , bạn chỉ định hình dạng đầu vào và đầu ra của mô hình bằng cách chuyển đối tượng FirebaseModelInputOutputOptions tới trình thông dịch khi bạn chạy nó.

Đối với trình thông dịch TensorFlow Lite, thay vào đó bạn phân bổ các đối tượng ByteBuffer với kích thước phù hợp cho đầu vào và đầu ra của mô hình của bạn.

Ví dụ: nếu mô hình của bạn có hình dạng đầu vào là [1 224 224 3] giá trị float và hình dạng đầu ra của [1 1000] giá trị float , hãy thực hiện những thay đổi sau:

Trước

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
                .build();

float[][][][] input = new float[1][224][224][3];
// Then populate with input data.

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)
        .build();

interpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
    .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
    .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 1000))
    .build()

val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
// Then populate with input data.

val inputs = FirebaseModelInputs.Builder()
    .add(input)
    .build()

interpreter.run(inputs, inputOutputOptions)
    .addOnSuccessListener { outputs ->
        // ...
    }
    .addOnFailureListener {
        // Task failed with an exception.
        // ...
    }

Sau

Java

int inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer inputBuffer =
        ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder());
// Then populate with input data.

int outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer outputBuffer =
        ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder());

interpreter.run(inputBuffer, outputBuffer);

Kotlin+KTX

val inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val inputBuffer = ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder())
// Then populate with input data.

val outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val outputBuffer = ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder())

interpreter.run(inputBuffer, outputBuffer)

4. Cập nhật mã xử lý đầu ra

Cuối cùng, thay vì nhận đầu ra của mô hình bằng phương thức getOutput() của đối tượng FirebaseModelOutputs , hãy chuyển đổi đầu ra ByteBuffer sang bất kỳ cấu trúc nào thuận tiện cho trường hợp sử dụng của bạn.

Ví dụ: nếu bạn đang phân loại, bạn có thể thực hiện các thay đổi như sau:

Trước

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];
try {
    BufferedReader reader = new BufferedReader(
          new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (float probability : probabilities) {
        String label = reader.readLine();
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

val output = result.getOutput(0)
val probabilities = output[0]
try {
    val reader = BufferedReader(InputStreamReader(assets.open("custom_labels.txt")))
    for (probability in probabilities) {
        val label: String = reader.readLine()
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Sau

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}