Создать операцию

Если вы хотите создать операцию, которая не поддерживается существующей библиотекой TensorFlow, мы рекомендуем сначала попробовать написать операцию на Python как композицию существующих операций или функций Python. Если это невозможно, вы можете создать собственную операцию C++. Существует несколько причин, по которым вам может потребоваться создать собственную операцию C++:

  • Непросто и невозможно выразить вашу операцию как композицию существующих операций.
  • Неэффективно выражать операцию как композицию существующих примитивов.
  • Вы хотите вручную объединить композицию примитивов, которую будущему компилятору будет сложно объединить.

Например, представьте, что вы хотите реализовать что-то вроде «объединения медиан», аналогичного оператору «MaxPool», но вычислять медианы по скользящим окнам, а не по максимальным значениям. Выполнение этого с использованием композиции операций может быть возможным (например, с использованием ExtractImagePatches и TopK), но оно может быть не таким эффективным с точки зрения производительности или памяти, как собственная операция, где вы можете сделать что-то более умное в одной объединенной операции. Как всегда, обычно сначала стоит попытаться выразить то, что вы хотите, с помощью композиции операторов, добавляя новую операцию только в том случае, если это окажется сложным или неэффективным.

Чтобы включить свою собственную операцию, вам необходимо:

  1. Зарегистрируйте новую операцию в файле C++. Регистрация операции определяет интерфейс (спецификацию) для функциональности операции, которая не зависит от реализации операции. Например, регистрация операции определяет имя операции, а также ее входы и выходы. Он также определяет функцию формы, которая используется для вывода формы тензора.
  2. Реализуйте операцию на C++. Реализация операции называется ядром и представляет собой конкретную реализацию спецификации, которую вы зарегистрировали на шаге 1. Может существовать несколько ядер для разных типов ввода/вывода или архитектур (например, ЦП, ГП).
  3. Создайте оболочку Python (необязательно). Эта оболочка представляет собой общедоступный API, который используется для создания операции в Python. На основе регистрации операции создается оболочка по умолчанию, которую можно использовать напрямую или добавлять.
  4. Напишите функцию для вычисления градиентов для операции (необязательно).
  5. Проверьте оп. Обычно для удобства мы делаем это на Python, но вы также можете протестировать эту операцию на C++. Если вы определяете градиенты, вы можете проверить их с помощью Python tf.test.compute_gradient_error . См. relu_op_test.py как пример проверки прямых функций Relu-подобных операторов и их градиентов.

Предварительные условия

Определить операционный интерфейс

Вы определяете интерфейс операции, регистрируя ее в системе TensorFlow. При регистрации вы указываете имя вашей операции, ее входные данные (типы и имена) и выходные данные (типы и имена), а также строки документации и любые атрибуты , которые могут потребоваться для этой операции.

Чтобы увидеть, как это работает, предположим, что вы хотите создать операцию, которая принимает тензор int32 и выводит копию тензора, при этом все элементы, кроме первого, установлены в ноль. Для этого создайте файл с именем zero_out.cc . Затем добавьте вызов макроса REGISTER_OP , который определяет интерфейс вашей операции:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

Эта операция ZeroOut принимает в качестве входных данных один тензор to_zero 32-битных целых чисел и выводит zeroed тензор 32-битных целых чисел. Операция также использует функцию формы, чтобы гарантировать, что выходной тензор имеет ту же форму, что и входной тензор. Например, если входные данные представляют собой тензор формы [10, 20], то эта функция формы указывает, что выходная форма также имеет форму [10, 20].

Реализовать ядро ​​для операции

После определения интерфейса предоставьте одну или несколько реализаций операции. Чтобы создать одно из этих ядер, создайте класс, расширяющий OpKernel и переопределяющий метод Compute . Метод Compute предоставляет один context аргумент типа OpKernelContext* , из которого вы можете получить доступ к таким полезным вещам, как входные и выходные тензоры.

Добавьте свое ядро ​​в файл, который вы создали выше. Ядро может выглядеть примерно так:

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};

После реализации ядра вы регистрируете его в системе TensorFlow. При регистрации вы указываете различные ограничения, при которых будет работать это ядро. Например, у вас может быть одно ядро ​​для процессоров и отдельное для графических процессоров.

Чтобы сделать это для операции ZeroOut , добавьте в zero_out.cc следующее:

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

Многопоточные ядра процессора

Чтобы написать многопоточное ядро ​​ЦП, можно использовать функцию Shard в work_sharder.h . Эта функция распределяет вычислительную функцию по потокам, сконфигурированным для использования внутриоперационной многопоточности (см.tra_op_parallelism_threads в config.proto ).

Ядра графического процессора

Ядро графического процессора состоит из двух частей: OpKernel, ядра CUDA и его кода запуска.

Иногда реализация OpKernel является общей для ядра ЦП и графического процессора, например, при проверке входных данных и распределении выходных данных. В этом случае предлагаемая реализация заключается в следующем:

  1. Определите шаблон OpKernel на устройстве и примитивный тип тензора.
  2. Чтобы выполнить фактическое вычисление выходных данных, функция Compute вызывает шаблонную структуру-функтор.
  3. Специализация этого функтора для CPUDevice определяется в том же файле, но специализация для GPUDevice определяется в файле .cu.cc, поскольку он будет скомпилирован с помощью компилятора CUDA.

Вот пример реализации.

