Criar uma operação

Se você quiser criar uma operação que não seja coberta pela biblioteca TensorFlow existente, recomendamos que você tente primeiro escrever a operação em Python como uma composição de operações ou funções existentes do Python. Se isso não for possível, você pode criar uma operação C++ personalizada. Há vários motivos pelos quais você pode querer criar uma operação C++ personalizada:

  • Não é fácil ou possível expressar sua operação como uma composição de operações existentes.
  • Não é eficiente expressar sua operação como uma composição de primitivas existentes.
  • Você deseja fundir manualmente uma composição de primitivos que um compilador futuro acharia difícil de fundir.

Por exemplo, imagine que você deseja implementar algo como "pooling mediano", semelhante ao operador "MaxPool", mas calculando medianas sobre janelas deslizantes em vez de valores máximos. Fazer isso usando uma composição de operações pode ser possível (por exemplo, usando ExtractImagePatches e TopK), mas pode não ser tão eficiente em desempenho ou memória quanto uma operação nativa onde você pode fazer algo mais inteligente em uma única operação fundida. Como sempre, normalmente vale a pena tentar expressar o que você deseja usando a composição de operadores, optando apenas por adicionar uma nova operação se isso for difícil ou ineficiente.

Para incorporar sua operação personalizada, você precisará:

  1. Registre o novo op em um arquivo C++. O registro de op define uma interface (especificação) para a funcionalidade do op, que é independente da implementação do op. Por exemplo, o registro de op define o nome do op e as entradas e saídas do op. Ele também define a função de forma que é usada para inferência de forma de tensor.
  2. Implemente a operação em C++. A implementação de um op é conhecida como kernel, e é a implementação concreta da especificação que você registrou na Etapa 1. Pode haver vários kernels para diferentes tipos ou arquiteturas de entrada/saída (por exemplo, CPUs, GPUs).
  3. Crie um wrapper Python (opcional). Esse wrapper é a API pública usada para criar a operação em Python. Um wrapper padrão é gerado a partir do registro de operação, que pode ser usado diretamente ou adicionado.
  4. Escreva uma função para calcular gradientes para o op (opcional).
  5. Teste a op. Geralmente fazemos isso em Python por conveniência, mas você também pode testar a operação em C++. Se você definir gradientes, poderá verificá-los com o Python tf.test.compute_gradient_error . Veja relu_op_test.py como um exemplo que testa as funções diretas de operadores do tipo Relu e seus gradientes.

Pré-requisitos

Defina a interface operacional

Você define a interface de uma operação registrando-a no sistema TensorFlow. No registro, você especifica o nome do seu op, suas entradas (tipos e nomes) e saídas (tipos e nomes), bem como docstrings e quaisquer attrs que o op possa exigir.

Para ver como isso funciona, suponha que você queira criar um op que receba um tensor de int32 s e gere uma cópia do tensor, com todos menos o primeiro elemento definido como zero. Para fazer isso, crie um arquivo chamado zero_out.cc . Em seguida, adicione uma chamada à macro REGISTER_OP que define a interface para sua operação:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

Esta ZeroOut recebe um tensor to_zero de inteiros de 32 bits como entrada e gera um tensor zeroed de inteiros de 32 bits. O op também usa uma função de forma para garantir que o tensor de saída tenha a mesma forma que o tensor de entrada. Por exemplo, se a entrada for um tensor de forma [10, 20], então esta função de forma especifica que a forma de saída também é [10, 20].

Implemente o kernel para a operação

Depois de definir a interface, forneça uma ou mais implementações do op. Para criar um desses kernels, crie uma classe que estenda o OpKernel e substitua o método Compute . O método Compute fornece um argumento de context do tipo OpKernelContext* , a partir do qual você pode acessar coisas úteis como os tensores de entrada e saída.

Adicione seu kernel ao arquivo que você criou acima. O kernel pode ser algo assim:

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};

Depois de implementar seu kernel, você o registra no sistema TensorFlow. No registro, você especifica diferentes restrições sob as quais esse kernel será executado. Por exemplo, você pode ter um kernel feito para CPUs e um separado para GPUs.

Para fazer isso para a ZeroOut , adicione o seguinte a zero_out.cc :

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

Kernels de CPU multi-thread

Para escrever um kernel de CPU multithread, a função Shard em work_sharder.h pode ser usada. Esta função fragmenta uma função de computação entre os encadeamentos configurados para serem usados ​​para encadeamento intra-op (consulte intra_op_parallelism_threads em config.proto ).

Kernels de GPU

Um kernel de GPU é implementado em duas partes: o OpKernel e o kernel CUDA e seu código de lançamento.

Às vezes, a implementação do OpKernel é comum entre um kernel de CPU e GPU, como na inspeção de entradas e na alocação de saídas. Nesse caso, uma implementação sugerida é:

  1. Defina o OpKernel modelado no dispositivo e o tipo primitivo do tensor.
  2. Para fazer o cálculo real da saída, a função Compute chama uma estrutura functor modelada.
  3. A especialização desse functor para CPUDevice é definida no mesmo arquivo, mas a especialização para GPUDevice é definida em um arquivo .cu.cc, pois será compilado com o compilador CUDA.

Aqui está um exemplo de implementação.

// kernel_example.h
#ifndef KERNEL_EXAMPLE_H_
#define KERNEL_EXAMPLE_H_

