การย้าย feature_columns ไปยัง Keras Preprocessing Layers . ของ TF2

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

การฝึกโมเดลมักจะมาพร้อมกับการประมวลผลคุณลักษณะล่วงหน้าจำนวนหนึ่ง โดยเฉพาะอย่างยิ่งเมื่อต้องจัดการกับข้อมูลที่มีโครงสร้าง เมื่อฝึก tf.estimator.Estimator ใน TF1 การประมวลผลล่วงหน้าของฟีเจอร์นี้มักจะเสร็จสิ้นด้วย tf.feature_column API ใน TF2 การประมวลผลล่วงหน้านี้สามารถทำได้โดยตรงกับเลเยอร์ Keras ซึ่งเรียกว่า เลเยอร์ การประมวลผลล่วงหน้า

ในคู่มือการย้ายข้อมูลนี้ คุณจะทำการเปลี่ยนแปลงคุณลักษณะทั่วไปบางอย่างโดยใช้ทั้งคอลัมน์คุณลักษณะและเลเยอร์การประมวลผลล่วงหน้า ตามด้วยการฝึกโมเดลที่สมบูรณ์ด้วย API ทั้งสอง

ขั้นแรก เริ่มต้นด้วยการนำเข้าที่จำเป็นสองสามอย่าง

import tensorflow as tf
import tensorflow.compat.v1 as tf1
import math

และเพิ่มยูทิลิตี้สำหรับการเรียกคอลัมน์คุณลักษณะสำหรับการสาธิต:

def call_feature_columns(feature_columns, inputs):
  # This is a convenient way to call a `feature_column` outside of an estimator
  # to display its output.
  feature_layer = tf1.keras.layers.DenseFeatures(feature_columns)
  return feature_layer(inputs)

การจัดการอินพุต

ในการใช้คอลัมน์คุณลักษณะที่มีตัวประมาณการ อินพุตของโมเดลมักถูกคาดหวังให้เป็นพจนานุกรมของเมตริกซ์เสมอ:

input_dict = {
  'foo': tf.constant([1]),
  'bar': tf.constant([0]),
  'baz': tf.constant([-1])
}

ต้องสร้างคอลัมน์คุณลักษณะแต่ละคอลัมน์ด้วยคีย์เพื่อสร้างดัชนีลงในข้อมูลต้นทาง เอาต์พุตของคอลัมน์คุณลักษณะทั้งหมดจะถูกต่อกันและใช้โดยโมเดลตัวประมาณ

columns = [
  tf1.feature_column.numeric_column('foo'),
  tf1.feature_column.numeric_column('bar'),
  tf1.feature_column.numeric_column('baz'),
]
call_feature_columns(columns, input_dict)
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[ 0., -1.,  1.]], dtype=float32)>

ใน Keras อินพุตโมเดลมีความยืดหยุ่นมากกว่ามาก tf.keras.Model สามารถจัดการอินพุตเทนเซอร์รายการเดียว รายการคุณสมบัติเทนเซอร์ หรือพจนานุกรมของคุณสมบัติเทนเซอร์ คุณสามารถจัดการอินพุตพจนานุกรมโดยส่งพจนานุกรมของ tf.keras.Input ในการสร้างแบบจำลอง อินพุตจะไม่ถูกต่อโดยอัตโนมัติ ซึ่งช่วยให้ใช้งานได้หลากหลายมากขึ้น สามารถต่อด้วย tf.keras.layers.Concatenate

inputs = {
  'foo': tf.keras.Input(shape=()),
  'bar': tf.keras.Input(shape=()),
  'baz': tf.keras.Input(shape=()),
}
# Inputs are typically transformed by preprocessing layers before concatenation.
outputs = tf.keras.layers.Concatenate()(inputs.values())
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model(input_dict)
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 1.,  0., -1.], dtype=float32)>

รหัสจำนวนเต็มการเข้ารหัสแบบร้อนครั้งเดียว

การแปลงลักษณะทั่วไปคืออินพุตจำนวนเต็มการเข้ารหัสแบบร้อนครั้งเดียวของช่วงที่ทราบ นี่คือตัวอย่างการใช้คอลัมน์คุณลักษณะ:

categorical_col = tf1.feature_column.categorical_column_with_identity(
    'type', num_buckets=3)
