nsl.keras.adversarial_loss
Stay organized with collections
Save and categorize content based on your preferences.
Computes the adversarial loss for model
given features
and labels
.
nsl.keras.adversarial_loss(
features,
labels,
model,
loss_fn,
sample_weights=None,
adv_config=None,
predictions=None,
labeled_loss=None,
gradient_tape=None,
model_kwargs=None
)
This utility function adds adversarial perturbations to the input features
,
runs the model
on the perturbed features for predictions, and returns the
corresponding loss loss_fn(labels, model(perturbed_features))
. This function
can be used in a Keras subclassed model and a custom training loop. This can
also be used freely as a helper function in eager execution mode.
The adversarial perturbation is based on the gradient of the labeled loss on
the original input features, i.e. loss_fn(labels, model(features))
.
Therefore, this function needs to compute the model's predictions on the input
features as model(features)
, and the labeled loss as loss_fn(labels,
predictions)
. If predictions or labeled loss have already been computed, they
can be passed in via the predictions
and labeled_loss
arguments in order
to save computational resources. Note that in eager execution mode,
gradient_tape
needs to be set accordingly when passing in predictions
or
labeled_loss
, so that the gradient can be computed correctly.
Example:
# A linear regression model (for demonstrating the usage only)
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(2,))])
loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD()
# Custom training loop. (The actual training data is omitted for clarity.)
for x, y in train_dataset:
with tf.GradientTape() as tape_w:
# A separate GradientTape is needed for watching the input.
with tf.GradientTape() as tape_x:
tape_x.watch(x)
# Regular forward pass.
labeled_loss = loss_fn(y, model(x))
# Calculates the adversarial loss. This will reuse labeled_loss and will
# consume tape_x.
adv_loss = nsl.keras.adversarial_loss(
x, y, model, loss_fn, labeled_loss=labeled_loss, gradient_tape=tape_x)
# Combines both losses. This could also be a weighted combination.
total_loss = labeled_loss + adv_loss
# Regular backward pass.
gradients = tape_w.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Args |
features
|
Input features, should be a Tensor or a collection of Tensor
objects. If it is a collection, the first dimension of all Tensor
objects inside should be the same (i.e. batch size).
|
labels
|
Target labels.
|
model
|
A callable that takes features as inputs and computes predictions
as outputs. An example would be a tf.keras.Model object.
|
loss_fn
|
A callable which calcualtes labeled loss from labels ,
predictions , and sample_weight . An example would be a
tf.keras.losses.Loss object.
|
sample_weights
|
(optional) A 1-D Tensor of weights for the examples, with
the same length as the first dimension of features .
|
adv_config
|
(optional) An nsl.configs.AdvRegConfig object for adversarial
regularization hyperparameters. Use nsl.configs.make_adv_reg_config to
construct one.
|
predictions
|
(optional) Precomputed value of model(features) . If set, the
value will be reused when calculating adversarial regularization. In eager
mode, the gradient_tape has to be set as well.
|
labeled_loss
|
(optional) Precomputed value of loss_fn(labels,
model(features)) . If set, the value will be reused when calculating
adversarial regularization. In eager mode, the gradient_tape has to be
set as well.
|
gradient_tape
|
(optional) A tf.GradientTape object watching features .
|
model_kwargs
|
(optional) A dictionary of additional keyword arguments to be
passed to the model .
|
Returns |
A Tensor for adversarial regularization loss, i.e. labeled loss on
adversarially perturbed features.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2022-10-28 UTC.
[null,null,["Last updated 2022-10-28 UTC."],[],[],null,["# nsl.keras.adversarial_loss\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/neural-structured-learning/blob/v1.4.0/neural_structured_learning/keras/adversarial_regularization.py#L32-L151) |\n\nComputes the adversarial loss for `model` given `features` and `labels`. \n\n nsl.keras.adversarial_loss(\n features,\n labels,\n model,\n loss_fn,\n sample_weights=None,\n adv_config=None,\n predictions=None,\n labeled_loss=None,\n gradient_tape=None,\n model_kwargs=None\n )\n\nThis utility function adds adversarial perturbations to the input `features`,\nruns the `model` on the perturbed features for predictions, and returns the\ncorresponding loss `loss_fn(labels, model(perturbed_features))`. This function\ncan be used in a Keras subclassed model and a custom training loop. This can\nalso be used freely as a helper function in eager execution mode.\n\nThe adversarial perturbation is based on the gradient of the labeled loss on\nthe original input features, i.e. `loss_fn(labels, model(features))`.\nTherefore, this function needs to compute the model's predictions on the input\nfeatures as `model(features)`, and the labeled loss as `loss_fn(labels,\npredictions)`. If predictions or labeled loss have already been computed, they\ncan be passed in via the `predictions` and `labeled_loss` arguments in order\nto save computational resources. Note that in eager execution mode,\n`gradient_tape` needs to be set accordingly when passing in `predictions` or\n`labeled_loss`, so that the gradient can be computed correctly.\n\n#### Example:\n\n # A linear regression model (for demonstrating the usage only)\n model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(2,))])\n loss_fn = tf.keras.losses.MeanSquaredError()\n optimizer = tf.keras.optimizers.SGD()\n\n # Custom training loop. (The actual training data is omitted for clarity.)\n for x, y in train_dataset:\n with tf.GradientTape() as tape_w:\n\n # A separate GradientTape is needed for watching the input.\n with tf.GradientTape() as tape_x:\n tape_x.watch(x)\n\n # Regular forward pass.\n labeled_loss = loss_fn(y, model(x))\n\n # Calculates the adversarial loss. This will reuse labeled_loss and will\n # consume tape_x.\n adv_loss = nsl.keras.adversarial_loss(\n x, y, model, loss_fn, labeled_loss=labeled_loss, gradient_tape=tape_x)\n\n # Combines both losses. This could also be a weighted combination.\n total_loss = labeled_loss + adv_loss\n\n # Regular backward pass.\n gradients = tape_w.gradient(total_loss, model.trainable_variables)\n optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `features` | Input features, should be a `Tensor` or a collection of `Tensor` objects. If it is a collection, the first dimension of all `Tensor` objects inside should be the same (i.e. batch size). |\n| `labels` | Target labels. |\n| `model` | A callable that takes `features` as inputs and computes `predictions` as outputs. An example would be a [`tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model) object. |\n| `loss_fn` | A callable which calcualtes labeled loss from `labels`, `predictions`, and `sample_weight`. An example would be a [`tf.keras.losses.Loss`](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss) object. |\n| `sample_weights` | (optional) A 1-D `Tensor` of weights for the examples, with the same length as the first dimension of `features`. |\n| `adv_config` | (optional) An [`nsl.configs.AdvRegConfig`](../../nsl/configs/AdvRegConfig) object for adversarial regularization hyperparameters. Use [`nsl.configs.make_adv_reg_config`](../../nsl/configs/make_adv_reg_config) to construct one. |\n| `predictions` | (optional) Precomputed value of `model(features)`. If set, the value will be reused when calculating adversarial regularization. In eager mode, the `gradient_tape` has to be set as well. |\n| `labeled_loss` | (optional) Precomputed value of `loss_fn(labels, model(features))`. If set, the value will be reused when calculating adversarial regularization. In eager mode, the `gradient_tape` has to be set as well. |\n| `gradient_tape` | (optional) A [`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape) object watching `features`. |\n| `model_kwargs` | (optional) A dictionary of additional keyword arguments to be passed to the `model`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A `Tensor` for adversarial regularization loss, i.e. labeled loss on adversarially perturbed features. ||\n\n\u003cbr /\u003e"]]