// kernel_example.h
#ifndef KERNEL_EXAMPLE_H_
#define KERNEL_EXAMPLE_H_

#include <unsupported/Eigen/CXX11/Tensor>

template <typename Device, typename T>
struct ExampleFunctor {
  void operator()(const Device& d, int size, const T* in, T* out);
};

#if GOOGLE_CUDA
// Partially specialize functor for GpuDevice.
template <typename T>
struct ExampleFunctor<Eigen::GpuDevice, T> {
  void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out);
};
#endif

#endif KERNEL_EXAMPLE_H_
// kernel_example.cc
#include "kernel_example.h"

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

REGISTER_OP("Example")
    .Attr("T: numbertype")
    .Input("input: T")
    .Output("input_times_two: T")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

// CPU specialization of actual computation.
template <typename T>
struct ExampleFunctor<CPUDevice, T> {
  void operator()(const CPUDevice& d, int size, const T* in, T* out) {
    for (int i = 0; i < size; ++i) {
      out[i] = 2 * in[i];
    }
  }
};

// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class ExampleOp : public OpKernel {
 public:
  explicit ExampleOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));

    // Do the computation.
    OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
                errors::InvalidArgument("Too many elements in tensor"));
    ExampleFunctor<Device, T>()(
        context->eigen_device<Device>(),
        static_cast<int>(input_tensor.NumElements()),
        input_tensor.flat<T>().data(),
        output_tensor->flat<T>().data());
  }
};

// Register the CPU kernels.
#define REGISTER_CPU(T)                                          \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      ExampleOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(int32);

// Register the GPU kernels.
#ifdef GOOGLE_CUDA
#define REGISTER_GPU(T)                                          \
  /* Declare explicit instantiations in kernel_example.cu.cc. */ \
  extern template class ExampleFunctor<GPUDevice, T>;            \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
      ExampleOp<GPUDevice, T>);
REGISTER_GPU(float);
REGISTER_GPU(int32);
#endif  // GOOGLE_CUDA
// kernel_example.cu.cc
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "kernel_example.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

using namespace tensorflow;

using GPUDevice = Eigen::GpuDevice;

// Define the CUDA kernel.
template <typename T>
__global__ void ExampleCudaKernel(const int size, const T* in, T* out) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
       i += blockDim.x * gridDim.x) {
    out[i] = 2 * __ldg(in + i);
  }
}

// Define the GPU implementation that launches the CUDA kernel.
template <typename T>
void ExampleFunctor<GPUDevice, T>::operator()(
    const GPUDevice& d, int size, const T* in, T* out) {
  // Launch the cuda kernel.
  //
  // See core/util/gpu_kernel_helper.h for example of computing
  // block count and thread_per_block count.
  int block_count = 1024;
  int thread_per_block = 20;
  ExampleCudaKernel<T>
      <<<block_count, thread_per_block, 0, d.stream()>>>(size, in, out);
}

// Explicitly instantiate functors for the types of OpKernels registered.
template struct ExampleFunctor<GPUDevice, float>;
template struct ExampleFunctor<GPUDevice, int32>;

#endif  // GOOGLE_CUDA

Создайте операционную библиотеку

Скомпилируйте операцию с помощью системного компилятора (двоичная установка TensorFlow).

Вы сможете скомпилировать zero_out.cc с помощью компилятора C++ , такого как g++ или clang доступного в вашей системе. Бинарный пакет PIP устанавливает файлы заголовков и библиотеку, необходимую для компиляции вашей операции, в местах, специфичных для системы. Однако библиотека Python TensorFlow предоставляет функцию get_include для получения каталога заголовков, а каталог get_lib имеет общий объект для связи. Вот выходные данные этих функций на машине с Ubuntu.

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
'/usr/local/lib/python3.6/site-packages/tensorflow/include'
>>> tf.sysconfig.get_lib()
'/usr/local/lib/python3.6/site-packages/tensorflow'

Предполагая, что у вас установлен g++ , вот последовательность команд, которую вы можете использовать для компиляции вашей операции в динамическую библиотеку.

TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++14 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2

В macOS при создании файла .so требуется дополнительный флаг «-undefined Dynamic_lookup».

Примечание по версии gcc >=5 : начиная с версии 5 , gcc использует новый C++ ABI . TensorFlow 2.8 и более ранние версии были созданы с помощью gcc4 , который использует более старый ABI. Если вы используете эти версии TensorFlow и пытаетесь скомпилировать свою библиотеку op с помощью gcc>=5 , добавьте -D_GLIBCXX_USE_CXX11_ABI=0 в командную строку, чтобы сделать библиотеку совместимой со старой версией ABI. Пакеты TensorFlow 2.9+ по умолчанию совместимы с более новым ABI.

Скомпилируйте операцию с помощью bazel (установка исходного кода TensorFlow)

Если у вас установлены исходные коды TensorFlow, вы можете использовать систему сборки TensorFlow для компиляции вашего проекта. Поместите файл BUILD со следующим правилом сборки Bazel в каталог tensorflow/core/user_ops .

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    name = "zero_out.so",
    srcs = ["zero_out.cc"],
)

Запустите следующую команду, чтобы создать zero_out.so .

$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so

