ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf_agents.specs.bandit_spec_utils.get_context_dims_from_spec

Returns the global and per-arm context dimensions.

If the policy accepts per-arm features, this function returns the tuple of the global and per-arm context dimension. Otherwise, it returns the (global) context dim and zero.

context_spec A nest of tensor specs, containing the observation spec.
accepts_per_arm_features (bool) Whether the context_spec is for a policy that accepts per-arm features.

Returns: A 2-tuple of ints, the global and per-arm context dimension. If the policy does not accept per-arm features, the per-arm context dim is 0.