indicator_col = tf1.feature_column.indicator_column(categorical_col)
call_feature_columns(indicator_col, {'type': [0, 1, 2]})
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>

การใช้เลเยอร์การประมวลผลล่วงหน้าของ Keras คอลัมน์เหล่านี้สามารถแทนที่ด้วยเลเยอร์ tf.keras.layers.CategoryEncoding เดียวโดยตั้งค่า output_mode เป็น 'one_hot' :

one_hot_layer = tf.keras.layers.CategoryEncoding(
    num_tokens=3, output_mode='one_hot')
one_hot_layer([0, 1, 2])
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>

การปรับคุณสมบัติตัวเลขให้เป็นมาตรฐาน

เมื่อจัดการคุณลักษณะจุดลอยตัวแบบต่อเนื่องด้วยคอลัมน์คุณลักษณะ คุณต้องใช้ tf.feature_column.numeric_column ในกรณีที่อินพุตถูกทำให้เป็นมาตรฐานแล้ว การแปลงเป็น Keras นั้นไม่สำคัญ คุณสามารถใช้ tf.keras.Input ลงในโมเดลของคุณได้โดยตรงดังที่แสดงไว้ด้านบน

numeric_column สามารถใช้เพื่อทำให้อินพุตเป็นมาตรฐานได้เช่นกัน:

def normalize(x):
  mean, variance = (2.0, 1.0)
  return (x - mean) / math.sqrt(variance)
numeric_col = tf1.feature_column.numeric_column('col', normalizer_fn=normalize)
call_feature_columns(numeric_col, {'col': tf.constant([[0.], [1.], [2.]])})
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-2.],
       [-1.],
       [ 0.]], dtype=float32)>

ตรงกันข้ามกับ Keras การทำให้เป็นมาตรฐานนี้สามารถทำได้ด้วย tf.keras.layers.Normalization

normalization_layer = tf.keras.layers.Normalization(mean=2.0, variance=1.0)
normalization_layer(tf.constant([[0.], [1.], [2.]]))
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-2.],
       [-1.],
       [ 0.]], dtype=float32)>

Bucketizing และคุณสมบัติตัวเลขการเข้ารหัสแบบร้อนครั้งเดียว

การเปลี่ยนแปลงทั่วไปอีกประการหนึ่งของอินพุตจุดลอยตัวแบบต่อเนื่องคือการเพิ่มจำนวนเต็มจากนั้นให้เป็นจำนวนเต็มของช่วงคงที่

ในคอลัมน์คุณลักษณะ สามารถทำได้ด้วย tf.feature_column.bucketized_column :

numeric_col = tf1.feature_column.numeric_column('col')
bucketized_col = tf1.feature_column.bucketized_column(numeric_col, [1, 4, 5])
call_feature_columns(bucketized_col, {'col': tf.constant([1., 2., 3., 4., 5.])})
<tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]], dtype=float32)>

ใน Keras สิ่งนี้สามารถแทนที่ด้วย tf.keras.layers.Discretization :

discretization_layer = tf.keras.layers.Discretization(bin_boundaries=[1, 4, 5])
one_hot_layer = tf.keras.layers.CategoryEncoding(
    num_tokens=4, output_mode='one_hot')
one_hot_layer(discretization_layer([1., 2., 3., 4., 5.]))
<tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]], dtype=float32)>

ข้อมูลสตริงการเข้ารหัสแบบร้อนครั้งเดียวพร้อมคำศัพท์

การจัดการคุณลักษณะสตริงมักต้องใช้การค้นหาคำศัพท์เพื่อแปลสตริงเป็นดัชนี ต่อไปนี้คือตัวอย่างการใช้คอลัมน์คุณลักษณะเพื่อค้นหาสตริง จากนั้นจึงเข้ารหัสดัชนีแบบลัดครั้งเดียว:

vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
    'sizes',
    vocabulary_list=['small', 'medium', 'large'],
    num_oov_buckets=0)
indicator_col = tf1.feature_column.indicator_column(vocab_col)
call_feature_columns(indicator_col, {'sizes': ['small', 'medium', 'large']})
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>

ใช้เลเยอร์การประมวลผลล่วงหน้าของ tf.keras.layers.StringLookup ใช้เลเยอร์ output_mode โดยตั้งค่า output_mode เป็น 'one_hot' :