#include <unsupported/Eigen/CXX11/Tensor>

template <typename Device, typename T>
struct ExampleFunctor {
  void operator()(const Device& d, int size, const T* in, T* out);
};

#if GOOGLE_CUDA
// Partially specialize functor for GpuDevice.
template <typename T>
struct ExampleFunctor<Eigen::GpuDevice, T> {
  void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out);
};
#endif

#endif KERNEL_EXAMPLE_H_
// kernel_example.cc
#include "kernel_example.h"

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

REGISTER_OP("Example")
    .Attr("T: numbertype")
    .Input("input: T")
    .Output("input_times_two: T")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

// CPU specialization of actual computation.
template <typename T>
struct ExampleFunctor<CPUDevice, T> {
  void operator()(const CPUDevice& d, int size, const T* in, T* out) {
    for (int i = 0; i < size; ++i) {
      out[i] = 2 * in[i];
    }
  }
};

// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class ExampleOp : public OpKernel {
 public:
  explicit ExampleOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));

    // Do the computation.
    OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
                errors::InvalidArgument("Too many elements in tensor"));
    ExampleFunctor<Device, T>()(
        context->eigen_device<Device>(),
        static_cast<int>(input_tensor.NumElements()),
        input_tensor.flat<T>().data(),
        output_tensor->flat<T>().data());
  }
};

// Register the CPU kernels.
#define REGISTER_CPU(T)                                          \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      ExampleOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(int32);

// Register the GPU kernels.
#ifdef GOOGLE_CUDA
#define REGISTER_GPU(T)                                          \
  /* Declare explicit instantiations in kernel_example.cu.cc. */ \
  extern template class ExampleFunctor<GPUDevice, T>;            \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
      ExampleOp<GPUDevice, T>);
REGISTER_GPU(float);
REGISTER_GPU(int32);
#endif  // GOOGLE_CUDA
// kernel_example.cu.cc
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "kernel_example.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

using namespace tensorflow;

using GPUDevice = Eigen::GpuDevice;

// Define the CUDA kernel.
template <typename T>
__global__ void ExampleCudaKernel(const int size, const T* in, T* out) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
       i += blockDim.x * gridDim.x) {
    out[i] = 2 * __ldg(in + i);
  }
}

// Define the GPU implementation that launches the CUDA kernel.
template <typename T>
void ExampleFunctor<GPUDevice, T>::operator()(
    const GPUDevice& d, int size, const T* in, T* out) {
  // Launch the cuda kernel.
  //
  // See core/util/gpu_kernel_helper.h for example of computing
  // block count and thread_per_block count.
  int block_count = 1024;
  int thread_per_block = 20;
  ExampleCudaKernel<T>
      <<<block_count, thread_per_block, 0, d.stream()>>>(size, in, out);
}

// Explicitly instantiate functors for the types of OpKernels registered.
template struct ExampleFunctor<GPUDevice, float>;
template struct ExampleFunctor<GPUDevice, int32>;

#endif  // GOOGLE_CUDA

Construir a biblioteca de operações

Compile o op usando o compilador do sistema (instalação binária do TensorFlow)

Você deve ser capaz de compilar zero_out.cc com um compilador C++ como g++ ou clang disponível em seu sistema. O pacote PIP binário instala os arquivos de cabeçalho e a biblioteca que você precisa para compilar seu op em locais específicos do sistema. No entanto, a biblioteca python do get_include fornece a função get_include para obter o diretório de cabeçalho, e o diretório get_lib tem um objeto compartilhado para vincular. Aqui estão as saídas dessas funções em uma máquina Ubuntu.

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
'/usr/local/lib/python3.6/site-packages/tensorflow/include'
>>> tf.sysconfig.get_lib()
'/usr/local/lib/python3.6/site-packages/tensorflow'

Supondo que você tenha o g++ instalado, aqui está a sequência de comandos que você pode usar para compilar seu op em uma biblioteca dinâmica.

TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++14 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2

No macOS, o sinalizador adicional "-undefined dynamic_lookup" é necessário ao compilar o arquivo .so .

Nota sobre a versão do gcc >=5 : o gcc usa o novo C++ ABI desde a versão 5 . Os pacotes de pip binários disponíveis no site do TensorFlow são criados com gcc4 que usa a ABI mais antiga. Se você compilar sua biblioteca op com gcc>=5 , adicione -D_GLIBCXX_USE_CXX11_ABI=0 à linha de comando para tornar a biblioteca compatível com a abi mais antiga.

Compile o op usando o bazel (instalação de origem do TensorFlow)

Se você tiver fontes do TensorFlow instaladas, poderá usar o sistema de compilação do TensorFlow para compilar seu op. Coloque um arquivo BUILD com a seguinte regra de construção do Bazel no tensorflow/core/user_ops .

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    name = "zero_out.so",
    srcs = ["zero_out.cc"],
)

Execute o seguinte comando para compilar zero_out.so .

$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so

Para compilar a operação Example , com o Kernel CUDA, você precisa usar o parâmetro gpu_srcs de tf_custom_op_library . Coloque um arquivo BUILD com a seguinte regra de compilação do Bazel em uma nova pasta dentro do tensorflow/core/user_ops (por exemplo, "example_gpu").

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    # kernel_example.cc  kernel_example.cu.cc  kernel_example.h
    name = "kernel_example.so",
    srcs = ["kernel_example.h", "kernel_example.cc"],
    gpu_srcs = ["kernel_example.cu.cc", "kernel_example.h"],
)

