A biblioteca de tarefas do TensorFlow Lite fornece APIs nativas/Android/iOS pré-criadas sobre a mesma infraestrutura que abstrai o TensorFlow. Você pode estender a infraestrutura da API de tarefas para criar APIs personalizadas se seu modelo não for compatível com as bibliotecas de tarefas existentes.
Visão geral
A infraestrutura da API de tarefas tem uma estrutura de duas camadas: a camada C++ inferior que encapsula o tempo de execução TFLite nativo e a camada superior Java/ObjC que se comunica com a camada C++ por meio de JNI ou wrapper nativo.
A implementação de toda a lógica do TensorFlow apenas em C++ minimiza o custo, maximiza o desempenho da inferência e simplifica o fluxo de trabalho geral entre plataformas.
Para criar uma classe Task, estenda a BaseTaskApi para fornecer lógica de conversão entre a interface do modelo TFLite e a interface da API Task e, em seguida, use os utilitários Java/ObjC para criar as APIs correspondentes. Com todos os detalhes do TensorFlow ocultos, você pode implantar o modelo TFLite em seus aplicativos sem nenhum conhecimento de machine learning.
O TensorFlow Lite fornece algumas APIs pré-criadas para as tarefas mais populares do Vision e NLP . Você pode criar suas próprias APIs para outras tarefas usando a infraestrutura da API de tarefas.
Crie sua própria API com a infraestrutura da API de tarefas
API C++
Todos os detalhes do TFLite são implementados na API nativa. Crie um objeto de API usando uma das funções de fábrica e obtenha resultados de modelo chamando funções definidas na interface.
Uso de amostra
Aqui está um exemplo usando o C++ BertQuestionAnswerer
para MobileBert .
char kBertModelPath[] = "path/to/model.tflite";
// Create the API from a model file
std::unique_ptr<BertQuestionAnswerer> question_answerer =
BertQuestionAnswerer::CreateFromFile(kBertModelPath);
char kContext[] = ...; // context of a question to be answered
char kQuestion[] = ...; // question to be answered
// ask a question
std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
// answers[0].text is the best answer
Construindo a API
Para construir um objeto de API, você deve fornecer as seguintes informações estendendo BaseTaskApi
Determine a E/S da API - Sua API deve expor entradas/saídas semelhantes em diferentes plataformas. por exemplo
BertQuestionAnswerer
recebe duas strings(std::string& context, std::string& question)
como entrada e gera um vetor de possíveis respostas e probabilidades comostd::vector<QaAnswer>
. Isso é feito especificando os tipos correspondentes no parâmetro de modelo doBaseTaskApi
. Com os parâmetros do modelo especificados, a funçãoBaseTaskApi::Infer
terá os tipos corretos de entrada/saída. Essa função pode ser chamada diretamente pelos clientes da API, mas é uma boa prática envolvê-la dentro de uma função específica do modelo, neste caso,BertQuestionAnswerer::Answer
.class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Model specific function delegating calls to BaseTaskApi::Infer std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) { return Infer(context, question).value(); } }
Fornecer lógica de conversão entre API I/O e tensor de entrada/saída do modelo - Com tipos de entrada e saída especificados, as subclasses também precisam implementar as funções tipadas
BaseTaskApi::Preprocess
eBaseTaskApi::Postprocess
. As duas funções fornecem entradas e saídas do TFLiteFlatBuffer
. A subclasse é responsável por atribuir valores da API I/O aos tensores de I/O. Veja o exemplo de implementação completo emBertQuestionAnswerer
.class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Convert API input into tensors absl::Status BertQuestionAnswerer::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model const std::string& context, const std::string& query // InputType of the API ) { // Perform tokenization on input strings ... // Populate IDs, Masks and SegmentIDs to corresponding input tensors PopulateTensor(input_ids, input_tensors[0]); PopulateTensor(input_mask, input_tensors[1]); PopulateTensor(segment_ids, input_tensors[2]); return absl::OkStatus(); } // Convert output tensors into API output StatusOr<std::vector<QaAnswer>> // OutputType BertQuestionAnswerer::Postprocess( const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model ) { // Get start/end logits of prediction result from output tensors std::vector<float> end_logits; std::vector<float> start_logits; // output_tensors[0]: end_logits FLOAT[1, 384] PopulateVector(output_tensors[0], &end_logits); // output_tensors[1]: start_logits FLOAT[1, 384] PopulateVector(output_tensors[1], &start_logits); ... std::vector<QaAnswer::Pos> orig_results; // Look up the indices from vocabulary file and build results ... return orig_results; } }
Crie funções de fábrica da API - Um arquivo de modelo e um
OpResolver
são necessários para inicializar otflite::Interpreter
.TaskAPIFactory
fornece funções de utilitário para criar instâncias BaseTaskApi.Você também deve fornecer quaisquer arquivos associados ao modelo. por exemplo,
BertQuestionAnswerer
também pode ter um arquivo adicional para o vocabulário do seu tokenizer.class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Factory function to create the API instance StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateBertQuestionAnswerer( const std::string& path_to_model, // model to passed to TaskApiFactory const std::string& path_to_vocab // additional model specific files ) { // Creates an API object by calling one of the utils from TaskAPIFactory std::unique_ptr<BertQuestionAnswerer> api_to_init; ASSIGN_OR_RETURN( api_to_init, core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( path_to_model, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), kNumLiteThreads)); // Perform additional model specific initializations // In this case building a vocabulary vector from the vocab file. api_to_init->InitializeVocab(path_to_vocab); return api_to_init; } }
API do Android
Crie APIs Android definindo a interface Java/Kotlin e delegando a lógica para a camada C++ por meio de JNI. A API do Android requer que a API nativa seja criada primeiro.
Uso de amostra
Aqui está um exemplo usando Java BertQuestionAnswerer
para MobileBert .
String BERT_MODEL_FILE = "path/to/model.tflite";
String VOCAB_FILE = "path/to/vocab.txt";
// Create the API from a model file and vocabulary file
BertQuestionAnswerer bertQuestionAnswerer =
BertQuestionAnswerer.createBertQuestionAnswerer(
ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);
String CONTEXT = ...; // context of a question to be answered
String QUESTION = ...; // question to be answered
// ask a question
List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
// answers.get(0).text is the best answer
Construindo a API
Semelhante às APIs nativas, para construir um objeto de API, o cliente precisa fornecer as seguintes informações estendendo BaseTaskApi
, que fornece manipulações JNI para todas as APIs de tarefas Java.
Determinar a API I/O - Isso geralmente espelha as interfaces nativas. por exemplo
BertQuestionAnswerer
recebe(String context, String question)
como entrada e saídaList<QaAnswer>
. A implementação chama uma função nativa privada com assinatura semelhante, exceto que tem um parâmetro adicionallong nativeHandle
, que é o ponteiro retornado de C++.class BertQuestionAnswerer extends BaseTaskApi { public List<QaAnswer> answer(String context, String question) { return answerNative(getNativeHandle(), context, question); } private static native List<QaAnswer> answerNative( long nativeHandle, // C++ pointer String context, String question // API I/O ); }
Criar funções de fábrica da API - Isso também espelha as funções de fábrica nativas, exceto que as funções de fábrica do Android também precisam usar
Context
para acesso a arquivos. A implementação chama um dos utilitários emTaskJniUtils
para construir o objeto de API C++ correspondente e passar seu ponteiro para o construtorBaseTaskApi
.class BertQuestionAnswerer extends BaseTaskApi { private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "bert_question_answerer_jni"; // Extending super constructor by providing the // native handle(pointer of corresponding C++ API object) private BertQuestionAnswerer(long nativeHandle) { super(nativeHandle); } public static BertQuestionAnswerer createBertQuestionAnswerer( Context context, // Accessing Android files String pathToModel, String pathToVocab) { return new BertQuestionAnswerer( // The util first try loads the JNI module with name // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files, // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers // is called with the buffer for a C++ API object pointer TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( context, BertQuestionAnswerer::initJniWithBertByteBuffers, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, pathToModel, pathToVocab)); } // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer. // returns C++ API object pointer casted to long private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers); }
Implemente o módulo JNI para funções nativas - Todos os métodos nativos Java são implementados chamando uma função nativa correspondente do módulo JNI. As funções de fábrica criariam um objeto de API nativo e retornariam seu ponteiro como um tipo longo para Java. Em chamadas posteriores para a API Java, o ponteiro de tipo longo é passado de volta para JNI e convertido de volta para o objeto de API nativo. Os resultados da API nativa são então convertidos novamente em resultados Java.
Por exemplo, é assim que bert_question_answerer_jni é implementado.
// Implements BertQuestionAnswerer::initJniWithBertByteBuffers extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers( JNIEnv* env, jclass thiz, jobjectArray model_buffers) { // Convert Java ByteBuffer object into a buffer that can be read by native factory functions absl::string_view model = GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); // Creates the native API object absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status = BertQuestionAnswerer::CreateFromBuffer( model.data(), model.size()); if (status.ok()) { // converts the object pointer to jlong and return to Java. return reinterpret_cast<jlong>(status->release()); } else { return kInvalidPointer; } } // Implements BertQuestionAnswerer::answerNative extern "C" JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative( JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) { // Convert long to native API object pointer QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle); // Calls the native API std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context), JStringToString(env, question)); // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>) jclass qa_answer_class = env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer"); jmethodID qa_answer_ctor = env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V"); return ConvertVectorToArrayList<QaAnswer>( env, results, [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) { jstring text = env->NewStringUTF(ans.text.data()); jobject qa_answer = env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start, ans.pos.end, ans.pos.logit); env->DeleteLocalRef(text); return qa_answer; }); } // Implements BaseTaskApi::deinitJni by delete the native object extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni( JNIEnv* env, jobject thiz, jlong native_handle) { delete reinterpret_cast<QuestionAnswerer*>(native_handle); }
API do iOS
Crie APIs iOS envolvendo um objeto API nativo em um objeto API ObjC. O objeto de API criado pode ser usado em ObjC ou Swift. A API do iOS requer que a API nativa seja criada primeiro.
Uso de amostra
Aqui está um exemplo usando ObjC TFLBertQuestionAnswerer
para MobileBert em Swift.
static let mobileBertModelPath = "path/to/model.tflite";
// Create the API from a model file and vocabulary file
let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
modelPath: mobileBertModelPath)
static let context = ...; // context of a question to be answered
static let question = ...; // question to be answered
// ask a question
let answers = mobileBertAnswerer.answer(
context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
// answers.[0].text is the best answer
Construindo a API
A API do iOS é um wrapper ObjC simples sobre a API nativa. Crie a API seguindo as etapas abaixo:
Definir o wrapper ObjC - Defina uma classe ObjC e delegue as implementações ao objeto de API nativo correspondente. Observe que as dependências nativas só podem aparecer em um arquivo .mm devido à incapacidade do Swift de interoperar com C++.
- arquivo .h
@interface TFLBertQuestionAnswerer : NSObject // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath vocabPath:(NSString*)vocabPath NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:)); // Delegate calls to the native BertQuestionAnswerer::Answer - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context question:(NSString*)question NS_SWIFT_NAME(answer(context:question:)); }
- arquivo .mm
using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer; @implementation TFLBertQuestionAnswerer { // define an iVar for the native API object std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer; } // Initialize the native API object + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath vocabPath:(NSString *)vocabPath { absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer = BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath), MakeString(vocabPath)); _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer"); return [[TFLBertQuestionAnswerer alloc] initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())]; } // Calls the native API and converts C++ results into ObjC results - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question { std::vector<QaAnswerCPP> results = _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question)); return [self arrayFromVector:results]; } }
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2021-11-05 UTC.