Utiliser un modèle TensorFlow Lite personnalisé sur les plates-formes Apple

Si votre application utilise des modèles TensorFlow Lite personnalisés, vous pouvez utiliser Firebase ML pour déployer vos modèles. En déployant des modèles avec Firebase, vous pouvez réduire la taille de téléchargement initiale de votre application et mettre à jour les modèles ML de votre application sans publier une nouvelle version de votre application. Et, avec Remote Config et A/B Testing, vous pouvez proposer de manière dynamique différents modèles à différents groupes d'utilisateurs.

Conditions préalables

  • La bibliothèque MLModelDownloader est uniquement disponible pour Swift.
  • TensorFlow Lite fonctionne uniquement sur les appareils utilisant iOS 9 et versions ultérieures.

Modèles TensorFlow Lite

Les modèles TensorFlow Lite sont des modèles ML optimisés pour fonctionner sur des appareils mobiles. Pour obtenir un modèle TensorFlow Lite :

Avant que tu commences

Pour utiliser TensorFlowLite avec Firebase, vous devez utiliser CocoaPods car TensorFlowLite ne prend actuellement pas en charge l'installation avec Swift Package Manager. Consultez le guide d'installation de CocoaPods pour obtenir des instructions sur la façon d'installer MLModelDownloader .

Une fois installés, importez Firebase et TensorFlowLite afin de les utiliser.

Rapide

import FirebaseMLModelDownloader
import TensorFlowLite

1. Déployez votre modèle

Déployez vos modèles TensorFlow personnalisés à l'aide de la console Firebase ou des SDK Firebase Admin Python et Node.js. Voir Déployer et gérer des modèles personnalisés .

Après avoir ajouté un modèle personnalisé à votre projet Firebase, vous pouvez référencer le modèle dans vos applications en utilisant le nom que vous avez spécifié. À tout moment, vous pouvez déployer un nouveau modèle TensorFlow Lite et télécharger le nouveau modèle sur les appareils des utilisateurs en appelant getModel() (voir ci-dessous).

2. Téléchargez le modèle sur l'appareil et initialisez un interpréteur TensorFlow Lite

Pour utiliser votre modèle TensorFlow Lite dans votre application, utilisez d'abord le SDK Firebase ML pour télécharger la dernière version du modèle sur l'appareil.

Pour démarrer le téléchargement du modèle, appelez la méthode getModel() du téléchargeur de modèles, en spécifiant le nom que vous avez attribué au modèle lorsque vous l'avez téléchargé, si vous souhaitez toujours télécharger le dernier modèle et les conditions dans lesquelles vous souhaitez autoriser le téléchargement.

Vous pouvez choisir parmi trois comportements de téléchargement :

Type de téléchargement Description
localModel Obtenez le modèle local de l'appareil. Si aucun modèle local n'est disponible, cela se comporte comme latestModel . Utilisez ce type de téléchargement si vous ne souhaitez pas rechercher les mises à jour du modèle. Par exemple, vous utilisez Remote Config pour récupérer les noms de modèles et vous téléchargez toujours des modèles sous de nouveaux noms (recommandé).
localModelUpdateInBackground Obtenez le modèle local de l'appareil et commencez à mettre à jour le modèle en arrière-plan. Si aucun modèle local n'est disponible, cela se comporte comme latestModel .
latestModel Obtenez le dernier modèle. Si le modèle local est la dernière version, renvoie le modèle local. Sinon, téléchargez le dernier modèle. Ce comportement se bloquera jusqu'à ce que la dernière version soit téléchargée (non recommandé). Utilisez ce comportement uniquement dans les cas où vous avez explicitement besoin de la dernière version.

Vous devez désactiver les fonctionnalités liées au modèle (par exemple, griser ou masquer une partie de votre interface utilisateur) jusqu'à ce que vous confirmiez que le modèle a été téléchargé.

Rapide

let conditions = ModelDownloadConditions(allowsCellularAccess: false)
ModelDownloader.modelDownloader()
    .getModel(name: "your_model",
              downloadType: .localModelUpdateInBackground,
              conditions: conditions) { result in
        switch (result) {
        case .success(let customModel):
            do {
                // Download complete. Depending on your app, you could enable the ML
                // feature, or switch from the local model to the remote model, etc.

                // The CustomModel object contains the local path of the model file,
                // which you can use to instantiate a TensorFlow Lite interpreter.
                let interpreter = try Interpreter(modelPath: customModel.path)
            } catch {
                // Error. Bad model file?
            }
        case .failure(let error):
            // Download was unsuccessful. Don't enable ML features.
            print(error)
        }
}

