tfp.experimental.nn.util.make_kernel_bias

Creates kernel and bias as tf.Variables.

kernel_shape ...
bias_shape ...
kernel_initializer ... Default value: None (i.e., tf.initializers.glorot_uniform()).
bias_initializer ... Default value: None (i.e., tf.initializers.zeros()).
kernel_batch_ndims ... Default value: 0.
bias_batch_ndims ... Default value: 0.
dtype ... Default value: tf.float32.
kernel_name ... Default value: "kernel".
bias_name ... Default value: "bias".

kernel ...
bias ...

Recommendations:

#   tf.nn.relu    ==> tf.initializers.he_*
#   tf.nn.elu     ==> tf.initializers.he_*
#   tf.nn.selu    ==> tf.initializers.lecun_*
#   tf.nn.tanh    ==> tf.initializers.glorot_*
#   tf.nn.sigmoid ==> tf.initializers.glorot_*
#   tf.nn.softmax ==> tf.initializers.glorot_*
#   None          ==> tf.initializers.glorot_*
# https://towardsdatascience.com/hyper-parameters-in-action-part-ii-weight-initializers-35aee1a28404
# https://stats.stackexchange.com/a/393012/1835

def make_uniform(size):
  s = tf.math.rsqrt(size / 3.)
  return tfd.Uniform(low=-s, high=s)

def make_normal(size):
  # Constant is: `scipy.stats.truncnorm.std(loc=0., scale=1., a=-2., b=2.)`.
  s = tf.math.rsqrt(size) / 0.87962566103423978
  return tfd.TruncatedNormal(loc=0, scale=s, low=-2., high=2.)

# He.  https://arxiv.org/abs/1502.01852
he_uniform = make_uniform(fan_in / 2.)
he_normal  = make_normal (fan_in / 2.)

# Glorot (aka Xavier). http://proceedings.mlr.press/v9/glorot10a.html
glorot_uniform = make_uniform((fan_in + fan_out) / 2.)
glorot_normal  = make_normal ((fan_in + fan_out) / 2.)