string_lookup_layer = tf.keras.layers.StringLookup(
    vocabulary=['small', 'medium', 'large'],
    num_oov_indices=0,
    output_mode='one_hot')
string_lookup_layer(['small', 'medium', 'large'])
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>
ตัวยึดตำแหน่ง23

การฝังข้อมูลสตริงด้วยคำศัพท์

สำหรับคำศัพท์ที่ใหญ่ขึ้น มักจะจำเป็นต้องมีการฝังเพื่อประสิทธิภาพที่ดี ต่อไปนี้คือตัวอย่างการฝังคุณลักษณะสตริงโดยใช้คอลัมน์คุณลักษณะ:

vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
    'col',
    vocabulary_list=['small', 'medium', 'large'],
    num_oov_buckets=0)
embedding_col = tf1.feature_column.embedding_column(vocab_col, 4)
call_feature_columns(embedding_col, {'col': ['small', 'medium', 'large']})
<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[-0.01798586, -0.2808677 ,  0.27639154,  0.06081508],
       [ 0.05771849,  0.02464074,  0.20080602,  0.50164527],
       [-0.9208247 , -0.40816694, -0.49132794,  0.9203153 ]],
      dtype=float32)>

การใช้เลเยอร์การประมวลผลล่วงหน้าของ Keras สามารถทำได้โดยการรวมเลเยอร์ tf.keras.layers.StringLookup และเลเยอร์ tf.keras.layers.Embedding เอาต์พุตเริ่มต้นสำหรับ StringLookup จะเป็นดัชนีจำนวนเต็มซึ่งสามารถป้อนโดยตรงในการฝัง

string_lookup_layer = tf.keras.layers.StringLookup(
    vocabulary=['small', 'medium', 'large'], num_oov_indices=0)
embedding = tf.keras.layers.Embedding(3, 4)
embedding(string_lookup_layer(['small', 'medium', 'large']))
<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[ 0.04838837, -0.04014301,  0.02001903, -0.01150769],
       [-0.04580117, -0.04319514,  0.03725603, -0.00572466],
       [-0.0401094 ,  0.00997342,  0.00111955,  0.00132702]],
      dtype=float32)>

สรุปข้อมูลหมวดหมู่ถ่วงน้ำหนัก

ในบางกรณี คุณจำเป็นต้องจัดการกับข้อมูลที่เป็นหมวดหมู่ซึ่งแต่ละหมวดหมู่มาพร้อมกับน้ำหนักที่เกี่ยวข้อง ในคอลัมน์คุณลักษณะ สิ่งนี้ถูกจัดการด้วย tf.feature_column.weighted_categorical_column เมื่อจับคู่กับ indicator_column จะมีผลรวมน้ำหนักต่อหมวดหมู่

ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])

categorical_col = tf1.feature_column.categorical_column_with_identity(
    'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
    categorical_col, 'weights')
indicator_col = tf1.feature_column.indicator_column(weighted_categorical_col)
call_feature_columns(indicator_col, {'ids': ids, 'weights': weights})
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4203: sparse_merge (from tensorflow.python.ops.sparse_ops) is deprecated and will be removed in a future version.
Instructions for updating:
No similar op available at this time.
<tf.Tensor: shape=(1, 20), dtype=float32, numpy=
array([[0. , 0. , 0. , 0. , 0. , 1.2, 0. , 0. , 0. , 0. , 0. , 1.5, 0. ,

        0. , 0. , 0. , 0. , 2. , 0. , 0. ]], dtype=float32)>

ใน Keras สิ่งนี้สามารถทำได้โดยส่งอินพุต count_weights ไปยัง tf.keras.layers.CategoryEncoding ด้วย output_mode='count'

ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])

# Using sparse output is more efficient when `num_tokens` is large.
count_layer = tf.keras.layers.CategoryEncoding(
    num_tokens=20, output_mode='count', sparse=True)
tf.sparse.to_dense(count_layer(ids, count_weights=weights))
<tf.Tensor: shape=(1, 20), dtype=float32, numpy=
array([[0. , 0. , 0. , 0. , 0. , 1.2, 0. , 0. , 0. , 0. , 0. , 1.5, 0. ,

        0. , 0. , 0. , 0. , 2. , 0. , 0. ]], dtype=float32)>