Execute o seguinte comando para compilar kernel_example.so .

$ bazel build --config opt //tensorflow/core/user_ops/example_gpu:kernel_example.so

Use o op em Python

A API Python do TensorFlow fornece a função tf.load_op_library para carregar a biblioteca dinâmica e registrar a operação com a estrutura do TensorFlow. load_op_library retorna um módulo Python que contém os wrappers Python para o op e o kernel. Assim, depois de construir o op, você pode fazer o seguinte para executá-lo a partir do Python:

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
print(zero_out_module.zero_out([[1, 2], [3, 4]]).numpy())

# Prints
array([[1, 0], [0, 0]], dtype=int32)

Tenha em mente que a função gerada receberá um nome snake_case (para cumprir com PEP8 ). Portanto, se seu op for chamado ZeroOut nos arquivos C++, a função python será chamada zero_out .

Para tornar o op disponível como uma função regular importável de um módulo Python, talvez seja útil ter a load_op_library import um arquivo de origem Python da seguinte maneira:

import tensorflow as tf

zero_out_module = tf.load_op_library('./zero_out.so')
zero_out = zero_out_module.zero_out

Verifique se a operação funciona

Uma boa maneira de verificar se você implementou sua operação com sucesso é escrever um teste para ela. Crie o arquivo zero_out_op_test.py com o conteúdo:

import tensorflow as tf

class ZeroOutTest(tf.test.TestCase):
  def testZeroOut(self):
    zero_out_module = tf.load_op_library('./zero_out.so')
    with self.test_session():
      result = zero_out_module.zero_out([5, 4, 3, 2, 1])
      self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])

if __name__ == "__main__":
  tf.test.main()

Em seguida, execute seu teste (supondo que você tenha o tensorflow instalado):

$ python zero_out_op_test.py

Crie recursos avançados em sua operação

Agora que você sabe como construir uma operação e implementação básicas (e um tanto restritas), veremos algumas das coisas mais complicadas que você normalmente precisará para construir em sua operação. Isso inclui:

Verificações condicionais e validação

O exemplo acima assumiu que o op aplicado a um tensor de qualquer forma. E se fosse aplicado apenas a vetores? Isso significa adicionar uma verificação à implementação OpKernel acima.

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
                errors::InvalidArgument("ZeroOut expects a 1-D vector."));
    // ...
  }

Isso afirma que a entrada é um vetor e retorna tendo definido o status InvalidArgument se não for. A macro OP_REQUIRES recebe três argumentos:

Alternativamente, se você quiser testar se um objeto Status retornado de alguma função é um erro e, em caso afirmativo, retorná-lo, use OP_REQUIRES_OK . Ambas as macros retornam da função em caso de erro.

Registro de operação

Attrs

Ops pode ter attrs, cujos valores são definidos quando o op é adicionado a um gráfico. Estes são usados ​​para configurar o op, e seus valores podem ser acessados ​​tanto na implementação do kernel quanto nos tipos de entradas e saídas no registro do op. Prefira usar uma entrada em vez de um attr quando possível, pois as entradas são mais flexíveis. Isso ocorre porque attrs são constantes e devem ser definidos no momento da construção do gráfico. Em contraste, as entradas são Tensores cujos valores podem ser dinâmicos; isto é, as entradas podem mudar a cada passo, ser definidas usando um feed, etc. Attrs são usados ​​para coisas que não podem ser feitas com entradas: qualquer configuração que afete a assinatura (número ou tipo de entradas ou saídas) ou que possa' t mudar de passo a passo.

Você define um attr ao registrar o op, especificando seu nome e tipo usando o método Attr , que espera uma especificação da forma:

<name>: <attr-type-expr>

onde <name> começa com uma letra e pode ser composto por caracteres alfanuméricos e sublinhados, e <attr-type-expr> é uma expressão de tipo da forma descrita abaixo .

Por exemplo, se você quiser que o ZeroOut op preserve um índice especificado pelo usuário, em vez de apenas o elemento 0, você pode registrar o op assim:

REGISTER_OP("ZeroOut")
    .Attr("preserve_index: int")
    .Input("to_zero: int32")
    .Output("zeroed: int32");

(Observe que o conjunto de tipos de atributo é diferente do tf.DType usado para entradas e saídas.)

Seu kernel pode então acessar este attr em seu construtor através do parâmetro context :

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
    // Get the index of the value to preserve
    OP_REQUIRES_OK(context,
                   context->GetAttr("preserve_index", &preserve_index_));
    // Check that preserve_index is positive
    OP_REQUIRES(context, preserve_index_ >= 0,
                errors::InvalidArgument("Need preserve_index >= 0, got ",
                                        preserve_index_));
  }
  void Compute(OpKernelContext* context) override {
    // ...
  }
 private:
  int preserve_index_;
};

que pode ser usado no método Compute :

  void Compute(OpKernelContext* context) override {
    // ...

    // We're using saved attr to validate potentially dynamic input
    // So we check that preserve_index is in range
    OP_REQUIRES(context, preserve_index_ < input.dimension(0),
                errors::InvalidArgument("preserve_index out of range"));

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the requested input value
    output_flat(preserve_index_) = input(preserve_index_);
  }

