Have a question? Connect with the community at the TensorFlow Forum Visit Forum


TensorFlow 1 version View source on GitHub

An estimator for TensorFlow Linear and DNN joined classification models.

Inherits From: Estimator


numeric_feature = numeric_column(...)
categorical_column_a = categorical_column_with_hash_bucket(...)
categorical_column_b = categorical_column_with_hash_bucket(...)

categorical_feature_a_x_categorical_feature_b = crossed_column(...)
categorical_feature_a_emb = embedding_column(
    categorical_column=categorical_feature_a, ...)
categorical_feature_b_emb = embedding_column(
    categorical_id_column=categorical_feature_b, ...)

estimator = tf.estimator.DNNLinearCombinedClassifier(
    # wide settings
    # deep settings
        categorical_feature_a_emb, categorical_feature_b_emb,
    dnn_hidden_units=[1000, 500, 100],
    # warm-start settings

# To apply L1 and L2 regularization, you can set dnn_optimizer to:
# To apply learning rate decay, you can set dnn_optimizer to a callable:
lambda: tf.keras.optimizers.Adam(
# It is the same for linear_optimizer.

# Input builders
def input_fn_train:
  # Returns tf.data.Dataset of (x, y) tuple where y represents label's class
  # index.
def input_fn_eval:
  # Returns tf.data.Dataset of (x, y) tuple where y represents label's class
  # index.
def input_fn_predict:
  # Returns tf.data.Dataset of (x, None) tuple.
estimator.train(input_fn=input_fn_train, steps=100)
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
predictions = estimator.predict(input_fn=input_fn_predict)

Input of train and evaluate should have following features, otherwise there will be a KeyError:

  • for each column in dnn_feature_columns + linear_feature_columns:
    • if column is a CategoricalColumn, a feature with key=column.name whose value is a SparseTensor.
    • if column is a WeightedCategoricalColumn, two features: the first with key the id column name, the second with key the weight column name. Both features' value must be a SparseTensor.
    • if column is a DenseColumn, a feature with key=column.name whose value is a Tensor.

Loss is calculated by using softmax cross entropy.

model_dir Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.
linear_feature_columns An iterable containing all the feature columns used by linear part of the model. All items in the set must be instances of classes derived from FeatureColumn.
linear_optimizer An instance of tf.keras.optimizers.* used to apply gradients to the linear part of the model. Can also be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL optimizer.
dnn_feature_columns An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from FeatureColumn.
dnn_optimizer An instance of tf.keras.optimizers.* used to apply gradients to the deep part of the model. Can also be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad optimizer.
dnn_hidden_units List of hidden units per layer. All layers are fully connected.
dnn_activation_fn Activation function applied to each layer. If None, will use tf.nn.relu.
dnn_dropout When not None, the probability we will drop out a given coordinate.
n_classes Number of label classes. Defaults to 2, namely binary classification. Must be > 1.
weight_column A string or a _NumericColumn created by tf.feature_column.numeric_column defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. If it is a string, it is used as a key to fetch weight tensor from the features. If it is a _NumericColumn, raw tensor is fetched by key weight_column.key, then weight_column.normalizer_fn is applied on it to get weight tensor.
label_vocabulary A list of strings represents possible label values. If given, labels must be string type and have any value in label_vocabulary. If it is not given, that means labels are already encoded as integer or float within [0, 1] for n_classes=2 and encoded as integer values in {0, 1,..., n_classes-1} for n_classes>2 . Also there will be errors if vocabulary is not provided and labels are string.
config RunConfig object to configure the runtime settings.
warm_start_from A string filepath to a checkpoint to warm-start from, or a WarmStartSettings object to fully configure warm-starting. If the string filepath is provided instead of a WarmStartSettings, then all weights are warm-started, and it is assumed that vocabularies and Tensor names are unchanged.
loss_reduction One of tf.losses.Reduction except NONE. Describes how to reduce training loss over batch. Defaults to SUM_OVER_BATCH_SIZE.
batch_norm Whether to use batch normalization after each hidden layer.
linear_sparse_combiner A string specifying how to reduce the linear model if a categorical column is multivalent. One of "mean", "sqrtn", and "sum" -- these are effectively different ways to do example-level normalization, which can be useful for bag-of-words features. For more details, see tf.feature_column.linear_model.

ValueError If both linear_feature_columns and dnn_features_columns are empty at the same time.

Eager Compatibility

Estimators can be used while eager execution is enabled. Note that input_fn and all hooks are executed inside a graph context, so they have to be written to be compatible with graph mode. Note that input_fn code using tf.data generally works in both graph and eager modes.




model_fn Returns the model_fn which is bound to self.params.



View source

Shows the directory name where evaluation metrics are dumped.

name Name of the evaluation if user needs to run multiple evaluations on different data sets, such as on training data vs test data. Metrics for different evaluations are saved in separate folders, and appear separately in tensorboard.

A string which is the path of directory contains evaluation metrics.


View source

Evaluates the model given evaluation data input_fn.

For each step, calls input_fn, which returns one batch of data. Evaluates until:

input_fn A function that constructs the input data for evaluation. See Premade Estimators for more information. The function should construct and return one of the following:

  • A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below.
  • A tuple (features, labels): Where features