Для компиляции операции Example с ядром CUDA вам необходимо использовать параметр gpu_srcs tf_custom_op_library . Поместите файл BUILD со следующим правилом сборки Bazel в новую папку внутри каталога tensorflow/core/user_ops (например, «example_gpu»).

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    # kernel_example.cc  kernel_example.cu.cc  kernel_example.h
    name = "kernel_example.so",
    srcs = ["kernel_example.h", "kernel_example.cc"],
    gpu_srcs = ["kernel_example.cu.cc", "kernel_example.h"],
)

Запустите следующую команду для сборки kernel_example.so .

$ bazel build --config opt //tensorflow/core/user_ops/example_gpu:kernel_example.so

Используйте операцию в Python

API TensorFlow Python предоставляет функцию tf.load_op_library для загрузки динамической библиотеки и регистрации операции в платформе TensorFlow. load_op_library возвращает модуль Python, содержащий оболочки Python для операции и ядра. Таким образом, после того как вы создали операцию, вы можете сделать следующее, чтобы запустить ее из Python:

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
print(zero_out_module.zero_out([[1, 2], [3, 4]]).numpy())

# Prints
array([[1, 0], [0, 0]], dtype=int32)

Имейте в виду, что сгенерированной функции будет присвоено имя Snake_case (для соответствия PEP8 ). Итак, если ваша операция называется ZeroOut в файлах C++, функция Python будет называться zero_out .

Чтобы сделать операцию доступной как обычную функцию import из модуля Python, может быть полезно иметь вызов load_op_library в исходном файле Python следующим образом:

import tensorflow as tf

zero_out_module = tf.load_op_library('./zero_out.so')
zero_out = zero_out_module.zero_out

Убедитесь, что операция работает

Хороший способ убедиться, что вы успешно реализовали свою операцию, — написать для нее тест. Создайте файл zero_out_op_test.py с содержимым:

import tensorflow as tf

class ZeroOutTest(tf.test.TestCase):
  def testZeroOut(self):
    zero_out_module = tf.load_op_library('./zero_out.so')
    with self.test_session():
      result = zero_out_module.zero_out([5, 4, 3, 2, 1])
      self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])

if __name__ == "__main__":
  tf.test.main()

Затем запустите тест (при условии, что у вас установлен tensorflow):

$ python zero_out_op_test.py

Встройте расширенные функции в свою работу

Теперь, когда вы знаете, как построить базовую (и несколько ограниченную) операцию и ее реализацию, мы рассмотрим некоторые из более сложных вещей, которые вам обычно придется встроить в свою операцию. Это включает в себя:

Условные проверки и валидация

В приведенном выше примере предполагалось, что операция op применима к тензору любой формы. Что, если это применимо только к векторам? Это означает добавление проверки в приведенную выше реализацию OpKernel.

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
                errors::InvalidArgument("ZeroOut expects a 1-D vector."));
    // ...
  }

Это подтверждает, что входные данные являются вектором, и возвращает значение, установившее статус InvalidArgument , если это не так. Макрос OP_REQUIRES принимает три аргумента:

  • context , который может быть указателем OpKernelContext или OpKernelConstruction (см. tensorflow/core/framework/op_kernel.h ), для его метода SetStatus() .
  • Состояние. Например, в tensorflow/core/framework/tensor_shape.h есть функции для проверки формы тензора.
  • Саму ошибку, представленную объектом Status , см. tensorflow/core/platform/status.h . Status имеет как тип (часто InvalidArgument , но см. список типов), так и сообщение. Функции для создания ошибки можно найти в tensorflow/core/platform/errors.h .

В качестве альтернативы, если вы хотите проверить, является ли объект Status , возвращаемый какой-либо функцией, ошибкой, и если да, то вернуть его, используйте OP_REQUIRES_OK . Оба этих макроса возвращаются из функции в случае ошибки.

Операционная регистрация

Атрибуты

Операции могут иметь атрибуты, значения которых устанавливаются при добавлении операции на график. Они используются для настройки операции, и к их значениям можно получить доступ как в реализации ядра, так и в типах входов и выходов при регистрации операции. Предпочитайте использовать входные данные вместо атрибута, когда это возможно, поскольку входные данные более гибкие. Это связано с тем, что атрибуты являются константами и должны быть определены во время построения графа. Напротив, входные данные представляют собой тензоры, значения которых могут быть динамическими; то есть входные данные могут изменяться на каждом шаге, задаваться с помощью канала и т. д. Attrs используются для вещей, которые невозможно сделать с входными данными: любая конфигурация, которая влияет на подпись (количество или тип входных или выходных данных) или которая может t меняется от шага к шагу.

Вы определяете атрибут при регистрации операции, указывая его имя и тип с помощью метода Attr , который ожидает спецификации формы:

<name>: <attr-type-expr>

где <name> начинается с буквы и может состоять из буквенно-цифровых символов и символов подчеркивания, а <attr-type-expr> — это выражение типа формы , описанной ниже .

Например, если вы хотите, чтобы операция ZeroOut сохраняла указанный пользователем индекс, а не только 0-й элемент, вы можете зарегистрировать операцию следующим образом:

REGISTER_OP("ZeroOut")
    .Attr("preserve_index: int")
    .Input("to_zero: int32")
    .Output("zeroed: int32");

(Обратите внимание, что набор типов атрибутов отличается от tf.DType используемого для входных и выходных данных.)