การฝังข้อมูลหมวดหมู่ถ่วงน้ำหนัก

คุณอาจต้องการฝังอินพุตการจัดหมวดหมู่แบบถ่วงน้ำหนักอีกทางหนึ่ง ในคอลัมน์คุณลักษณะ embedding_column มีอาร์กิวเมนต์ตัว combiner หากตัวอย่างใดมีรายการหลายรายการสำหรับหมวดหมู่หนึ่ง รายการเหล่านั้นจะถูกรวมตามการตั้งค่าอาร์กิวเมนต์ (โดยค่าเริ่มต้น 'mean' )

ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])

categorical_col = tf1.feature_column.categorical_column_with_identity(
    'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
    categorical_col, 'weights')
embedding_col = tf1.feature_column.embedding_column(
    weighted_categorical_col, 4, combiner='mean')
call_feature_columns(embedding_col, {'ids': ids, 'weights': weights})
<tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[ 0.02666993,  0.289671  ,  0.18065728, -0.21045178]],
      dtype=float32)>
ตัวยึดตำแหน่ง33

ใน combiner ไม่มีตัวเลือกตัวรวมสำหรับ tf.keras.layers.Embedding แต่คุณสามารถบรรลุผลเช่นเดียวกันกับ tf.keras.layers.Dense embedding_column ด้านบนเป็นเพียงการรวมเวกเตอร์การฝังเชิงเส้นตามน้ำหนักหมวดหมู่ แม้ว่าจะไม่ชัดเจนในตอนแรก แต่ก็เทียบเท่ากับการแสดงอินพุตตามหมวดหมู่ของคุณเป็นเวกเตอร์น้ำหนักแบบเบาบางของขนาด (num_tokens) และคูณด้วยเคอร์เนลที่มีรูปร่าง Dense (embedding_size, num_tokens)

ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])

# For `combiner='mean'`, normalize your weights to sum to 1. Removing this line
# would be eqivalent to an `embedding_column` with `combiner='sum'`.
weights = weights / tf.reduce_sum(weights, axis=-1, keepdims=True)

count_layer = tf.keras.layers.CategoryEncoding(
    num_tokens=20, output_mode='count', sparse=True)
embedding_layer = tf.keras.layers.Dense(4, use_bias=False)
embedding_layer(count_layer(ids, count_weights=weights))
<tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.03897291, -0.27131438,  0.09332469,  0.04333957]],
      dtype=float32)>

ตัวอย่างการฝึกที่สมบูรณ์

ในการแสดงเวิร์กโฟลว์การฝึกอบรมที่สมบูรณ์ ขั้นแรกให้เตรียมข้อมูลที่มีคุณลักษณะสามประเภทที่แตกต่างกัน:

features = {
    'type': [0, 1, 1],
    'size': ['small', 'small', 'medium'],
    'weight': [2.7, 1.8, 1.6],
}
labels = [1, 1, 0]
predict_features = {'type': [0], 'size': ['foo'], 'weight': [-0.7]}

กำหนดค่าคงที่ทั่วไปบางอย่างสำหรับทั้ง TF1 และ TF2 เวิร์กโฟลว์:

vocab = ['small', 'medium', 'large']
one_hot_dims = 3
embedding_dims = 4
weight_mean = 2.0
weight_variance = 1.0

ด้วยคอลัมน์คุณสมบัติ

คอลัมน์คุณลักษณะต้องถูกส่งผ่านเป็นรายการไปยังตัวประมาณการสร้าง และจะถูกเรียกโดยปริยายในระหว่างการฝึกอบรม

categorical_col = tf1.feature_column.categorical_column_with_identity(
    'type', num_buckets=one_hot_dims)
# Convert index to one-hot; e.g. [2] -> [0,0,1].
indicator_col = tf1.feature_column.indicator_column(categorical_col)

# Convert strings to indices; e.g. ['small'] -> [1].
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
    'size', vocabulary_list=vocab, num_oov_buckets=1)
# Embed the indices.
embedding_col = tf1.feature_column.embedding_column(vocab_col, embedding_dims)

normalizer_fn = lambda x: (x - weight_mean) / math.sqrt(weight_variance)
# Normalize the numeric inputs; e.g. [2.0] -> [0.0].
numeric_col = tf1.feature_column.numeric_column(
    'weight', normalizer_fn=normalizer_fn)