Tipos de atribuição

Os seguintes tipos são suportados em um attr:

  • string : Qualquer sequência de bytes (não precisa ser UTF8).
  • int : Um inteiro com sinal.
  • float : Um número de ponto flutuante.
  • bool : Verdadeiro ou falso.
  • type : um dos valores (não-ref) de DataType .
  • shape : Um TensorShapeProto .
  • list(<type>) : Uma lista de <type> , onde <type> é um dos tipos acima. Observe que list(list(<type>)) é inválido.

Veja também: op_def_builder.cc:FinalizeAttr para uma lista definitiva.

Valores e restrições padrão

Attrs pode ter valores padrão e alguns tipos de attrs podem ter restrições. Para definir um attr com restrições, você pode usar os seguintes <attr-type-expr> s:

{'<string1>', '<string2>'} : O valor deve ser uma string que tenha o valor <string1> ou <string2> . O nome do tipo, string , está implícito quando você usa essa sintaxe. Isso emula um enum:

REGISTER_OP("EnumExample")
    .Attr("e: {'apple', 'orange'}");

{<type1>, <type2>} : o valor é do tipo type e deve ser <type1> ou <type2> , onde <type1> e <type2> são suportados tf.DType . Você não especifica que o tipo do attr é type . Isso está implícito quando você tem uma lista de tipos em {...} . Por exemplo, neste caso o attr t é um tipo que deve ser um int32 , um float ou um bool :

REGISTER_OP("RestrictedTypeExample")
    .Attr("t: {int32, float, bool}");

Existem atalhos para restrições de tipo comuns:

  • numbertype : Tipo type restrito aos tipos numéricos (não string e não bool).
  • realnumbertype : Como numbertype sem tipos complexos.
  • quantizedtype : Como numbertype , mas apenas os tipos de números quantizados.

As listas específicas de tipos permitidos por eles são definidas pelas funções (como NumberTypes() ) em tensorflow/core/framework/types.h . Neste exemplo, o attr t deve ser um dos tipos numéricos:

REGISTER_OP("NumberType")
    .Attr("t: numbertype");

Para esta operação:

tf.number_type(t=tf.int32)  # Valid
tf.number_type(t=tf.bool)   # Invalid

As listas podem ser combinadas com outras listas e tipos únicos. A seguinte op permite que attr t seja qualquer um dos tipos numéricos, ou o tipo bool:

REGISTER_OP("NumberOrBooleanType")
    .Attr("t: {numbertype, bool}");

Para esta operação:

tf.number_or_boolean_type(t=tf.int32)  # Valid
tf.number_or_boolean_type(t=tf.bool)   # Valid
tf.number_or_boolean_type(t=tf.string) # Invalid

int >= <n> : O valor deve ser um int cujo valor seja maior ou igual a <n> , onde <n> é um número natural. Por exemplo, o seguinte registro op especifica que o attr a deve ter um valor que seja pelo menos 2 :

REGISTER_OP("MinIntExample")
    .Attr("a: int >= 2");

list(<type>) >= <n> : Uma lista do tipo <type> cujo comprimento é maior ou igual a <n> . Por exemplo, o registro op a seguir especifica que o attr a é uma lista de tipos ( int32 ou float ), e que deve haver pelo menos 3 deles:

REGISTER_OP("TypeListExample")
    .Attr("a: list({int32, float}) >= 3");

Para definir um valor padrão para um attr (tornando-o opcional no código gerado), adicione = <default> ao final, como em:

REGISTER_OP("AttrDefaultExample")
    .Attr("i: int = 0");

Além disso, uma restrição e um valor padrão podem ser especificados:

REGISTER_OP("AttrConstraintAndDefaultExample")
    .Attr("i: int >= 1 = 1");

A sintaxe suportada do valor padrão é o que seria usado na representação proto da definição GraphDef resultante.

Aqui estão alguns exemplos de como especificar um padrão para todos os tipos:

REGISTER_OP("AttrDefaultExampleForAllTypes")
   .Attr("s: string = 'foo'")
   .Attr("i: int = 0")
   .Attr("f: float = 1.0")
   .Attr("b: bool = true")
   .Attr("ty: type = DT_INT32")
   .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
   .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
   .Attr("l_empty: list(int) = []")
   .Attr("l_int: list(int) = [2, 3, 5, 7]");

Observe em particular que os valores do tipo type usam tf.DType .

Polimorfismo

Tipo de polimorfismo

Para operações que podem receber tipos diferentes como entrada ou produzir tipos de saída diferentes, você pode especificar um attr em um tipo de entrada ou saída no registro de operação. Normalmente, você registraria um OpKernel para cada tipo suportado.

Por exemplo, se você quiser que o ZeroOut op funcione em float s além de int32 s, seu registro de op pode ser assim:

