| Known Direct Subclasses |
Initializer capable of adapting its scale to the shape of weights tensors.
With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from
a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after
truncation, if used) stddev = Math.sqrt(scale / n), where n is:
- number of input units in the weight tensor, if
mode=FAN_IN - number of output units, if
mode=FAN_OUT - average of the numbers of input and output units, if
mode=FAN_AVG
With distribution=UNIFORM, samples are drawn from a uniform distribution within
[-limit, limit], where limit = Math.sqrt(3 * scale / n);.
Examples:
long seed = 1234l;
float scale = 0.1f;
VarianceScaling<TFloat32, TFloat32> initializer =
new org.tensorflow.framework.initializers.VarianceScaling<>(
tf, scale, Mode.FAN_IN, Distribution.UNIFORM, seed);
Operand<TFloat32> values =
initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class);
Nested Classes
| enum | VarianceScaling.Distribution | The random distribution to use when initializing the values. | |
| enum | VarianceScaling.Mode | The mode to use for calculating the fan values. | |
Constants
| double | SCALE_DEFAULT |
Fields
| public static final VarianceScaling.Distribution | DISTRIBUTION_DEFAULT | |
| public static final VarianceScaling.Mode | MODE_DEFAULT |
Public Constructors
|
VarianceScaling(Ops tf, long seed)
Creates a VarianceScaling Initializer
|
|
|
VarianceScaling(Ops tf, double scale, VarianceScaling.Mode mode, VarianceScaling.Distribution distribution, long seed)
Creates a VarianceScaling Initializer
|
Public Methods
| Operand<T> |
Inherited Methods
Constants
public static final double SCALE_DEFAULT
Fields
Public Constructors
public VarianceScaling (Ops tf, long seed)
Creates a VarianceScaling Initializer
Parameters
| tf | the TensorFlow Ops |
|---|---|
| seed | sed to create random seeds. |
public VarianceScaling (Ops tf, double scale, VarianceScaling.Mode mode, VarianceScaling.Distribution distribution, long seed)
Creates a VarianceScaling Initializer
Parameters
| tf | the TensorFlow Ops |
|---|---|
| scale | Scaling factor (positive float). |
| mode | the mode for the variance |
| distribution | Random distribution to use. |
| seed | Used to create random seeds. |