Затем ваше ядро ​​сможет получить доступ к этому атрибуту в своем конструкторе через параметр context :

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
    // Get the index of the value to preserve
    OP_REQUIRES_OK(context,
                   context->GetAttr("preserve_index", &preserve_index_));
    // Check that preserve_index is positive
    OP_REQUIRES(context, preserve_index_ >= 0,
                errors::InvalidArgument("Need preserve_index >= 0, got ",
                                        preserve_index_));
  }
  void Compute(OpKernelContext* context) override {
    // ...
  }
 private:
  int preserve_index_;
};

который затем можно использовать в методе Compute :

  void Compute(OpKernelContext* context) override {
    // ...

    // We're using saved attr to validate potentially dynamic input
    // So we check that preserve_index is in range
    OP_REQUIRES(context, preserve_index_ < input.dimension(0),
                errors::InvalidArgument("preserve_index out of range"));

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the requested input value
    output_flat(preserve_index_) = input(preserve_index_);
  }

Типы атрибутов

В атрибуте поддерживаются следующие типы:

  • string : любая последовательность байтов (не обязательно UTF8).
  • int : целое число со знаком.
  • float : число с плавающей запятой.
  • bool : Истина или ложь.
  • type : одно из (не ссылочных) значений DataType .
  • shape : TensorShapeProto .
  • list(<type>) : список <type> , где <type> — один из вышеуказанных типов. Обратите внимание, что list(list(<type>)) недействителен.

См. также: op_def_builder.cc:FinalizeAttr для получения окончательного списка.

Значения и ограничения по умолчанию

Атрибуты могут иметь значения по умолчанию, а некоторые типы атрибутов могут иметь ограничения. Чтобы определить атрибут с ограничениями, вы можете использовать следующие <attr-type-expr> :

{'<string1>', '<string2>'} : значение должно быть строкой, имеющей значение <string1> или <string2> . Имя типа string подразумевается при использовании этого синтаксиса. Это эмулирует перечисление:

REGISTER_OP("EnumExample")
    .Attr("e: {'apple', 'orange'}");

{<type1>, <type2>} : значение имеет тип type и должно быть одним из <type1> или <type2> , где <type1> и <type2> поддерживаются tf.DType . Вы не указываете, что тип атрибута — type . Это подразумевается, когда у вас есть список типов в {...} . Например, в этом случае attr t — это тип, который должен быть int32 , float или bool :

REGISTER_OP("RestrictedTypeExample")
    .Attr("t: {int32, float, bool}");

Существуют ярлыки для ограничений общего типа:

  • numbertype : Тип type , ограниченный числовыми типами (не строковыми и не логическими).
  • realnumbertype : как и numbertype без сложных типов.
  • quantizedtype : Как и numbertype , но только типы квантованных чисел.

Конкретные списки типов, разрешенных ими, определяются функциями (например, NumberTypes() ) в tensorflow/core/framework/types.h . В этом примере атрибут t должен быть одним из числовых типов:

REGISTER_OP("NumberType")
    .Attr("t: numbertype");

Для этой операции:

tf.number_type(t=tf.int32)  # Valid
tf.number_type(t=tf.bool)   # Invalid

Списки можно комбинировать с другими списками и отдельными типами. Следующая операция позволяет attr t иметь любой числовой тип или тип bool:

REGISTER_OP("NumberOrBooleanType")
    .Attr("t: {numbertype, bool}");

Для этой операции:

tf.number_or_boolean_type(t=tf.int32)  # Valid
tf.number_or_boolean_type(t=tf.bool)   # Valid
tf.number_or_boolean_type(t=tf.string) # Invalid

int >= <n> : значение должно быть целым числом, значение которого больше или равно <n> , где <n> — натуральное число. Например, следующая регистрация операции указывает, что атрибут a должен иметь значение не менее 2 :

REGISTER_OP("MinIntExample")
    .Attr("a: int >= 2");

list(<type>) >= <n> : список типа <type> , длина которого больше или равна <n> . Например, следующая регистрация операции указывает, что attr a представляет собой список типов ( int32 или float ), и что их должно быть как минимум 3:

REGISTER_OP("TypeListExample")
    .Attr("a: list({int32, float}) >= 3");

Чтобы установить значение по умолчанию для атрибута (сделав его необязательным в сгенерированном коде), добавьте = <default> в конец, например:

REGISTER_OP("AttrDefaultExample")
    .Attr("i: int = 0");

Кроме того, можно указать как ограничение, так и значение по умолчанию:

REGISTER_OP("AttrConstraintAndDefaultExample")
    .Attr("i: int >= 1 = 1");

Поддерживаемый синтаксис значения по умолчанию — это тот, который будет использоваться в прототипном представлении результирующего определения GraphDef.

Вот примеры того, как указать значение по умолчанию для всех типов:

REGISTER_OP("AttrDefaultExampleForAllTypes")
   .Attr("s: string = 'foo'")
   .Attr("i: int = 0")
   .Attr("f: float = 1.0")
   .Attr("b: bool = true")
   .Attr("ty: type = DT_INT32")
   .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
   .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
   .Attr("l_empty: list(int) = []")
   .Attr("l_int: list(int) = [2, 3, 5, 7]");

Обратите внимание, в частности, что значения типа type используют tf.DType .

Полиморфизм

Тип полиморфизма

Для операций, которые могут принимать разные типы в качестве входных данных или создавать разные типы выходных данных, вы можете указать атрибут в типе ввода или вывода при регистрации операции. Обычно после этого вы регистрируете OpKernel для каждого поддерживаемого типа.

Например, если вы хотите, чтобы операция ZeroOut работала с float в дополнение к int32 , регистрация вашей операции может выглядеть так:

