tf.contrib.layers.weighted_sum_from_feature_columns
A tf.contrib.layers style linear prediction builder based on FeatureColumn.
tf.contrib.layers.weighted_sum_from_feature_columns(
columns_to_tensors, feature_columns, num_outputs, weight_collections=None,
trainable=True, scope=None
)
Generally a single example in training data is described with feature columns.
This function generates weighted sum for each num_outputs. Weighted sum refers
to logits in classification problems. It refers to prediction itself for
linear regression problems.
Example:
# Building model for training
feature_columns = (
real_valued_column("my_feature1"),
...
)
columns_to_tensor = tf.io.parse_example(...)
logits = weighted_sum_from_feature_columns(
columns_to_tensors=columns_to_tensor,
feature_columns=feature_columns,
num_outputs=1)
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
logits=logits)
Args |
columns_to_tensors
|
A mapping from feature column to tensors. 'string' key
means a base feature (not-transformed). It can have FeatureColumn as a
key too. That means that FeatureColumn is already transformed by input
pipeline. For example, inflow may have handled transformations.
|
feature_columns
|
A set containing all the feature columns. All items in the
set should be instances of classes derived from FeatureColumn.
|
num_outputs
|
An integer specifying number of outputs. Default value is 1.
|
weight_collections
|
List of graph collections to which weights are added.
|
trainable
|
If True also add variables to the graph collection
GraphKeys.TRAINABLE_VARIABLES (see tf.Variable).
|
scope
|
Optional scope for variable_scope.
|
Returns |
A tuple containing:
- A Tensor which represents predictions of a linear model.
- A dictionary which maps feature_column to corresponding Variable.
- A Variable which is used for bias.
|
Raises |
ValueError
|
if FeatureColumn cannot be used for linear predictions.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[]]