REGISTER_OP("ZeroOut")
    .Attr("T: {float, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Seu registro op agora especifica que o tipo de entrada deve ser float ou int32 e que sua saída será do mesmo tipo, já que ambos têm o tipo T .

Nomeação

Entradas, saídas e attrs geralmente devem receber nomes snake_case. A única exceção são os attrs que são usados ​​como o tipo de uma entrada ou no tipo de uma saída. Esses atributos podem ser inferidos quando o op é adicionado ao gráfico e, portanto, não aparecem na função do op. Por exemplo, esta última definição de ZeroOut irá gerar uma função Python que se parece com:

def zero_out(to_zero, name=None):
  """...
  Args:
    to_zero: A `Tensor`. Must be one of the following types:
        `float32`, `int32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `to_zero`.
  """

Se to_zero for passado por um tensor int32 , então T será automaticamente definido como int32 (bem, na verdade, DT_INT32 ). Esses atributos inferidos recebem nomes em maiúsculas ou CamelCase.

Compare isso com um op que tem um tipo attr que determina o tipo de saída:

REGISTER_OP("StringToNumber")
    .Input("string_tensor: string")
    .Output("output: out_type")
    .Attr("out_type: {float, int32} = DT_FLOAT");
    .Doc(R"doc(
Converts each string in the input Tensor to the specified numeric type.
)doc");

Nesse caso, o usuário deve especificar o tipo de saída, como no Python gerado:

def string_to_number(string_tensor, out_type=None, name=None):
  """Converts each string in the input Tensor to the specified numeric type.

  Args:
    string_tensor: A `Tensor` of type `string`.
    out_type: An optional `tf.DType` from: `tf.float32, tf.int32`.
      Defaults to `tf.float32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `out_type`.
  """
Exemplo de polimorfismo de tipo
#include "tensorflow/core/framework/op_kernel.h"

class ZeroOutInt32Op : public OpKernel {
  // as before
};

class ZeroOutFloatOp : public OpKernel {
 public:
  explicit ZeroOutFloatOp(OpKernelConstruction* context)
      : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<float>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<float>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutInt32Op);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutFloatOp);

Para preservar a compatibilidade com versões anteriores , você deve especificar um valor padrão ao adicionar um attr a uma operação existente:

REGISTER_OP("ZeroOut")
  .Attr("T: {float, int32} = DT_INT32")
  .Input("to_zero: T")
  .Output("zeroed: T")

Digamos que você queira adicionar mais tipos, digamos double :

REGISTER_OP("ZeroOut")
    .Attr("T: {float, double, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Em vez de escrever outro OpKernel com código redundante como acima, muitas vezes você poderá usar um modelo C++. Você ainda terá um registro de kernel (chamada REGISTER_KERNEL_BUILDER ) por sobrecarga.

template <typename T>
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<T>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<T>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<double>("T"),
    ZeroOutOp<double>);

Se você tiver mais de algumas sobrecargas, poderá colocar o registro em uma macro.

#include "tensorflow/core/framework/op_kernel.h"

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);

#undef REGISTER_KERNEL

Dependendo da lista de tipos para os quais você está registrando o kernel, você pode usar uma macro fornecida por tensorflow/core/framework/register_types.h :

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"

REGISTER_OP("ZeroOut")
    .Attr("T: realnumbertype")
    .Input("to_zero: T")
    .Output("zeroed: T");

template <typename T>
class ZeroOutOp : public OpKernel { ... };

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);

#undef REGISTER_KERNEL
Listar entradas e saídas

Além de poder aceitar ou produzir diferentes tipos, ops podem consumir ou produzir um número variável de tensores.

No próximo exemplo, o attr T contém uma lista de tipos e é usado como o tipo in entrada e out . A entrada e a saída são listas de tensores desse tipo (e o número e os tipos de tensores na saída são os mesmos da entrada, pois ambos têm o tipo T ).

REGISTER_OP("PolymorphicListExample")
    .Attr("T: list(type)")
    .Input("in: T")
    .Output("out: T");

Você também pode colocar restrições sobre quais tipos podem ser especificados na lista. Neste próximo caso, a entrada é uma lista de tensores float e double . O op aceita, por exemplo, tipos de entrada (float, double, float) e nesse caso o tipo de saída também seria (float, double, float) .

REGISTER_OP("ListTypeRestrictionExample")
    .Attr("T: list({float, double})")
    .Input("in: T")
    .Output("out: T");

Se você quiser que todos os tensores em uma lista sejam do mesmo tipo, você pode fazer algo como:

REGISTER_OP("IntListInputExample")
    .Attr("N: int")
    .Input("in: N * int32")
    .Output("out: int32");

Isso aceita uma lista de tensores int32 e usa um int attr N para especificar o comprimento da lista.

Isso pode ser feito do tipo polimórfico também. No próximo exemplo, a entrada é uma lista de tensores (com comprimento "N" ) do mesmo tipo (mas não especificado) ( "T" ), e a saída é um único tensor de tipo correspondente:

REGISTER_OP("SameListInputExample")
    .Attr("N: int")
    .Attr("T: type")
    .Input("in: N * T")
    .Output("out: T");

Por padrão, as listas de tensores têm um comprimento mínimo de 1. Você pode alterar esse padrão usando uma restrição ">=" no attr correspondente . Neste próximo exemplo, a entrada é uma lista de pelo menos 2 tensores int32 :

REGISTER_OP("MinLengthIntListExample")
    .Attr("N: int >= 2")
    .Input("in: N * int32")
    .Output("out: int32");

A mesma sintaxe funciona com attrs "list(type)" :

REGISTER_OP("MinimumLengthPolymorphicListExample")
    .Attr("T: list(type) >= 3")
    .Input("in: T")
    .Output("out: T");

