Implements the focal loss function.
@tf.function
tfa.losses.sigmoid_focal_crossentropy( y_true:
tfa.types.TensorLike
, y_pred:tfa.types.TensorLike
, alpha:tfa.types.FloatTensorLike
= 0.25, gamma:tfa.types.FloatTensorLike
= 2.0, from_logits: bool = False ) -> tf.Tensor
Focal loss was first introduced in the RetinaNet paper (https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for classification when you have highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The loss value is much higher for a sample which is misclassified by the classifier as compared to the loss value corresponding to a well-classified example. One of the best use-cases of focal loss is its usage in object detection where the imbalance between the background class and other classes is extremely high.
Args | |
---|---|
y_true
|
true targets tensor. |
y_pred
|
predictions tensor. |
alpha
|
balancing factor. |
gamma
|
modulating factor. |
Returns | |
---|---|
Weighted loss float Tensor . If reduction is NONE ,this has the
same shape as y_true ; otherwise, it is scalar.
|