Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Zacznij korzystać z transformacji TensorFlow

Ten przewodnik przedstawia podstawowe pojęcia tf.Transform i jak z nich korzystać. To będzie:

  • Zdefiniuj funkcję przetwarzania wstępnego , logiczny opis potoku, który przekształca nieprzetworzone dane w dane używane do trenowania modelu uczenia maszynowego.
  • Pokaż implementację Apache Beam używaną do przekształcania danych, konwertując funkcję przetwarzania wstępnego na potok Beam .
  • Pokaż dodatkowe przykłady użycia.

Zdefiniuj funkcję przetwarzania wstępnego

Funkcja przetwarzania wstępnego jest najważniejszą koncepcją tf.Transform . Funkcja przetwarzania wstępnego jest logicznym opisem transformacji zbioru danych. Funkcja przetwarzania wstępnego akceptuje i zwraca słownik tensorów, gdzie tensor oznacza Tensor lub 2D SparseTensor . Istnieją dwa rodzaje funkcji używanych do definiowania funkcji przetwarzania wstępnego:

  1. Dowolna funkcja, która akceptuje i zwraca tensory. Dodają one do wykresu operacje TensorFlow, które przekształcają dane surowe w dane przekształcone.
  2. Dowolny z analizatorów dostarczonych przez tf.Transform . Analizatory również akceptują i zwracają tensory, ale w przeciwieństwie do funkcji TensorFlow nie dodają operacji do wykresu. Zamiast tego analizatory powodują, że tf.Transform oblicza operację pełnego przebiegu poza TensorFlow. Używają wejściowych wartości tensorów w całym zbiorze danych do generowania stałego tensora, który jest zwracany jako dane wyjściowe. Na przykład tft.min oblicza minimum tensora w zbiorze danych. tf.Transform zapewnia stały zestaw analizatorów, ale zostanie on rozszerzony w przyszłych wersjach.

Przykład funkcji przetwarzania wstępnego

Łącząc analizatory i zwykłe funkcje TensorFlow, użytkownicy mogą tworzyć elastyczne potoki do przekształcania danych. Następująca funkcja przetwarzania wstępnego przekształca każdą z trzech funkcji na różne sposoby i łączy dwie z nich:

import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam

def preprocessing_fn(inputs):
  x = inputs['x']
  y = inputs['y']
  s = inputs['s']
  x_centered = x - tft.mean(x)
  y_normalized = tft.scale_to_0_1(y)
  s_integerized = tft.compute_and_apply_vocabulary(s)
  x_centered_times_y_normalized = x_centered * y_normalized
  return {
      'x_centered': x_centered,
      'y_normalized': y_normalized,
      'x_centered_times_y_normalized': x_centered_times_y_normalized,
      's_integerized': s_integerized
  }

Tutaj x , y i sTensor które reprezentują cechy wejściowe. Pierwszy utworzony nowy tensor, x_centered , jest tworzony przez zastosowanie tft.mean do x i odjęcie go od x . tft.mean(x) zwraca tensor reprezentujący średnią tensora x . x_centered to tensor x przy odjęciu średniej.

Drugi nowy tensor, y_normalized , jest tworzony w podobny sposób, ale przy użyciu wygodnej metody tft.scale_to_0_1 . Ta metoda robi coś podobnego do obliczania x_centered , a mianowicie x_centered maksimum i minimum i używa ich do skalowania y .

Tensor s_integerized pokazuje przykład manipulacji na ciągach znaków. W tym przypadku bierzemy ciąg znaków i mapujemy go na liczbę całkowitą. Wykorzystuje to wygodną funkcję tft.compute_and_apply_vocabulary . Ta funkcja używa analizatora do obliczenia unikatowych wartości pobranych przez ciągi wejściowe, a następnie używa operacji TensorFlow do konwersji ciągów wejściowych na indeksy w tabeli wartości unikatowych.

Ostatnia kolumna pokazuje, że możliwe jest użycie operacji TensorFlow do tworzenia nowych funkcji poprzez łączenie tensorów.

Funkcja przetwarzania wstępnego definiuje potok operacji na zbiorze danych. Aby zastosować potok, opieramy się na konkretnej implementacji API tf.Transform . Implementacja Apache Beam zapewnia PTransform która stosuje funkcję wstępnego przetwarzania danych użytkownika. Typowy przepływ pracy użytkownika tf.Transform konstruuje funkcję przetwarzania wstępnego, a następnie włącza ją do większego potoku Beam, tworząc dane do uczenia.