estimator = tf1.estimator.DNNClassifier(
    feature_columns=[indicator_col, embedding_col, numeric_col],
    hidden_units=[1])

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

estimator.train(_input_fn)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp8lwbuor2
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp8lwbuor2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp8lwbuor2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.54634213, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmp8lwbuor2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Loss for final step: 0.7308526.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f90685d53d0>
ตัวยึดตำแหน่ง39

คอลัมน์คุณลักษณะจะยังใช้เพื่อแปลงข้อมูลอินพุตเมื่อเรียกใช้การอนุมานในโมเดล

def _predict_fn():
  return tf1.data.Dataset.from_tensor_slices(predict_features).batch(1)

next(estimator.predict(_predict_fn))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp8lwbuor2/model.ckpt-3
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
{'logits': array([0.5172372], dtype=float32),
 'logistic': array([0.6265015], dtype=float32),
 'probabilities': array([0.37349847, 0.6265015 ], dtype=float32),
 'class_ids': array([1]),
 'classes': array([b'1'], dtype=object),
 'all_class_ids': array([0, 1], dtype=int32),
 'all_classes': array([b'0', b'1'], dtype=object)}
ตัวยึดตำแหน่ง41

ด้วยเลเยอร์การประมวลผลล่วงหน้าของ Keras

เลเยอร์การประมวลผลล่วงหน้าของ Keras นั้นยืดหยุ่นกว่าในตำแหน่งที่สามารถเรียกได้ สามารถใช้เลเยอร์กับเทนเซอร์ได้โดยตรง ใช้ภายในไพพ์ไลน์อินพุต tf.data หรือสร้างโดยตรงในโมเดล Keras ที่ฝึกได้

ในตัวอย่างนี้ คุณจะใช้เลเยอร์การประมวลผลล่วงหน้าภายในไปป์ไลน์อินพุต tf.data ในการดำเนินการนี้ คุณสามารถกำหนด tf.keras.Model แยกต่างหากเพื่อประมวลผลคุณลักษณะอินพุตของคุณล่วงหน้า โมเดลนี้ไม่สามารถฝึกได้ แต่เป็นวิธีที่สะดวกในการจัดกลุ่มเลเยอร์การประมวลผลล่วงหน้า

inputs = {
  'type': tf.keras.Input(shape=(), dtype='int64'),
  'size': tf.keras.Input(shape=(), dtype='string'),
  'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Convert index to one-hot; e.g. [2] -> [0,0,1].
type_output = tf.keras.layers.CategoryEncoding(
      one_hot_dims, output_mode='one_hot')(inputs['type'])
# Convert size strings to indices; e.g. ['small'] -> [1].
size_output = tf.keras.layers.StringLookup(vocabulary=vocab)(inputs['size'])
# Normalize the numeric inputs; e.g. [2.0] -> [0.0].
weight_output = tf.keras.layers.Normalization(
      axis=None, mean=weight_mean, variance=weight_variance)(inputs['weight'])
outputs = {
  'type': type_output,
  'size': size_output,
  'weight': weight_output,
}
preprocessing_model = tf.keras.Model(inputs, outputs)

ตอนนี้คุณสามารถใช้โมเดลนี้ในการเรียก tf.data.Dataset.map โปรดทราบว่าฟังก์ชันที่ส่งไปยัง map จะถูกแปลงเป็น tf.function โดยอัตโนมัติ และมีการใช้คำเตือนตามปกติในการเขียนโค้ด tf.function (ไม่มีผลข้างเคียง)

# Apply the preprocessing in tf.data.Dataset.map.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
dataset = dataset.map(lambda x, y: (preprocessing_model(x), y),
                      num_parallel_calls=tf.data.AUTOTUNE)
# Display a preprocessed input sample.
next(dataset.take(1).as_numpy_iterator())
({'type': array([[1., 0., 0.]], dtype=float32),
  'size': array([1]),
  'weight': array([0.70000005], dtype=float32)},
 array([1], dtype=int32))

ถัดไป คุณสามารถกำหนด Model แยกกันที่มีเลเยอร์ที่สามารถฝึกได้ สังเกตว่าอินพุตของโมเดลนี้สะท้อนถึงประเภทและรูปร่างของฟีเจอร์ที่ประมวลผลล่วงหน้าอย่างไร

inputs = {
  'type': tf.keras.Input(shape=(one_hot_dims,), dtype='float32'),
  'size': tf.keras.Input(shape=(), dtype='int64'),
  'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Since the embedding is trainable, it needs to be part of the training model.
embedding = tf.keras.layers.Embedding(len(vocab), embedding_dims)
outputs = tf.keras.layers.Concatenate()([
  inputs['type'],
  embedding(inputs['size']),
  tf.expand_dims(inputs['weight'], -1),
])
outputs = tf.keras.layers.Dense(1)(outputs)
training_model = tf.keras.Model(inputs, outputs)

ตอนนี้คุณสามารถฝึก training_model ด้วย tf.keras.Model.fit

# Train on the preprocessed data.
training_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))
training_model.fit(dataset)
3/3 [==============================] - 0s 3ms/step - loss: 0.7248
<keras.callbacks.History at 0x7f9041a294d0>