Entradas e saídas

Para resumir o acima, um registro de operação pode ter várias entradas e saídas:

REGISTER_OP("MultipleInsAndOuts")
    .Input("y: int32")
    .Input("z: float")
    .Output("a: string")
    .Output("b: int32");

Cada especificação de entrada ou saída é da forma:

<name>: <io-type-expr>

onde <name> começa com uma letra e pode ser composto por caracteres alfanuméricos e sublinhados. <io-type-expr> é uma das seguintes expressões de tipo:

  • <type> , onde <type> é um tipo de entrada suportado (por exemplo, float , int32 , string ). Isso especifica um único tensor do tipo fornecido.

    Consulte tf.DType .

    REGISTER_OP("BuiltInTypesExample")
        .Input("integers: int32")
        .Input("complex_numbers: complex64");
    
  • <attr-type> , onde <attr-type> é o nome de um Attr com type type ou list(type) (com uma possível restrição de tipo). Essa sintaxe permite operações polimórficas .

    REGISTER_OP("PolymorphicSingleInput")
        .Attr("T: type")
        .Input("in: T");
    
    REGISTER_OP("RestrictedPolymorphicSingleInput")
        .Attr("T: {int32, int64}")
        .Input("in: T");
    

    Fazer referência a um attr do tipo list(type) permite que você aceite uma sequência de tensores.

    REGISTER_OP("ArbitraryTensorSequenceExample")
        .Attr("T: list(type)")
        .Input("in: T")
        .Output("out: T");
    
    REGISTER_OP("RestrictedTensorSequenceExample")
        .Attr("T: list({int32, int64})")
        .Input("in: T")
        .Output("out: T");
    

    Observe que o número e os tipos de tensores na saída out são os mesmos que na entrada in , pois ambos são do tipo T .

  • Para uma sequência de tensores com o mesmo tipo: <number> * <type> , onde <number> é o nome de um Attr com tipo int . O <type> pode ser um tf.DType ou o nome de um attr com type type . Como exemplo do primeiro, este op aceita uma lista de tensores int32 :

    REGISTER_OP("Int32SequenceExample")
        .Attr("NumTensors: int")
        .Input("in: NumTensors * int32")
    

    Considerando que este op aceita uma lista de tensores de qualquer tipo, desde que sejam todos iguais:

    REGISTER_OP("SameTypeSequenceExample")
        .Attr("NumTensors: int")
        .Attr("T: type")
        .Input("in: NumTensors * T")
    
  • Para uma referência a um tensor: Ref(<type>) , onde <type> é um dos tipos anteriores.

Qualquer attr usado no tipo de uma entrada será inferido. Por convenção, esses atributos inferidos usam nomes maiúsculos (como T ou N ). Caso contrário, entradas, saídas e attrs têm nomes como parâmetros de função (por exemplo num_outputs ). Para obter mais detalhes, consulte a seção anterior sobre nomenclatura .

Para obter mais detalhes, consulte tensorflow/core/framework/op_def_builder.h .

Compatibilidade com versões anteriores

Vamos supor que você escreveu uma operação personalizada e agradável e a compartilhou com outras pessoas, para que você tenha clientes satisfeitos usando sua operação. No entanto, você gostaria de fazer alterações na operação de alguma forma.

Em geral, as alterações nas especificações de check-in existentes devem ser compatíveis com versões anteriores: alterar a especificação de uma operação não deve interromper os buffers de protocolo GraphDef serializados anteriores construídos a partir de especificações mais antigas. Os detalhes da compatibilidade do GraphDef são descritos aqui .

Existem várias maneiras de preservar a compatibilidade com versões anteriores.

  1. Qualquer novo attrs adicionado a uma operação deve ter valores padrão definidos e, com esse valor padrão, o op deve ter o comportamento original. Para alterar uma operação de não polimórfica para polimórfica, você deve fornecer um valor padrão ao novo tipo attr para preservar a assinatura original por padrão. Por exemplo, se sua operação foi:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: float")
        .Output("out: float");
    

    você pode torná-lo polimórfico de maneira compatível com versões anteriores usando:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: T")
        .Output("out: T")
        .Attr("T: numerictype = DT_FLOAT");
    
  2. Você pode fazer uma restrição em um attr com segurança menos restritiva. Por exemplo, você pode alterar de {int32, int64} para {int32, int64, float} ou type . Ou você pode mudar de {"apple", "orange"} para {"apple", "banana", "orange"} ou string .

  3. Você pode alterar entradas/saídas individuais em entradas/saídas de lista, desde que o padrão para o tipo de lista corresponda à assinatura antiga.

  4. Você pode adicionar uma nova entrada/saída de lista, se o padrão for vazio.

  5. Namespace quaisquer novas operações que você criar, prefixando os nomes das operações com algo exclusivo para seu projeto. Isso evita que sua operação colida com qualquer operação que possa ser incluída em versões futuras do TensorFlow.

  6. Planejar com antecedência! Tente antecipar usos futuros para a operação. Algumas alterações de assinatura não podem ser feitas de maneira compatível (por exemplo, transformar uma lista do mesmo tipo em uma lista de tipos variados).

A lista completa de alterações seguras e inseguras pode ser encontrada em tensorflow/core/framework/op_compatibility_test.cc . Se você não puder fazer sua alteração em uma operação compatível com versões anteriores, crie uma nova operação com um novo nome com a nova semântica.

