Look for spec's shape, check that outer dim is 1, and remove it.

If spec.shape[i] != 1 for any i in range(outer_ndim), we stop removing singleton batch dimensions at i and return what's left. This is necessary to handle the outputs of inconsistent layers like tf.keras.layers.LSTM() which may take as input (batch, time, dim) = (1, 1, Nin) and emits only the batch entry if time == 1: output shape is (1, Nout). We log an error in these cases.

spec A tf.TypeSpec.
outer_ndim The maximum number of outer singleton dims to remove.

A tf.TypeSpec, the spec without its outer batch dimension(s).

ValueError If spec lacks a shape property.