สุดท้าย ในเวลาอนุมาน อาจเป็นประโยชน์ในการรวมขั้นตอนที่แยกจากกันเหล่านี้เป็นโมเดลเดียวที่จัดการอินพุตคุณสมบัติดิบ

inputs = preprocessing_model.input
outpus = training_model(preprocessing_model(inputs))
inference_model = tf.keras.Model(inputs, outpus)

predict_dataset = tf.data.Dataset.from_tensor_slices(predict_features).batch(1)
inference_model.predict(predict_dataset)
array([[0.936637]], dtype=float32)

โมเดลที่ประกอบขึ้นนี้สามารถบันทึกเป็น SavedModel เพื่อใช้ในภายหลังได้

inference_model.save('model')
restored_model = tf.keras.models.load_model('model')
restored_model.predict(predict_dataset)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
2021-10-27 01:23:25.649967: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: model/assets
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
array([[0.936637]], dtype=float32)
ตัวยึดตำแหน่ง51

ตารางคุณสมบัติสมมูลของคอลัมน์

สำหรับการอ้างอิง นี่คือการติดต่อโดยประมาณระหว่างคอลัมน์คุณลักษณะและเลเยอร์การประมวลผลล่วงหน้า:

คอลัมน์คุณลักษณะ Keras Layer
feature_column.bucketized_column layers.Discretization
feature_column.categorical_column_with_hash_bucket layers.Hashing
feature_column.categorical_column_with_identity layers.CategoryEncoding
feature_column.categorical_column_with_vocabulary_file layers.StringLookup หรือ layers.IntegerLookup
feature_column.categorical_column_with_vocabulary_list layers.StringLookup หรือ layers.IntegerLookup
feature_column.crossed_column ไม่ได้ดำเนินการ.
feature_column.embedding_column layers.Embedding
feature_column.indicator_column output_mode='one_hot' หรือ output_mode='multi_hot' *
feature_column.numeric_column layers.Normalization
feature_column.sequence_categorical_column_with_hash_bucket layers.Hashing
feature_column.sequence_categorical_column_with_identity layers.CategoryEncoding
feature_column.sequence_categorical_column_with_vocabulary_file layers.StringLookup , layers.IntegerLookup หรือ layer.TextVectorization
feature_column.sequence_categorical_column_with_vocabulary_list layers.StringLookup , layers.IntegerLookup หรือ layer.TextVectorization
feature_column.sequence_numeric_column layers.Normalization
feature_column.weighted_categorical_column layers.CategoryEncoding

* output_mode สามารถส่งต่อไปยัง layers.CategoryEncoding , layers.StringLookup , layers.IntegerLookup และ layers.TextVectorization

layers.TextVectorization สามารถจัดการการป้อนข้อความในรูปแบบอิสระได้โดยตรง (เช่น ทั้งประโยคหรือย่อหน้า) นี่ไม่ใช่การแทนที่แบบหนึ่งต่อหนึ่งสำหรับการจัดการลำดับตามหมวดหมู่ใน TF1 แต่อาจเสนอการแทนที่ที่สะดวกสำหรับการประมวลผลข้อความเฉพาะกิจล่วงหน้า

ขั้นตอนถัดไป