nsl.estimator.add_graph_regularization
Stay organized with collections
Save and categorize content based on your preferences.
Adds graph regularization to a tf.estimator.Estimator
.
nsl.estimator.add_graph_regularization(
estimator, embedding_fn, optimizer_fn=None, graph_reg_config=None
)
Used in the notebooks
Args |
estimator
|
An object of type tf.estimator.Estimator .
|
embedding_fn
|
A function that accepts the input layer (dictionary of feature
names and corresponding batched tensor values) as its first argument,
an instance of tf.estimator.ModeKeys as its second argument to indicate
if the mode is training, evaluation, or prediction, and an optional third
argument named params which is a dict similar to the params argument
of tf.estimator.Estimator 's model_fn , and returns the corresponding
embeddings or logits to be used for graph regularization. The params
argument will receive what was passed to estimator at the time of its
creation as its params argument.
|
optimizer_fn
|
A function that accepts no arguments and returns an instance
of tf.train.Optimizer .
|
graph_reg_config
|
An instance of nsl.configs.GraphRegConfig that specifies
various hyperparameters for graph regularization.
|
Returns |
A modified tf.estimator.Estimator object with graph regularization
incorporated into its loss.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-01-26 UTC.
[null,null,["Last updated 2024-01-26 UTC."],[],[],null,["# nsl.estimator.add_graph_regularization\n\n\u003cbr /\u003e\n\n|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/neural-structured-learning/blob/v1.4.0/neural_structured_learning/estimator/graph_regularization.py#L30-L179) |\n\nAdds graph regularization to a `tf.estimator.Estimator`. \n\n nsl.estimator.add_graph_regularization(\n estimator, embedding_fn, optimizer_fn=None, graph_reg_config=None\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|----------------------------------------------------------------------------------------------------------------------------|\n| - [Graph-based Neural Structured Learning in TFX](https://www.tensorflow.org/tfx/tutorials/tfx/neural_structured_learning) |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|--------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `estimator` | An object of type `tf.estimator.Estimator`. |\n| `embedding_fn` | A function that accepts the input layer (dictionary of feature names and corresponding batched tensor values) as its first argument, an instance of `tf.estimator.ModeKeys` as its second argument to indicate if the mode is training, evaluation, or prediction, and an optional third argument named `params` which is a `dict` similar to the `params` argument of `tf.estimator.Estimator`'s `model_fn`, and returns the corresponding embeddings or logits to be used for graph regularization. The `params` argument will receive what was passed to `estimator` at the time of its creation as its `params` argument. |\n| `optimizer_fn` | A function that accepts no arguments and returns an instance of `tf.train.Optimizer`. |\n| `graph_reg_config` | An instance of [`nsl.configs.GraphRegConfig`](../../nsl/configs/GraphRegConfig) that specifies various hyperparameters for graph regularization. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A modified `tf.estimator.Estimator` object with graph regularization incorporated into its loss. ||\n\n\u003cbr /\u003e"]]