SquaredHinge

classe pubblica SquaredHinge

Calcola la perdita di cerniera al quadrato tra etichette e previsioni.

loss = square(maximum(1 - labels * predictions, 0))

si prevede che i valori labels siano -1 o 1. Se vengono fornite etichette binarie (0 o 1), verranno convertite in -1 o 1.

Utilizzo autonomo:

    Operand<TFloat32> labels =
        tf.constant(new float[][] { {0., 1.}, {0., 0.} });
    Operand<TFloat32> predictions =
        tf.constant(new float[][] { {0.6f, 0.4f}, {0.4f, 0.6f} });
    SquaredHinge squaredHinge = new SquaredHinge(tf);
    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
    // produces 1.86f
 

Chiamata con peso campione:

    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
    Operand<TFloat32> result = squaredHinge.call(labels, predictions,
                                                  sampleWeight);
    // produces 0.73f
 

Utilizzando il tipo di riduzione SUM :

    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.SUM);
    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
    // produces 3.72f
 

Utilizzando il tipo di riduzione NONE :

    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.NONE);
    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
    // produces [1.46f, 2.26f]
 

Campi ereditati

Costruttori pubblici

SquaredHinge (Ops tf)
Crea una perdita di cerniera quadrata utilizzando getSimpleName() come nome della perdita e una riduzione della perdita di REDUCTION_DEFAULT
SquaredHinge (Ops tf, Riduzione riduzione)
Crea una perdita cerniera quadrata utilizzando getSimpleName() come nome della perdita
SquaredHinge (Ops tf, Nome stringa, Riduzione riduzione)
Crea una cerniera quadrata

Metodi pubblici

<T estende TNumero > Operando <T>
chiamata ( Operando <? estende TNumber > etichette, Operando <T> previsioni, Operando <T> sampleWeights)
Genera un operando che calcola la perdita.

Metodi ereditati

Costruttori pubblici

pubblico SquaredHinge (Ops tf)

Crea una perdita di cerniera quadrata utilizzando getSimpleName() come nome della perdita e una riduzione della perdita di REDUCTION_DEFAULT

Parametri
tf le operazioni TensorFlow

public SquaredHinge (Ops tf, Riduzione riduzione)

Crea una perdita cerniera quadrata utilizzando getSimpleName() come nome della perdita

Parametri
tf le operazioni TensorFlow
riduzione Tipo di riduzione da applicare alla perdita.

public SquaredHinge (Ops tf, Nome stringa, Riduzione riduzione)

Crea una cerniera quadrata

Parametri
tf le operazioni TensorFlow
nome il nome della perdita
riduzione Tipo di riduzione da applicare alla perdita.

Metodi pubblici

chiamata pubblica dell'operando <T> ( Operando <? estende TNumber > etichette, previsioni dell'operando <T>, operando <T> sampleWeights)

Genera un operando che calcola la perdita.

Se eseguito in modalità Grafico, il calcolo genererà TFInvalidArgumentException se i valori dell'etichetta non sono nel set [-1., 0., 1.]. In modalità Eager, questa chiamata genererà IllegalArgumentException , se i valori dell'etichetta non sono nel set [-1., 0., 1.].

Parametri
etichette i valori di verità o le etichette devono essere -1, 0 o 1. I valori dovrebbero essere -1 o 1. Se vengono fornite etichette binarie (0 o 1), verranno convertite in -1 o 1.
previsioni le previsioni, i valori devono essere compresi nell'intervallo [0. a 1.] compreso.
campionePesi SampleWeights opzionale funge da coefficiente per la perdita. Se viene fornito uno scalare, la perdita viene semplicemente ridimensionata in base al valore fornito. Se SampleWeights è un tensore di dimensione [batch_size], la perdita totale per ciascun campione del batch viene riscalata dall'elemento corrispondente nel vettore SampleWeights. Se la forma di SampleWeights è [batch_size, d0, .. dN-1] (o può essere trasmessa a questa forma), ogni elemento di perdita delle previsioni viene ridimensionato in base al valore corrispondente di SampleWeights. (Nota su dN-1: tutte le funzioni di perdita si riducono di 1 dimensione, solitamente asse=-1.)
Ritorni
  • la perdita
Lancia
IllegalArgumentException se le previsioni sono fuori dall'intervallo [0.-1.].