REGISTER_OP("ZeroOut")
    .Attr("T: {float, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Ваша регистрация операции теперь указывает, что тип ввода должен быть float или int32 , и что его вывод будет того же типа, поскольку оба имеют тип T

Именование

Входные, выходные данные и атрибуты обычно должны иметь имена Snake_case. Единственным исключением являются атрибуты, которые используются в качестве типа ввода или типа вывода. Эти атрибуты можно вывести, когда операция добавляется в график, и поэтому они не отображаются в функции операции. Например, это последнее определение ZeroOut сгенерирует функцию Python, которая выглядит следующим образом:

def zero_out(to_zero, name=None):
  """...
  Args:
    to_zero: A `Tensor`. Must be one of the following types:
        `float32`, `int32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `to_zero`.
  """

Если to_zero передается тензор int32 , то T автоматически устанавливается в int32 (ну, на самом деле DT_INT32 ). Этим выведенным атрибутам присваиваются имена с заглавной буквы или в CamelCase.

Сравните это с операцией, у которой есть тип attr, определяющий тип вывода:

REGISTER_OP("StringToNumber")
    .Input("string_tensor: string")
    .Output("output: out_type")
    .Attr("out_type: {float, int32} = DT_FLOAT");
    .Doc(R"doc(
Converts each string in the input Tensor to the specified numeric type.
)doc");

В этом случае пользователю необходимо указать тип вывода, как в сгенерированном Python:

def string_to_number(string_tensor, out_type=None, name=None):
  """Converts each string in the input Tensor to the specified numeric type.

  Args:
    string_tensor: A `Tensor` of type `string`.
    out_type: An optional `tf.DType` from: `tf.float32, tf.int32`.
      Defaults to `tf.float32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `out_type`.
  """
Пример полиморфизма типов
#include "tensorflow/core/framework/op_kernel.h"

class ZeroOutInt32Op : public OpKernel {
  // as before
};

class ZeroOutFloatOp : public OpKernel {
 public:
  explicit ZeroOutFloatOp(OpKernelConstruction* context)
      : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<float>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<float>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutInt32Op);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutFloatOp);

Чтобы сохранить обратную совместимость , вам следует указать значение по умолчанию при добавлении атрибута к существующей операции:

REGISTER_OP("ZeroOut")
  .Attr("T: {float, int32} = DT_INT32")
  .Input("to_zero: T")
  .Output("zeroed: T")

Допустим, вы хотите добавить больше типов, скажем, double :

REGISTER_OP("ZeroOut")
    .Attr("T: {float, double, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Вместо того, чтобы писать еще один OpKernel с избыточным кодом, как указано выше, часто вместо этого вы можете использовать шаблон C++. У вас по-прежнему будет одна регистрация ядра (вызов REGISTER_KERNEL_BUILDER ) для каждой перегрузки.

template <typename T>
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<T>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<T>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<double>("T"),
    ZeroOutOp<double>);

Если у вас больше пары перегрузок, вы можете поместить регистрацию в макрос.

#include "tensorflow/core/framework/op_kernel.h"

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);

#undef REGISTER_KERNEL

В зависимости от списка типов, для которых вы регистрируете ядро, вы можете использовать макрос, предоставленный tensorflow/core/framework/register_types.h :

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"

REGISTER_OP("ZeroOut")
    .Attr("T: realnumbertype")
    .Input("to_zero: T")
    .Output("zeroed: T");

template <typename T>
class ZeroOutOp : public OpKernel { ... };

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);

#undef REGISTER_KERNEL
Перечислить входы и выходы

Помимо возможности принимать или создавать разные типы, операции могут потреблять или создавать переменное количество тензоров.

В следующем примере атрибут T содержит список типов и используется как тип in и out данных. Входные и выходные данные представляют собой списки тензоров этого типа (и количество и типы тензоров на выходе такие же, как на входе, поскольку оба имеют тип T ).

REGISTER_OP("PolymorphicListExample")
    .Attr("T: list(type)")
    .Input("in: T")
    .Output("out: T");

Вы также можете установить ограничения на то, какие типы можно указывать в списке. В следующем случае входные данные представляют собой список тензоров float и double чисел. Операция принимает, например, типы ввода (float, double, float) и в этом случае тип вывода также будет (float, double, float) .

REGISTER_OP("ListTypeRestrictionExample")
    .Attr("T: list({float, double})")
    .Input("in: T")
    .Output("out: T");

Если вы хотите, чтобы все тензоры в списке были одного типа, вы можете сделать что-то вроде:

REGISTER_OP("IntListInputExample")
    .Attr("N: int")
    .Input("in: N * int32")
    .Output("out: int32");

Он принимает список тензоров int32 и использует атрибут int N для указания длины списка.

Это также можно сделать полиморфным типом . В следующем примере входные данные представляют собой список тензоров (длиной "N" ) одного и того же (но неуказанного) типа ( "T" ), а выходные данные — один тензор соответствующего типа:

REGISTER_OP("SameListInputExample")
    .Attr("N: int")
    .Attr("T: type")
    .Input("in: N * T")
    .Output("out: T");

По умолчанию тензорные списки имеют минимальную длину 1. Вы можете изменить это значение по умолчанию, используя ограничение ">=" для соответствующего атрибута . В следующем примере входными данными является список как минимум из двух тензоров int32 :

