TFX per dispositivi mobili

Introduzione

Questa guida dimostra come Tensorflow Extended (TFX) può creare e valutare modelli di machine learning che verranno distribuiti sul dispositivo. TFX ora fornisce il supporto nativo per TFLite , che rende possibile eseguire un'inferenza altamente efficiente sui dispositivi mobili.

Questa guida ti guida attraverso le modifiche che possono essere apportate a qualsiasi pipeline per generare e valutare i modelli TFLite. Forniamo qui un esempio completo, dimostrando come TFX può addestrare e valutare i modelli TFLite addestrati dal set di dati MNIST . Inoltre, mostriamo come la stessa pipeline può essere utilizzata per esportare simultaneamente sia il SavedModel standard basato su Keras sia quello TFLite, consentendo agli utenti di confrontare la qualità dei due.

Presumiamo che tu abbia familiarità con TFX, i nostri componenti e le nostre pipeline. In caso contrario, consulta questo tutorial .

Passi

Sono necessari solo due passaggi per creare e valutare un modello TFLite in TFX. Il primo passo è invocare il riscrittore TFLite nel contesto del TFX Trainer per convertire il modello TensorFlow addestrato in uno TFLite. Il secondo passo è configurare l'Evaluator per valutare i modelli TFLite. Ora ne discuteremo ciascuno a turno.

Invocare il riscrittore TFLite all'interno del Trainer.

Il TFX Trainer prevede che nel file del modulo venga specificato un run_fn definito dall'utente. Questo run_fn definisce il modello da addestrare, lo addestra per il numero specificato di iterazioni ed esporta il modello addestrato.

Nel resto di questa sezione forniamo frammenti di codice che mostrano le modifiche necessarie per richiamare il riscrittore TFLite ed esportare un modello TFLite. Tutto questo codice si trova nel run_fn del modulo MNIST TFLite .

Come mostrato nel codice seguente, dobbiamo prima creare una firma che prenda come input un Tensor per ogni caratteristica. Tieni presente che si tratta di una deviazione dalla maggior parte dei modelli esistenti in TFX, che accettano prototipi tf.Example serializzati come input.

 signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }

Quindi il modello Keras viene salvato come SavedModel nello stesso modo in cui lo è normalmente.

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

Infine, creiamo un'istanza del rewriter TFLite ( tfrw ) e la invochiamo su SavedModel per ottenere il modello TFLite. Memorizziamo questo modello TFLite nel serving_model_dir fornito dal chiamante di run_fn . In questo modo, il modello TFLite viene archiviato nella posizione in cui tutti i componenti TFX downstream si aspetteranno di trovare il modello.

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

Valutazione del modello TFLite.

Il valutatore TFX offre la possibilità di analizzare modelli addestrati per comprenderne la qualità in un'ampia gamma di parametri. Oltre ad analizzare SavedModels, TFX Evaluator è ora in grado di analizzare anche i modelli TFLite.

Il seguente frammento di codice (riprodotto dalla pipeline MNIST ), mostra come configurare un valutatore che analizza un modello TFLite.

  # Informs the evaluator that the model is a TFLite model.
  eval_config_lite.model_specs[0].model_type = 'tf_lite'

  ...

  # Uses TFMA to compute the evaluation statistics over features of a TFLite
  # model.
  model_analyzer_lite = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer_lite.outputs['model'],
      eval_config=eval_config_lite,
  ).with_id('mnist_lite')

Come mostrato sopra, l'unica modifica che dobbiamo apportare è impostare il campo model_type su tf_lite . Non sono necessarie altre modifiche alla configurazione per analizzare il modello TFLite. Indipendentemente dal fatto che venga analizzato un modello TFLite o un SavedModel, l'output Evaluator avrà esattamente la stessa struttura.

Tuttavia, tieni presente che il Valutatore presuppone che il modello TFLite sia salvato in un file denominato tflite all'interno di trainer_lite.outputs['model'].