Dozowanie

Tworzenie partii jest ważną częścią TensorFlow. Ponieważ jednym z celów tf.Transform jest zapewnienie wykresu TensorFlow do przetwarzania wstępnego, który można włączyć do wykresu serwowania (i opcjonalnie wykresu tf.Transform jest również ważną koncepcją w tf.Transform .

Chociaż nie jest to oczywiste w powyższym przykładzie, funkcja wstępnego przetwarzania zdefiniowana przez użytkownika jest przekazywana tensorom reprezentującym partie, a nie poszczególne instancje, jak to ma miejsce podczas uczenia i udostępniania za pomocą TensorFlow. Z drugiej strony analizatory wykonują obliczenia na całym zestawie danych, które zwracają pojedynczą wartość, a nie partię wartości. x to Tensor o kształcie (batch_size,) , natomiast tft.mean(x) to Tensor o kształcie () . Odejmowanie x - tft.mean(x) nadaje emisję, w której wartość tft.mean(x) jest odejmowana od każdego elementu wsadu reprezentowanego przez x .

Implementacja Apache Beam

Chociaż funkcja przetwarzania wstępnego ma być logicznym opisem potoku przetwarzania wstępnego zaimplementowanego w wielu strukturach przetwarzania danych, tf.Transform zapewnia implementację kanoniczną używaną w Apache Beam. Ta implementacja demonstruje funkcjonalność wymaganą od implementacji. Nie ma formalnego API dla tej funkcjonalności, więc każda implementacja może używać API, które jest idiomatyczne dla jej określonej struktury przetwarzania danych.

Implementacja Apache Beam zapewnia dwa PTransform używane do przetwarzania danych dla funkcji przetwarzania wstępnego. Poniżej przedstawiono użycie złożonego PTransform AnalyzeAndTransformDataset :

raw_data = [
    {'x': 1, 'y': 1, 's': 'hello'},
    {'x': 2, 'y': 2, 's': 'world'},
    {'x': 3, 'y': 3, 's': 'hello'}
]

raw_data_metadata = ...
transformed_dataset, transform_fn = (
    (raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
        preprocessing_fn))
transformed_data, transformed_metadata = transformed_dataset

Zawartość transformed_data jest pokazana poniżej i zawiera przekształcone kolumny w tym samym formacie co nieprzetworzone dane. W szczególności wartości s_integerized to [0, 1, 0] te wartości zależą od tego, jak słowa hello i world zostały odwzorowane na liczby całkowite, co jest deterministyczne. Dla kolumny x_centered , odjęliśmy średnią, tak więc wartości z kolumny x , które wynosiły [1.0, 2.0, 3.0] , stały się [-1.0, 0.0, 1.0] . Podobnie pozostałe kolumny są zgodne z ich oczekiwanymi wartościami.

[{u's_integerized': 0,
  u'x_centered': -1.0,
  u'x_centered_times_y_normalized': -0.0,
  u'y_normalized': 0.0},
 {u's_integerized': 1,
  u'x_centered': 0.0,
  u'x_centered_times_y_normalized': 0.0,
  u'y_normalized': 0.5},
 {u's_integerized': 0,
  u'x_centered': 1.0,
  u'x_centered_times_y_normalized': 1.0,
  u'y_normalized': 1.0}]

Zarówno raw_data jak i transformed_data to raw_data danych. Następne dwie sekcje pokazują, jak implementacja Beam reprezentuje zestawy danych oraz jak odczytywać i zapisywać dane na dysku. Druga wartość zwracana, transform_fn , reprezentuje transformację zastosowaną do danych, omówioną szczegółowo poniżej.

Zestaw AnalyzeAndTransformDataset to zestawienie dwóch podstawowych przekształceń dostarczonych przez implementację AnalyzeDataset i TransformDataset . Zatem następujące dwa fragmenty kodu są równoważne:

transformed_data, transform_fn = (
    my_data | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
transform_fn = my_data | tft_beam.AnalyzeDataset(preprocessing_fn)
transformed_data = (my_data, transform_fn) | tft_beam.TransformDataset()

transform_fn to czysta funkcja, która reprezentuje operację wykonywaną na każdym wierszu zbioru danych. W szczególności wartości analizatora są już obliczone i traktowane jako stałe. W tym przykładzie transform_fn zawiera jako stałe średnią z kolumny x , min i max kolumny y oraz słownictwo używane do odwzorowywania ciągów na liczby całkowite.

Ważną cechą tf.Transform jest to, że transform_fn reprezentuje mapę w wierszach - jest to czysta funkcja stosowana do każdego wiersza oddzielnie. Wszystkie obliczenia związane z agregacją wierszy są wykonywane w AnalyzeDataset . Ponadto transform_fn jest reprezentowany jako Graph TensorFlow, który można osadzić w obsługującym grafie.

AnalyzeAndTransformDataset jest udostępniony do optymalizacji w tym szczególnym przypadku. Jest to ten sam wzorzec używany w scikit-learning , zapewniający metody fit , transform i fit_transform .

Formaty danych i schemat

Implementacja TFT Beam akceptuje dwa różne formaty danych wejściowych. Format „instance dict” (jak widać w powyższym przykładzie oraz w simple_example.py ) jest formatem intuicyjnym i jest odpowiedni dla małych zestawów danych, podczas gdy format TFXIO ( Apache Arrow ) zapewnia lepszą wydajność i jest odpowiedni dla dużych zestawów danych.

Implementacja Beam mówi, w jakim formacie będzie wejściowy PCollection przez "metadane" towarzyszące PCollection:

(raw_data, raw_data_metadata) | tft.AnalyzeDataset(...)
  • Jeśli raw_data_metadata jest dataset_metadata.DatasetMetadata (patrz poniżej, określanym jako „przykład DICT" format”fragment), a następnie raw_data Oczekuje się, że w formacie„wystąpienie dict”.
  • Jeśli raw_data_metadata jest tfxio.TensorAdapterConfig (patrz niżej, „Format TFXIO” sekcja), następnie raw_data ma być w formacie TFXIO.

Format „instancji”

W poprzednich przykładach kodu pominięto kod definiujący raw_data_metadata . Metadane zawierają schemat, który definiuje układ danych, dzięki czemu są odczytywane i zapisywane w różnych formatach. Nawet format w pamięci przedstawiony w ostatniej sekcji nie jest samoopisujący i wymaga schematu, aby mógł zostać zinterpretowany jako tensory.

Oto definicja schematu dla przykładowych danych:

from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils

raw_data_metadata = dataset_metadata.DatasetMetadata(
      schema_utils.schema_from_feature_spec({
        's': tf.io.FixedLenFeature([], tf.string),
        'y': tf.io.FixedLenFeature([], tf.float32),
        'x': tf.io.FixedLenFeature([], tf.float32),
    }))

Klasa dataset_schema.Schema zawiera informacje potrzebne do przeanalizowania danych z ich formatu na dysku lub w pamięci na tensory. Zwykle jest konstruowany przez wywołanie schema_utils.schema_from_feature_spec z schema_utils.schema_from_feature_spec mapującym klucze funkcji do wartości tf.io.FixedLenFeature , tf.io.VarLenFeature i tf.io.SparseFeature . Więcej informacji można znaleźć w dokumentacji tf.parse_example .

Powyżej używamy tf.io.FixedLenFeature aby wskazać, że każda cecha zawiera stałą liczbę wartości, w tym przypadku pojedynczą wartość skalarną. Ponieważ instancje partii tf.Transform , rzeczywisty Tensor reprezentujący tf.Transform będzie miał kształt (None,) gdzie nieznany wymiar jest wymiarem partii.

Format TFXIO

W tym formacie oczekuje się, że dane będą zawarte w pliku pyarrow.RecordBatch . W przypadku danych tabelarycznych nasza implementacja Apache Beam akceptuje RecordBatch Arrow RecordBatch które składają się z kolumn następujących typów:

  • pa.list_(<primitive>) , gdzie <primitive> to pa.int64() , pa.float32() pa.binary() lub pa.large_binary() .

  • pa.large_list(<primitive>)

RecordBatch zestaw danych wejściowych, którego użyliśmy powyżej, przedstawiony jako RecordBatch , wygląda następująco:

raw_data = [
    pa.record_batch([
        pa.array([[1], [2], [3]], pa.list_(pa.float32())),
        pa.array([[1], [2], [3]], pa.list_(pa.float32())),
        pa.array([['hello'], ['world'], ['hello']], pa.list_(pa.binary())),
    ], ['x', 'y', 's'])
]

Podobnie jak w przypadku DatasetMetadata, który musi towarzyszyć formatowi „instance dict”, plik tfxio.TensorAdapterConfig jest potrzebny do dołączenia do RecordBatch . Składa się ze schematu Arrow obiektów RecordBatch i TensorRepresentations celu jednoznacznego określenia sposobu interpretacji kolumn w RecordBatch es jako TensorFlow Tensors (w tym między innymi tf.Tensor, tf.SparseTensor).

TensorRepresentations to Dict[Text, TensorRepresentation] który ustanawia relację między Dict[Text, TensorRepresentation] który akceptuje preprocessing_fn a kolumnami w RecordBatch es. Na przykład:

tensor_representation = {
    'x': text_format.Parse(
        """dense_tensor { column_name: "col1" shape { dim { size: 2 } } }"""
        schema_pb2.TensorRepresentation())
}

Oznacza, że inputs['x'] w preprocessing_fn powinny być gęstym tf.Tensorem, którego wartości pochodzą z kolumny o nazwie 'col1' w wejściowych RecordBatch , a jego kształt [batch_size, 2] ) powinien mieć wartość [batch_size, 2] .

TensorRepresentation to Protobuf zdefiniowany w metadanych TensorFlow .

Wejście i wyjście z Apache Beam

Do tej pory widzieliśmy dane wejściowe i wyjściowe na listach Pythona (w RecordBatch lub słownikach instancji). Jest to uproszczenie, które polega na zdolności Apache Beam do pracy z listami, a także z jego główną reprezentacją danych, PCollection .

PCollection to reprezentacja danych, która tworzy część potoku Beam. PTransform Beam jest tworzony przez zastosowanie różnych PTransform , w tym AnalyzeDataset i TransformDataset , oraz uruchomienie potoku. PCollection nie jest tworzony w pamięci głównego PCollection binarnego, ale zamiast tego jest rozprowadzany wśród pracowników (chociaż ta sekcja używa trybu wykonywania w pamięci).

Pre-puszkach PCollection Źródła ( TFXIO )

Format RecordBatch przez naszą implementację jest popularnym formatem akceptowanym przez inne biblioteki TFX. Dlatego TFX oferuje wygodne „źródła” (aka TFXIO ), które odczytują pliki w różnych formatach na dysku i tworzą RecordBatch a także mogą podawać TensorAdapterConfig , w tym wywnioskowane TensorRepresentations .

Te TFXIO można znaleźć w pakiecie tfx_bsl ( tfx_bsl.public.tfxio ).

Przykład: zbiór danych „Dochód ze spisu ludności”

Poniższy przykład wymaga zarówno odczytu, jak i zapisu danych na dysku oraz przedstawienia danych jako PCollection (nie listy), zobacz: census_example.py . Poniżej pokazujemy, jak pobrać dane i uruchomić ten przykład. Zbiór danych „Census Income” jest dostarczany przez repozytorium UCI Machine Learning . Ten zbiór danych zawiera zarówno dane jakościowe, jak i liczbowe.

Dane są w formacie CSV, oto pierwsze dwa wiersze:

39, State-gov, 77516, Bachelors, 13, Never-married, Adm-clerical, Not-in-family, White, Male, 2174, 0, 40, United-States, <=50K
50, Self-emp-not-inc, 83311, Bachelors, 13, Married-civ-spouse, Exec-managerial, Husband, White, Male, 0, 0, 13, United-States, <=50K

Kolumny zestawu danych są jakościowe lub liczbowe. Ten zbiór danych opisuje problem klasyfikacyjny: przewidywanie ostatniej kolumny, w której osoba zarabia więcej lub mniej niż 50 000 rocznie. Jednak z punktu widzenia tf.Transform ta etykieta to tylko kolejna kategoryczna kolumna.

Używamy TFXIO , BeamRecordCsvTFXIO aby przetłumaczyć linie CSV na RecordBatches . TFXIO wymaga dwóch ważnych informacji:

  • Schemat metadanych TensorFlow zawierający informacje o typie i kształcie każdej kolumny CSV. TensorRepresentation to opcjonalna część schematu; jeśli nie zostaną podane (co ma miejsce w tym przykładzie), zostaną wywnioskowane z informacji o typie i kształcie. Schemat można uzyskać za pomocą funkcji pomocniczej, którą zapewniamy do przetłumaczenia specyfikacji parsowania TF (pokazanych w tym przykładzie) lub przez uruchomienie walidacji danych TensorFlow .

  • lista nazw kolumn w kolejności, w jakiej pojawiają się w pliku CSV. Zauważ, że te nazwy muszą pasować do nazw funkcji w schemacie.

W tym przykładzie zezwalamy na brak funkcji education-num . Oznacza to, że jest reprezentowany jako tf.io.VarLenFeature w feature_spec oraz jako tf.SparseTensor w preprocessing_fn . Inne funkcje staną się tf.Tensor o tej samej nazwie w preprocessing_fn .

csv_tfxio = tfxio.BeamRecordCsvTFXIO(
    physical_format='text', column_names=ordered_columns, schema=SCHEMA)

record_batches = (
    p
    | 'ReadTrainData' >> textio.ReadFromText(train_data_file)
    | ...  # fix up csv lines
    | 'ToRecordBatches' >> csv_tfxio.BeamSource())

tensor_adapter_config = csv_tfxio.TensorAdapterConfig()

Zauważ, że musieliśmy zrobić kilka dodatkowych poprawek po wczytaniu linii CSV. W przeciwnym razie moglibyśmy polegać na CsvTFXIO aby obsłużyć zarówno odczyt plików, jak i tłumaczenie na pliki RecordBatch :

csv_tfxio = tfxio.CsvTFXIO(train_data_file, column_name=ordered_columns,
                           schema=SCHEMA)
record_batches = p | 'TFXIORead' >> csv_tfxio.BeamSource()
tensor_adapter_config = csv_tfxio.TensorAdapterConfig()

Przetwarzanie wstępne jest podobne do poprzedniego przykładu, z wyjątkiem tego, że funkcja przetwarzania wstępnego jest generowana programowo zamiast ręcznego określania każdej kolumny. W poniższej funkcji przetwarzania wstępnego NUMERICAL_COLUMNS i CATEGORICAL_COLUMNS to listy zawierające nazwy kolumn liczbowych i kategorialnych:

def preprocessing_fn(inputs):
  """Preprocess input columns into transformed columns."""
  # Since we are modifying some features and leaving others unchanged, we
  # start by setting `outputs` to a copy of `inputs.
  outputs = inputs.copy()

  # Scale numeric columns to have range [0, 1].
  for key in NUMERIC_FEATURE_KEYS:
    outputs[key] = tft.scale_to_0_1(outputs[key])

  for key in OPTIONAL_NUMERIC_FEATURE_KEYS:
    # This is a SparseTensor because it is optional. Here we fill in a default
    # value when it is missing.
      sparse = tf.sparse.SparseTensor(outputs[key].indices, outputs[key].values,
                                      [outputs[key].dense_shape[0], 1])
      dense = tf.sparse.to_dense(sp_input=sparse, default_value=0.)
    # Reshaping from a batch of vectors of size 1 to a batch to scalars.
    dense = tf.squeeze(dense, axis=1)
    outputs[key] = tft.scale_to_0_1(dense)

  # For all categorical columns except the label column, we generate a
  # vocabulary but do not modify the feature.  This vocabulary is instead
  # used in the trainer, by means of a feature column, to convert the feature
  # from a string to an integer id.
  for key in CATEGORICAL_FEATURE_KEYS:
    tft.vocabulary(inputs[key], vocab_filename=key)

  # For the label column we provide the mapping from string to index.
  initializer = tf.lookup.KeyValueTensorInitializer(
      keys=['>50K', '<=50K'],
      values=tf.cast(tf.range(2), tf.int64),
      key_dtype=tf.string,
      value_dtype=tf.int64)
  table = tf.lookup.StaticHashTable(initializer, default_value=-1)

  outputs[LABEL_KEY] = table.lookup(outputs[LABEL_KEY])

  return outputs

Jedną z różnic w stosunku do poprzedniego przykładu jest to, że kolumna etykiety ręcznie określa mapowanie z ciągu na indeks. Zatem '>50' jest mapowane na 0 a '<=50K' jest mapowane na 1 ponieważ warto wiedzieć, który indeks w uczonym modelu odpowiada której etykiecie.

Zmienna record_batches reprezentuje PCollection of pyarrow.RecordBatch es. Wartość tensor_adapter_config jest podawana przez csv_tfxio , który jest wywnioskowany ze csv_tfxio SCHEMA (i ostatecznie, w tym przykładzie, ze specyfikacji parsowania TF).

Ostatnim etapem jest zapisanie transformowanych danych na dysk i ma podobną postać do odczytu danych surowych. Schemat używany do tego jest częścią danych wyjściowych AnalyzeAndTransformDataset który wnioskuje o schemat dla danych wyjściowych. Kod do zapisu na dysku jest pokazany poniżej. Schemat jest częścią metadanych, ale używa ich zamiennie w tf.Transform API (tj. Przekazuje metadane do ExampleProtoCoder ). Należy pamiętać, że zapisuje to w innym formacie. Zamiast textio.WriteToText , użyj wbudowanej obsługi Beam dla formatu TFRecord i użyj kodera do kodowania danych jako Example protokołów. Jest to lepszy format do wykorzystania podczas treningu, jak pokazano w następnej sekcji. transformed_eval_data_base zapewnia podstawową nazwę pliku dla poszczególnych zapisywanych fragmentów.

transformed_data | "WriteTrainData" >> tfrecordio.WriteToTFRecord(
    transformed_eval_data_base,
    coder=tft.coders.ExampleProtoCoder(transformed_metadata))

Oprócz danych transform_fn , transform_fn jest również zapisywane z metadanymi:

_ = (
    transform_fn
    | 'WriteTransformFn' >> tft_beam.WriteTransformFn(working_dir))
transformed_metadata | 'WriteMetadata' >> tft_beam.WriteMetadata(
    transformed_metadata_file, pipeline=p)

Uruchom cały potok Beam za pomocą p.run().wait_until_finish() . Do tego momentu potok Beam reprezentuje odroczone, rozproszone obliczenia. Zawiera instrukcje dotyczące tego, co zostanie zrobione, ale instrukcje nie zostały wykonane. To końcowe wywołanie wykonuje określony potok.

Pobierz zbiór danych spisu

Pobierz zestaw danych spisu za pomocą następujących poleceń powłoki:

  wget https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data
  wget https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test

Podczas uruchamiania skryptu census_example.py katalog zawierający te dane jako pierwszy argument. Skrypt tworzy tymczasowy podkatalog w celu dodania wstępnie przetworzonych danych.

Integracja ze szkoleniami TensorFlow

Ostatnia sekcja census_example.py pokazuje, jak wstępnie przetworzone dane są wykorzystywane do census_example.py modelu. Szczegółowe informacje można znaleźć w dokumentacji estymatorów . Pierwszym krokiem jest skonstruowanie Estimator który wymaga opisu wstępnie przetworzonych kolumn. Każda kolumna liczbowa jest opisana jako kolumna real_valued_column która jest opakowaniem wokół gęstego wektora o stałym rozmiarze ( 1 w tym przykładzie). Każda kolumna kategorialna jest odwzorowywana z łańcucha na liczby całkowite, a następnie przekazywana do kolumny indicator_column . tft.TFTransformOutput służy do znajdowania ścieżki pliku słownika dla każdej funkcji kategorialnej.

real_valued_columns = [feature_column.real_valued_column(key)
                       for key in NUMERIC_FEATURE_KEYS]

one_hot_columns = [
    tf.feature_column.indicator_column(
        tf.feature_column.categorical_column_with_vocabulary_file(
            key=key,
            vocabulary_file=tf_transform_output.vocabulary_file_by_name(
                vocab_filename=key)))
    for key in CATEGORICAL_FEATURE_KEYS]

estimator = tf.estimator.LinearClassifier(real_valued_columns + one_hot_columns)

Następnym krokiem jest utworzenie konstruktora, który wygeneruje funkcję wejściową do szkolenia i oceny. Różni się od treningu używanego przez tf.Learn ponieważ specyfikacja funkcji nie jest wymagana do przeanalizowania przekształconych danych. Zamiast tego użyj metadanych dla przekształconych danych, aby wygenerować specyfikację funkcji.

def _make_training_input_fn(tf_transform_output, transformed_examples,
                            batch_size):
  ...
  def input_fn():
    """Input function for training and eval."""
    dataset = tf.data.experimental.make_batched_features_dataset(
        ..., tf_transform_output.transformed_feature_spec(), ...)

    transformed_features = tf.compat.v1.data.make_one_shot_iterator(
        dataset).get_next()
    ...

  return input_fn

Pozostały kod jest taki sam, jak w przypadku klasy Estimator . Przykład zawiera również kod do wyeksportowania modelu w formacie SavedModel . Wyeksportowany model może być używany przez Tensorflow Serving lub Cloud ML Engine .