De nombreuses applications démarrent la tâche de téléchargement dans leur code d'initialisation, mais vous pouvez le faire à tout moment avant de devoir utiliser le modèle.

3. Effectuer une inférence sur les données d'entrée

Obtenez les formes d'entrée et de sortie de votre modèle

L'interpréteur de modèle TensorFlow Lite prend en entrée et produit en sortie un ou plusieurs tableaux multidimensionnels. Ces tableaux contiennent des valeurs byte , int , long ou float . Avant de pouvoir transmettre des données à un modèle ou utiliser son résultat, vous devez connaître le nombre et les dimensions (« forme ») des tableaux utilisés par votre modèle.

Si vous avez créé le modèle vous-même ou si le format d'entrée et de sortie du modèle est documenté, vous disposez peut-être déjà de ces informations. Si vous ne connaissez pas la forme et le type de données des entrées et sorties de votre modèle, vous pouvez utiliser l'interpréteur TensorFlow Lite pour inspecter votre modèle. Par exemple:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Exemple de sortie :

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Exécuter l'interprète

Après avoir déterminé le format de l'entrée et de la sortie de votre modèle, récupérez vos données d'entrée et effectuez toutes les transformations sur les données nécessaires pour obtenir une entrée de la forme appropriée pour votre modèle.

Par exemple, si votre modèle traite des images et que votre modèle a des dimensions d'entrée de [1, 224, 224, 3] valeurs à virgule flottante, vous devrez peut-être redimensionner les valeurs de couleur de l'image selon une plage à virgule flottante, comme dans l'exemple suivant. :

Rapide

let image: CGImage = // Your input image
guard let context = CGContext(
  data: nil,
  width: image.width, height: image.height,
  bitsPerComponent: 8, bytesPerRow: image.width * 4,
  space: CGColorSpaceCreateDeviceRGB(),
  bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
  return false
}

context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }

var inputData = Data()
for row in 0 ..&lt; 224 {
  for col in 0 ..&lt; 224 {
    let offset = 4 * (row * context.width + col)
    // (Ignore offset 0, the unused alpha channel)
    let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
    let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
    let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)

    // Normalize channel values to [0.0, 1.0]. This requirement varies
    // by model. For example, some models might require values to be
    // normalized to the range [-1.0, 1.0] instead, and others might
    // require fixed-point values or the original bytes.
    var normalizedRed = Float32(red) / 255.0
    var normalizedGreen = Float32(green) / 255.0
    var normalizedBlue = Float32(blue) / 255.0

    // Append normalized values to Data object in RGB order.
    let elementSize = MemoryLayout.size(ofValue: normalizedRed)
    var bytes = [UInt8](repeating: 0, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedRed, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedGreen, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&ammp;bytes, &amp;normalizedBlue, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
  }
}

Ensuite, copiez votre entrée NSData dans l'interpréteur et exécutez-le :

Rapide

try interpreter.allocateTensors()
try interpreter.copy(inputData, toInputAt: 0)
try interpreter.invoke()

Vous pouvez obtenir la sortie du modèle en appelant la méthode output(at:) de l’interpréteur. La façon dont vous utilisez la sortie dépend du modèle que vous utilisez.

Par exemple, si vous effectuez une classification, vous pouvez ensuite mapper les index du résultat aux étiquettes qu'ils représentent :

Rapide

let output = try interpreter.output(at: 0)
let probabilities =
        UnsafeMutableBufferPointer<Float32>.allocate(capacity: 1000)
output.data.copyBytes(to: probabilities)

guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }

for i in labels.indices {
    print("\(labels[i]): \(probabilities[i])")
}

Annexe : sécurité du modèle

Quelle que soit la manière dont vous mettez vos modèles TensorFlow Lite à la disposition de Firebase ML, Firebase ML les stocke au format protobuf sérialisé standard dans le stockage local.

En théorie, cela signifie que n’importe qui peut copier votre modèle. Cependant, dans la pratique, la plupart des modèles sont si spécifiques à l'application et obscurcis par les optimisations que le risque est similaire à celui de concurrents désassemblant et réutilisant votre code. Néanmoins, vous devez être conscient de ce risque avant d'utiliser un modèle personnalisé dans votre application.