REGISTER_OP("MinLengthIntListExample")
    .Attr("N: int >= 2")
    .Input("in: N * int32")
    .Output("out: int32");

Тот же синтаксис работает с атрибутами "list(type)" :

REGISTER_OP("MinimumLengthPolymorphicListExample")
    .Attr("T: list(type) >= 3")
    .Input("in: T")
    .Output("out: T");

Входы и выходы

Подводя итог вышесказанному, регистрация операции может иметь несколько входов и выходов:

REGISTER_OP("MultipleInsAndOuts")
    .Input("y: int32")
    .Input("z: float")
    .Output("a: string")
    .Output("b: int32");

Каждая спецификация ввода или вывода имеет форму:

<name>: <io-type-expr>

где <name> начинается с буквы и может состоять из буквенно-цифровых символов и символов подчеркивания. <io-type-expr> — это одно из следующих выражений типа:

  • <type> , где <type> — поддерживаемый тип ввода (например, float , int32 , string ). Это указывает один тензор данного типа.

    См. tf.DType .

    REGISTER_OP("BuiltInTypesExample")
        .Input("integers: int32")
        .Input("complex_numbers: complex64");
    
  • <attr-type> , где <attr-type> — имя Attr с type или list(type) (с возможным ограничением типа). Этот синтаксис допускает полиморфные операции .

    REGISTER_OP("PolymorphicSingleInput")
        .Attr("T: type")
        .Input("in: T");
    
    REGISTER_OP("RestrictedPolymorphicSingleInput")
        .Attr("T: {int32, int64}")
        .Input("in: T");
    

    Ссылка на атрибут list(type) позволяет принять последовательность тензоров.

    REGISTER_OP("ArbitraryTensorSequenceExample")
        .Attr("T: list(type)")
        .Input("in: T")
        .Output("out: T");
    
    REGISTER_OP("RestrictedTensorSequenceExample")
        .Attr("T: list({int32, int64})")
        .Input("in: T")
        .Output("out: T");
    

    Обратите внимание, что количество и типы тензоров на выходе out такие же, как и на входе in , поскольку оба имеют тип T

  • Для последовательности тензоров одного и того же типа: <number> * <type> , где <number> — имя Attr с типом int . <type> может быть либо tf.DType , либо именем атрибута с type . В качестве примера первого, эта операция принимает список тензоров int32 :

    REGISTER_OP("Int32SequenceExample")
        .Attr("NumTensors: int")
        .Input("in: NumTensors * int32")
    

    В то время как эта операция принимает список тензоров любого типа, если они все одинаковы:

    REGISTER_OP("SameTypeSequenceExample")
        .Attr("NumTensors: int")
        .Attr("T: type")
        .Input("in: NumTensors * T")
    
  • Для ссылки на тензор: Ref(<type>) , где <type> — один из предыдущих типов.

Любой атрибут, используемый в типе ввода, будет выведен. По соглашению эти выведенные атрибуты используют заглавные имена (например, T или N ). В противном случае входы, выходы и атрибуты имеют имена, подобные параметрам функции (например, num_outputs ). Более подробную информацию см. в предыдущем разделе об именовании .

Более подробную информацию см. в tensorflow/core/framework/op_def_builder.h .

Обратная совместимость

Предположим, вы написали хороший индивидуальный проект и поделились им с другими, и у вас есть довольные клиенты, использующие вашу операцию. Однако вы хотели бы каким-то образом внести изменения в операцию.

В общем, изменения в существующих, зарегистрированных спецификациях должны быть обратно совместимыми: изменение спецификации операции не должно нарушать предыдущие сериализованные буферы протокола GraphDef , созданные на основе старых спецификаций. Подробности совместимости GraphDef описаны здесь .

Есть несколько способов сохранить обратную совместимость.

  1. Любые новые атрибуты, добавленные в операцию, должны иметь определенные значения по умолчанию, и с этим значением по умолчанию операция должна иметь исходное поведение. Чтобы изменить операцию с неполиморфной на полиморфную, вы должны задать значение по умолчанию для нового типа attr, чтобы сохранить исходную подпись по умолчанию. Например, если ваша операция была:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: float")
        .Output("out: float");
    

    вы можете сделать его полиморфным обратно совместимым способом, используя:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: T")
        .Output("out: T")
        .Attr("T: numerictype = DT_FLOAT");
    
  2. Вы можете безопасно сделать ограничение атрибута менее строгим. Например, вы можете изменить {int32, int64} на {int32, int64, float} или type . Или вы можете изменить значение с {"apple", "orange"} на {"apple", "banana", "orange"} или string .

  3. Вы можете изменить отдельные входы/выходы на входы/выходы списка, если значение по умолчанию для типа списка соответствует старой сигнатуре.

  4. Вы можете добавить новый список ввода/вывода, если он по умолчанию пуст.

  5. Пространство имен для любых новых операций, которые вы создаете, добавляя к именам операций что-то уникальное для вашего проекта. Это позволяет избежать конфликта вашей операции с любыми операциями, которые могут быть включены в будущие версии TensorFlow.

  6. Планируйте заранее! Постарайтесь предвидеть будущее использование операции. Некоторые изменения сигнатур невозможно выполнить совместимым способом (например, превратить список одного типа в список разных типов).

Полный список безопасных и небезопасных изменений можно найти в tensorflow/core/framework/op_compatibility_test.cc . Если вы не можете сделать изменения в операции обратно совместимыми, создайте новую операцию с новым именем и новой семантикой.

