firebase-ml-model-interpreter
라이브러리 22.0.2 버전에는 기기에서의 커스텀 모델 위치를 가져오는 새로운 getLatestModelFile()
메서드가 도입되었습니다. 이 메서드를 사용하면 TensorFlow Lite Interpreter
객체를 직접 인스턴스화할 수 있습니다. 이 객체는 FirebaseModelInterpreter
래퍼 대신 사용할 수 있습니다.
앞으로는 이 방법이 권장됩니다. TensorFlow Lite 인터프리터 버전이 더 이상 Firebase 라이브러리 버전과 결합되어 있지 않으므로 필요할 경우 새 버전의 TensorFlow Lite로 자유롭게 업그레이드하거나 커스텀 TensorFlow Lite 빌드를 더 간편하게 사용할 수 있습니다.
이 페이지에서는 FirebaseModelInterpreter
에서 TensorFlow Lite Interpreter
로 마이그레이션하는 방법을 보여줍니다.
1. 프로젝트 종속 항목 업데이트
firebase-ml-model-interpreter
라이브러리 22.0.2 이상 버전과 tensorflow-lite
라이브러리를 포함하도록 프로젝트의 종속 항목을 업데이트합니다.
이전
implementation("com.google.firebase:firebase-ml-model-interpreter:22.0.1")
이후
implementation("com.google.firebase:firebase-ml-model-interpreter:22.0.2")
implementation("org.tensorflow:tensorflow-lite:2.0.0")
2. FirebaseModelInterpreter 대신 TensorFlow Lite 인터프리터 만들기
FirebaseModelInterpreter
를 만드는 대신 getLatestModelFile()
로 기기에서의 모델 위치를 가져와서 TensorFlow Lite Interpreter
를 만드는 데 사용합니다.
이전
Kotlin+KTX
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val options = FirebaseModelInterpreterOptions.Builder(remoteModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)
Java
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelInterpreterOptions options =
new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
이후
Kotlin+KTX
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener { task ->
val modelFile = task.getResult()
if (modelFile != null) {
// Instantiate an org.tensorflow.lite.Interpreter object.
interpreter = Interpreter(modelFile)
}
}
Java
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener(new OnCompleteListener<File>() {
@Override
public void onComplete(@NonNull Task<File> task) {
File modelFile = task.getResult();
if (modelFile != null) {
// Instantiate an org.tensorflow.lite.Interpreter object.
Interpreter interpreter = new Interpreter(modelFile);
}
}
});
3. 입력 및 출력 준비 코드 업데이트
FirebaseModelInterpreter
를 사용하는 경우 인터프리터를 실행할 때 인터프리터에 FirebaseModelInputOutputOptions
객체를 전달하여 모델의 입력 및 출력 모양을 지정하게 됩니다.
TensorFlow Lite 인터프리터의 경우 모델의 입력 및 출력에 적합한 크기의 ByteBuffer
객체를 할당합니다.
예를 들어 모델의 입력 모양이 [1 224 224 3] float
값이고 출력 모양이 [1 1000] float
인 경우 다음과 같이 변경합니다.
이전
Kotlin+KTX
val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 1000))
.build()
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
// Then populate with input data.
val inputs = FirebaseModelInputs.Builder()
.add(input)
.build()
interpreter.run(inputs, inputOutputOptions)
.addOnSuccessListener { outputs ->
// ...
}
.addOnFailureListener {
// Task failed with an exception.
// ...
}
Java
FirebaseModelInputOutputOptions inputOutputOptions =
new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
.build();
float[][][][] input = new float[1][224][224][3];
// Then populate with input data.
FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
.add(input)
.build();
interpreter.run(inputs, inputOutputOptions)
.addOnSuccessListener(
new OnSuccessListener<FirebaseModelOutputs>() {
@Override
public void onSuccess(FirebaseModelOutputs result) {
// ...
}
})
.addOnFailureListener(
new OnFailureListener() {
@Override
public void onFailure(@NonNull Exception e) {
// Task failed with an exception
// ...
}
});
이후
Kotlin+KTX
val inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val inputBuffer = ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder())
// Then populate with input data.
val outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val outputBuffer = ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder())
interpreter.run(inputBuffer, outputBuffer)
Java
int inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer inputBuffer =
ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder());
// Then populate with input data.
int outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer outputBuffer =
ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder());
interpreter.run(inputBuffer, outputBuffer);
4. 출력 처리 코드 업데이트
마지막으로 FirebaseModelOutputs
객체의 getOutput()
메서드를 사용하여 모델의 출력을 가져오는 대신 ByteBuffer
출력을 사용 사례에 적합한 구조로 변환합니다.
예를 들어 분류를 수행하는 경우 다음과 같이 변경할 수 있습니다.
이전
Kotlin+KTX
val output = result.getOutput(0)
val probabilities = output[0]
try {
val reader = BufferedReader(InputStreamReader(assets.open("custom_labels.txt")))
for (probability in probabilities) {
val label: String = reader.readLine()
println("$label: $probability")
}
} catch (e: IOException) {
// File not found?
}
Java
float[][] output = result.getOutput(0);
float[] probabilities = output[0];
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(getAssets().open("custom_labels.txt")));
for (float probability : probabilities) {
String label = reader.readLine();
Log.i(TAG, String.format("%s: %1.4f", label, probability));
}
} catch (IOException e) {
// File not found?
}
이후
Kotlin+KTX
modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
val reader = BufferedReader(
InputStreamReader(assets.open("custom_labels.txt")))
for (i in probabilities.capacity()) {
val label: String = reader.readLine()
val probability = probabilities.get(i)
println("$label: $probability")
}
} catch (e: IOException) {
// File not found?
}
Java
modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(getAssets().open("custom_labels.txt")));
for (int i = 0; i < probabilities.capacity(); i++) {
String label = reader.readLine();
float probability = probabilities.get(i);
Log.i(TAG, String.format("%s: %1.4f", label, probability));
}
} catch (IOException e) {
// File not found?
}