View source on GitHub
|
Adds adversarial regularization to a tf.estimator.Estimator.
nsl.estimator.add_adversarial_regularization(
estimator, optimizer_fn=None, adv_config=None
)
The returned estimator will include the adversarial loss as a regularization
term in its training objective, and will be trained using the optimizer
provided by optimizer_fn. optimizer_fn (along with the hyperparameters)
should be set to the same one used in the base estimator.
If optimizer_fn is not set, a default optimizer tf.train.AdagradOptimizer
with learning_rate=0.05 will be used.
Args | |
|---|---|
estimator
|
A tf.estimator.Estimator object, the base model.
|
optimizer_fn
|
A function that accepts no arguments and returns an instance
of tf.train.Optimizer. This optimizer (instead of the one used in
estimator) will be used to train the model. If not specified, default to
tf.train.AdagradOptimizer with learning_rate=0.05.
|
adv_config
|
An instance of nsl.configs.AdvRegConfig that specifies various
hyperparameters for adversarial regularization.
|
Returns | |
|---|---|
A modified tf.estimator.Estimator object with adversarial regularization
incorporated into its loss.
|
View source on GitHub