Также обратите внимание, что хотя эти изменения могут обеспечить совместимость GraphDef , сгенерированный код Python может измениться таким образом, что это будет несовместимо со старыми вызывающими объектами. Совместимость API Python можно сохранить путем осторожных изменений в написанной вручную оболочке Python, сохранив старую подпись, за исключением, возможно, добавления в конец новых необязательных аргументов. Обычно несовместимые изменения могут быть внесены только тогда, когда TensorFlow меняет основные версии, и они должны соответствовать семантике версии GraphDef .

Поддержка графического процессора

Вы можете реализовать разные OpKernels и зарегистрировать один для CPU, а другой для GPU, точно так же, как вы можете зарегистрировать ядра для разных типов . В tensorflow/core/kernels/ есть несколько примеров ядер с поддержкой графического процессора. Обратите внимание, что некоторые ядра имеют версию ЦП в файле .cc , версию графического процессора в файле, заканчивающемся на _gpu.cu.cc , а некоторый общий код содержится в файле .h .

Например, в tf.pad есть все, кроме ядра графического процессора, в tensorflow/core/kernels/pad_op.cc . Ядро графического процессора находится в tensorflow/core/kernels/pad_op_gpu.cu.cc , а общий код — это шаблонный класс, определенный в tensorflow/core/kernels/pad_op.h . Мы организуем код таким образом по двум причинам: он позволяет использовать общий код для реализаций ЦП и графического процессора, а также помещает реализацию графического процессора в отдельный файл, чтобы его можно было скомпилировать только компилятором графического процессора.

Следует отметить одну вещь: даже когда используется версия pad для ядра графического процессора, ей все равно требуется ввод "paddings" в память ЦП. Чтобы отметить, что входные или выходные данные хранятся в ЦП, добавьте вызов HostMemory() к регистрации ядра, например:

#define REGISTER_GPU_KERNEL(T)                         \
  REGISTER_KERNEL_BUILDER(Name("Pad")                  \
                              .Device(DEVICE_GPU)      \
                              .TypeConstraint<T>("T")  \
                              .HostMemory("paddings"), \
                          PadOp<GPUDevice, T>)

Компиляция ядра для устройства с графическим процессором

Посмотрите cuda_op_kernel.cu.cc на пример, который использует ядро ​​CUDA для реализации операции. tf_custom_op_library принимает аргумент gpu_srcs , в котором можно указать список исходных файлов, содержащих ядра CUDA (файлы *.cu.cc ). Для использования с бинарной установкой TensorFlow ядра CUDA необходимо скомпилировать с помощью компилятора NVIDIA nvcc . Вот последовательность команд, которую вы можете использовать для компиляции cuda_op_kernel.cu.cc и cuda_op_kernel.cc в единую динамически загружаемую библиотеку:

nvcc -std=c++14 -c -o cuda_op_kernel.cu.o cuda_op_kernel.cu.cc \
  ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC

g++ -std=c++14 -shared -o cuda_op_kernel.so cuda_op_kernel.cc \
  cuda_op_kernel.cu.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]}

cuda_op_kernel.so созданный выше, можно загрузить в Python как обычно, используя функцию tf.load_op_library .

Обратите внимание: если ваши библиотеки CUDA не установлены в /usr/local/lib64 , вам нужно будет явно указать путь во второй команде (g++) выше. Например, добавьте -L /usr/local/cuda-8.0/lib64/ если ваш CUDA установлен в /usr/local/cuda-8.0 .

Реализация градиента в Python

Учитывая график операций, TensorFlow использует автоматическое дифференцирование (обратное распространение ошибки) для добавления новых операций, представляющих градиенты по отношению к существующим операциям. Чтобы автоматическое дифференцирование работало для новых операций, вы должны зарегистрировать функцию градиента, которая вычисляет градиенты относительно входных данных операций с учетом градиентов относительно выходных данных операций.

Математически, если операция вычисляет \(y = f(x)\) зарегистрированная операция градиента преобразует градиенты \(\partial L/ \partial y\) потери \(L\) относительно\(y\) в градиенты \(\partial L/ \partial x\) относительно \(x\) по правилу цепочки:

\[\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.\]

В случае ZeroOut только одна запись на входе влияет на выход, поэтому градиент по отношению к входу представляет собой разреженный «горячий» тензор. Это выражается следующим образом:

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops

@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
  """The gradients for `zero_out`.

  Args:
    op: The `zero_out` `Operation` that we are differentiating, which we can use
      to find the inputs and outputs of the original op.
    grad: Gradient with respect to the output of the `zero_out` op.

  Returns:
    Gradients with respect to the input of `zero_out`.
  """
  to_zero = op.inputs[0]
  shape = array_ops.shape(to_zero)
  index = array_ops.zeros_like(shape)
  first_grad = array_ops.reshape(grad, [-1])[0]
  to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
  return [to_zero_grad]  # List of one Tensor, since we have one input