Observe também que, embora essas alterações possam manter a compatibilidade com o GraphDef , o código Python gerado pode mudar de uma maneira que não é compatível com os chamadores antigos. A API Python pode ser mantida compatível por mudanças cuidadosas em um wrapper Python escrito à mão, mantendo a assinatura antiga, exceto possivelmente adicionando novos argumentos opcionais ao final. Geralmente, alterações incompatíveis só podem ser feitas quando o TensorFlow altera as versões principais e devem estar em conformidade com a semântica da versão do GraphDef .

Suporte a GPU

Você pode implementar diferentes OpKernels e registrar um para CPU e outro para GPU, assim como você pode registrar kernels para diferentes tipos . Existem vários exemplos de kernels com suporte a GPU em tensorflow/core/kernels/ . Observe que alguns kernels têm uma versão de CPU em um arquivo .cc , uma versão de GPU em um arquivo que termina em _gpu.cu.cc e algum código compartilhado em um arquivo .h .

Por exemplo, o tf.pad tem tudo menos o kernel da GPU em tensorflow/core/kernels/pad_op.cc . O kernel da GPU está em tensorflow/core/kernels/pad_op_gpu.cu.cc , e o código compartilhado é uma classe de modelo definida em tensorflow/core/kernels/pad_op.h . Organizamos o código dessa maneira por dois motivos: permite compartilhar código comum entre as implementações de CPU e GPU e coloca a implementação de GPU em um arquivo separado para que possa ser compilado apenas pelo compilador de GPU.

Uma coisa a notar, mesmo quando a versão do kernel da GPU do pad é usada, ele ainda precisa de sua entrada de "paddings" na memória da CPU. Para marcar que as entradas ou saídas são mantidas na CPU, adicione uma chamada HostMemory() ao registro do kernel, por exemplo:

#define REGISTER_GPU_KERNEL(T)                         \
  REGISTER_KERNEL_BUILDER(Name("Pad")                  \
                              .Device(DEVICE_GPU)      \
                              .TypeConstraint<T>("T")  \
                              .HostMemory("paddings"), \
                          PadOp<GPUDevice, T>)

Compilando o kernel para o dispositivo GPU

Veja cuda_op_kernel.cu.cc para um exemplo que usa um kernel CUDA para implementar um op. O tf_custom_op_library aceita um argumento gpu_srcs no qual a lista de arquivos fonte contendo os kernels CUDA (arquivos *.cu.cc ) pode ser especificada. Para uso com uma instalação binária do TensorFlow, os kernels CUDA devem ser compilados com o compilador nvcc da NVIDIA. Aqui está a sequência de comandos que você pode usar para compilar cuda_op_kernel.cu.cc e cuda_op_kernel.cc em uma única biblioteca carregável dinamicamente:

nvcc -std=c++14 -c -o cuda_op_kernel.cu.o cuda_op_kernel.cu.cc \
  ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC

g++ -std=c++14 -shared -o cuda_op_kernel.so cuda_op_kernel.cc \
  cuda_op_kernel.cu.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]}

cuda_op_kernel.so produzido acima pode ser carregado normalmente em Python, usando a função tf.load_op_library .

Observe que se suas bibliotecas CUDA não estiverem instaladas em /usr/local/lib64 , você precisará especificar o caminho explicitamente no segundo comando (g++) acima. Por exemplo, adicione -L /usr/local/cuda-8.0/lib64/ se seu CUDA estiver instalado em /usr/local/cuda-8.0 .

Implemente o gradiente em Python

Dado um gráfico de operações, o TensorFlow usa diferenciação automática (backpropagation) para adicionar novas operações que representam gradientes em relação às operações existentes. Para fazer a diferenciação automática funcionar para novas operações, você deve registrar uma função gradiente que calcula gradientes em relação às entradas das operações, dados os gradientes em relação às saídas das operações.

Matematicamente, se um op calcula \(y = f(x)\) o gradiente registrado op converte gradientes \(\partial L/ \partial y\) de perda \(L\) em relação a\(y\) em gradientes \(\partial L/ \partial x\) em relação a \(x\) através da regra da cadeia:

\[\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.\]

No caso de ZeroOut , apenas uma entrada na entrada afeta a saída, então o gradiente em relação à entrada é um tensor esparso "one hot". Isso é expresso da seguinte forma:

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops

@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
  """The gradients for `zero_out`.

  Args:
    op: The `zero_out` `Operation` that we are differentiating, which we can use
      to find the inputs and outputs of the original op.
    grad: Gradient with respect to the output of the `zero_out` op.

  Returns:
    Gradients with respect to the input of `zero_out`.
  """
  to_zero = op.inputs[0]
  shape = array_ops.shape(to_zero)
  index = array_ops.zeros_like(shape)
  first_grad = array_ops.reshape(grad, [-1])[0]
  to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
  return [to_zero_grad]  # List of one Tensor, since we have one input

