View source on GitHub |
Adds a Jensen-Shannon divergence to the training procedure.
nsl.lib.jensen_shannon_divergence(
labels,
predictions,
axis=None,
weights=1.0,
scope=None,
loss_collection=tf.compat.v1.GraphKeys.LOSSES,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS
)
For brevity, let P = labels
, Q = predictions
, KL(P||Q)
be the
Kullback-Leibler divergence as defined in the description of the
nsl.lib.kl_divergence
function.". The Jensen-Shannon divergence (JSD) is
M = (P + Q) / 2
JSD(P||Q) = KL(P||M) / 2 + KL(Q||M) / 2
This function assumes that predictions
and labels
are the values of a
multinomial distribution, i.e., each value is the probability of the
corresponding class.
For the usage of weights
and reduction
, please refer to tf.losses
.
Returns | |
---|---|
Weighted loss float Tensor . If reduction is
tf.compat.v1.losses.Reduction.MEAN , this has the same shape as labels ,
otherwise, it is a scalar.
|