View source on GitHub |
Selects targeting classes for adversarial attack (classification only).
nsl.lib.get_target_indices(
logits, labels, adv_target_config
)
Args | |
---|---|
logits
|
tensor of shape [batch_size, num_classes] and dtype=tf.float32 .
|
labels
|
int tensor with a shape of [batch_size] containing the ground
truth labels.
|
adv_target_config
|
instance of nsl.configs.AdvTargetConfig specifying the
adversarial target configuration.
|
Returns | |
---|---|
Tensor of shape [batch_size] and dtype=tf.int32 of indices of targets.
|