Подробности о регистрации функций градиента с помощью tf.RegisterGradient :

  • Для операции с одним выходом функция градиента возьмет tf.Operation , op и tf.Tensor grad и создаст новые операции из тензоров op.inputs[i] , op.outputs[i] и grad . Информацию о любых атрибутах можно найти через tf.Operation.get_attr .

  • Если op имеет несколько выходов, функция градиента будет принимать op и grads , где grads — это список градиентов по отношению к каждому выходу. Результатом функции градиента должен быть список объектов Tensor , представляющих градиенты по отношению к каждому входу.

  • Если для некоторых входных данных нет четко определенного градиента, например для целочисленных входных данных, используемых в качестве индексов, соответствующий возвращаемый градиент должен быть None . Например, для операции, принимающей тензор с плавающей запятой x и целочисленный индекс i , функция градиента return [x_grad, None] .

  • Если для операции вообще нет значимого градиента, вам часто не придется регистрировать какой-либо градиент, и пока градиент операции никогда не понадобится, все будет в порядке. В некоторых случаях операция не имеет четко определенного градиента, но может участвовать в вычислении градиента. Здесь вы можете использовать ops.NotDifferentiable для автоматического распространения нулей в обратном направлении.

Обратите внимание, что во время вызова функции градиента доступен только граф потока данных операций, а не сами данные тензора. Таким образом, все вычисления должны выполняться с использованием других операций тензорного потока, которые будут выполняться во время выполнения графа.

Добавляйте подсказки типов при регистрации пользовательского градиента для типа операции, чтобы сделать код более читабельным, отлаживаемым, простым в обслуживании и более надежным благодаря проверке данных. Например, при использовании op в качестве параметра функции укажите, что функция градиента будет принимать tf.Operation в качестве типа параметра.

Функции формы в C++

API TensorFlow имеет функцию под названием «вывод формы», которая предоставляет информацию о формах тензоров без необходимости выполнения графа. Вывод формы поддерживается «функциями формы», которые зарегистрированы для каждого типа операции в объявлении C++ REGISTER_OP и выполняют две роли: утверждение совместимости входных форм во время построения графа и указание форм для выходных данных.

Функции формы определяются как операции в классе shape_inference::InferenceContext . Например, в функции формы для ZeroOut:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

c->set_output(0, c->input(0)); объявляет, что форма первого вывода должна быть установлена ​​в соответствии с формой первого ввода. Если выходные данные выбираются по индексу, как в приведенном выше примере, второй параметр set_output должен быть объектом ShapeHandle . Вы можете создать пустой объект ShapeHandle с помощью его конструктора по умолчанию. Объект ShapeHandle для ввода с индексом idx можно получить с помощью c->input(idx) .

Существует ряд общих функций формы, которые применяются ко многим операциям, например shape_inference::UnchangedShape , которую можно найти в common_shape_fns.h и использовать следующим образом:

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn(::tensorflow::shape_inference::UnchangedShape);

Функция формы также может ограничивать форму ввода. Для версии ZeroOut с ограничением векторной формы функция формы будет следующей:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
      c->set_output(0, input);
      return Status::OK();
    });

Вызов WithRank проверяет, что входная форма c->input(0) имеет форму ровно с одним измерением (или, если входная форма неизвестна, выходная форма будет вектором с одним неизвестным измерением).

Если ваша операция является полиморфной с несколькими входными параметрами , вы можете использовать члены InferenceContext , чтобы определить количество проверяемых фигур, и Merge , чтобы убедиться, что все фигуры совместимы (альтернативно, получить доступ к атрибутам, указывающим длины, с помощью InferenceContext::GetAttr , который обеспечивает доступ к атрибутам операции).

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      ::tensorflow::shape_inference::ShapeHandle output;
      for (size_t i = 0; i < c->num_inputs(); ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &input));
        TF_RETURN_IF_ERROR(c->Merge(output, input, &output));
      }
      c->set_output(0, output);
      return Status::OK();
    });

Поскольку вывод формы является дополнительной функцией, а формы тензоров могут изменяться динамически, функции формы должны быть устойчивы к неполной информации о форме для любого из входных данных. Метод Merge в InferenceContext позволяет вызывающему объекту утверждать, что две фигуры одинаковы, даже если одна из них или обе не имеют полной информации. Функции формы определены для всех основных операций TensorFlow и предоставляют множество различных примеров использования.

Класс InferenceContext имеет ряд функций, которые можно использовать для определения манипуляций с функциями формы. Например, вы можете проверить, что конкретное измерение имеет очень конкретное значение, используя InferenceContext::Dim и InferenceContext::WithValue ; вы можете указать, что выходное измерение является суммой/произведением двух входных измерений, используя InferenceContext::Add и InferenceContext::Multiply . См. класс InferenceContext для всех различных манипуляций с фигурами, которые вы можете указать. В следующем примере для формы первого вывода задается значение (n, 3), где первый вход имеет форму (n,...).

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    c->set_output(0, c->Matrix(c->Dim(c->input(0), 0), 3));
    return Status::OK();
});

Если у вас сложная функция формы, вам следует рассмотреть возможность добавления теста для проверки того, что различные комбинации входных форм создают ожидаемые комбинации выходных форм. Вы можете увидеть примеры написания этих тестов в некоторых наших основных операционных тестах . (Синтаксис INFER_OK и INFER_ERROR немного загадочный, но постарайтесь быть компактным при представлении спецификаций входных и выходных форм в тестах. А пока посмотрите комментарии в этих тестах, чтобы получить представление о спецификации строки формы).

Создайте пакет pip для вашей пользовательской операции.

Чтобы создать пакет pip для вашей операции, см. пример tensorflow/custom-op . В этом руководстве показано, как создавать собственные операции из пакета pip TensorFlow вместо сборки TensorFlow из исходного кода.