Detalhes sobre como registrar funções de gradiente com tf.RegisterGradient :

  • Para um op com uma saída, a função gradiente pegará um tf.Operation , op e um tf.Tensor grad e construirá novos ops a partir dos tensores op.inputs[i] , op.outputs[i] e grad . Informações sobre quaisquer attrs podem ser encontradas tf.Operation.get_attr .

  • Se o op tiver várias saídas, a função gradiente usará op e grads , onde grads é uma lista de gradientes em relação a cada saída. O resultado da função gradiente deve ser uma lista de objetos Tensor representando os gradientes em relação a cada entrada.

  • Se não houver gradiente bem definido para alguma entrada, como para entradas inteiras usadas como índices, o gradiente retornado correspondente deve ser None . Por exemplo, para um op pegando um tensor de ponto flutuante x e um índice inteiro i , a função gradiente return [x_grad, None] .

  • Se não houver nenhum gradiente significativo para o op, muitas vezes você não terá que registrar nenhum gradiente e, desde que o gradiente do op nunca seja necessário, você ficará bem. Em alguns casos, um op não tem gradiente bem definido, mas pode estar envolvido no cálculo do gradiente. Aqui você pode usar ops.NotDifferentiable para propagar automaticamente zeros para trás.

Observe que no momento em que a função gradiente é chamada, apenas o gráfico de fluxo de dados de operações está disponível, não os dados do tensor em si. Assim, toda a computação deve ser realizada usando outras operações de tensorflow, para serem executadas em tempo de execução do gráfico.

Funções de forma em C++

A API do TensorFlow possui um recurso chamado "inferência de forma" que fornece informações sobre as formas dos tensores sem precisar executar o gráfico. A inferência de forma é suportada por "funções de forma" que são registradas para cada tipo de operação na declaração C++ REGISTER_OP e executam duas funções: afirmar que as formas das entradas são compatíveis durante a construção do gráfico e especificar as formas das saídas.

As funções de forma são definidas como operações na classe shape_inference::InferenceContext . Por exemplo, na função de forma para ZeroOut:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

c->set_output(0, c->input(0)); declara que a forma da primeira saída deve ser definida para a forma da primeira entrada. Se a saída for selecionada por seu índice como no exemplo acima, o segundo parâmetro de set_output deve ser um objeto ShapeHandle . Você pode criar um objeto ShapeHandle vazio por seu construtor padrão. O objeto ShapeHandle para uma entrada com índice idx pode ser obtido por c->input(idx) .

Existem várias funções de forma comuns que se aplicam a muitas operações, como shape_inference::UnchangedShape , que pode ser encontrada em common_shape_fns.h e usada da seguinte forma:

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn(::tensorflow::shape_inference::UnchangedShape);

Uma função de forma também pode restringir a forma de uma entrada. Para a versão de ZeroOut com uma restrição de forma vetorial , a função de forma seria a seguinte:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
      c->set_output(0, input);
      return Status::OK();
    });

A chamada WithRank valida que a forma de entrada c->input(0) tem uma forma com exatamente uma dimensão (ou se a forma de entrada for desconhecida, a forma de saída será um vetor com uma dimensão desconhecida).

Se seu op for polimórfico com várias entradas , você poderá usar membros de InferenceContext para determinar o número de formas a serem verificadas e Merge para validar se as formas são todas compatíveis (alternativamente, acesse atributos que indicam os comprimentos, com InferenceContext::GetAttr , que fornece acesso aos atributos do op).

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      ::tensorflow::shape_inference::ShapeHandle output;
      for (size_t i = 0; i < c->num_inputs(); ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &input));
        TF_RETURN_IF_ERROR(c->Merge(output, input, &output));
      }
      c->set_output(0, output);
      return Status::OK();
    });

Como a inferência de forma é um recurso opcional e as formas dos tensores podem variar dinamicamente, as funções de forma devem ser robustas para informações de forma incompletas para qualquer uma das entradas. O método Merge em InferenceContext permite que o chamador declare que duas formas são iguais, mesmo que uma ou ambas não tenham informações completas. As funções de forma são definidas para todas as operações principais do TensorFlow e fornecem muitos exemplos de uso diferentes.

A classe InferenceContext tem várias funções que podem ser usadas para definir manipulações de funções de forma. Por exemplo, você pode validar que uma determinada dimensão tem um valor muito específico usando InferenceContext::Dim e InferenceContext::WithValue ; você pode especificar que uma dimensão de saída é a soma/produto de duas dimensões de entrada usando InferenceContext::Add e InferenceContext::Multiply . Consulte a classe InferenceContext para todas as várias manipulações de forma que você pode especificar. O exemplo a seguir define a forma da primeira saída para (n, 3), onde a primeira entrada tem a forma (n, ...)

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    c->set_output(0, c->Matrix(c->Dim(c->input(0), 0), 3));
    return Status::OK();
});

Se você tiver uma função de formato complicada, considere adicionar um teste para validar que várias combinações de formato de entrada produzem as combinações de formato de saída esperadas. Você pode ver exemplos de como escrever esses testes em alguns dos nossos testes de operações principais . (A sintaxe de INFER_OK e INFER_ERROR é um pouco enigmática, mas tente ser compacto ao representar especificações de forma de entrada e saída em testes. Por enquanto, veja os comentários ao redor desses testes para ter uma noção da especificação da string de forma).

Crie um pacote pip para sua operação personalizada

Para construir um pacote pip para seu op, veja o exemplo tensorflow/custom-op . Este guia mostra como criar operações personalizadas do pacote pip do TensorFlow em vez de criar o TensorFlow a partir da origem.