![]() |
TF Lattice canned estimators implement typical monotonic model architectures.
You can use TFL canned estimators to easily construct commonly used monotonic
model architectures. To construct a TFL canned estimator, construct a model
configuration from tfl.configs
and pass it to the canned estimator
constructor. To use automated quantile calculation, canned estimators also
require passing a feature_analysis_input_fn which is similar to the one used
for training, but with a single epoch or a subset of the data. To create a
Crystals ensemble model using tfl.configs.CalibratedLatticeEnsembleConfig
, you
will also need to provide a prefitting_input_fn to the estimator constructor.
feature_columns = ...
model_config = tfl.configs.CalibratedLatticeConfig(...)
feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)
train_input_fn = create_input_fn(num_epochs=100, ...)
estimator = tfl.estimators.CannedClassifier(
feature_columns=feature_columns,
model_config=model_config,
feature_analysis_input_fn=feature_analysis_input_fn)
estimator.train(input_fn=train_input_fn)
Supported models are defined in tfl.configs
. Each model architecture can be
used for:
Classification using
tfl.estimators.CannedClassifier
with standard classification head (softmax cross-entropy loss).Regression using
tfl.estimators.CannedRegressor
with standard regression head (squared loss).Custom head using
tfl.estimators.CannedEstimator
with any custom head and loss.
This module also provides tfl.estimators.get_model_graph
as a mechanism to
extract abstract model graphs and layer parameters from saved models. The
resulting graph (not a TF graph) can be used by the tfl.visualization
module
for plotting and other visualization and analysis.
model_graph = estimators.get_model_graph(saved_model_path)
visualization.plot_feature_calibrator(model_graph, "feature_name")
visualization.plot_all_calibrators(model_graph)
visualization.draw_model_graph(model_graph)
Classes
class CannedClassifier
: Canned classifier for TensorFlow lattice models.
class CannedEstimator
: An estimator for TensorFlow lattice models.
class CannedRegressor
: A regressor for TensorFlow lattice models.
class WaitTimeOutError
: Timeout error when waiting for a file.
Functions
get_model_graph(...)
: Returns all layers and parameters used in a saved model as a graph.
transform_features(...)
: Parses the input features using the given feature columns.
Other Members | |
---|---|
FEATURES_SCOPE |
'features'
|
OUTPUT_NAME |
'output'
|
absolute_import |
Instance of __future__._Feature
|
division |
Instance of __future__._Feature
|
print_function |
Instance of __future__._Feature
|