tf_agents.utils.common.get_episode_mask

Create a mask that is 0.0 for all final steps, 1.0 elsewhere.

time_steps A TimeStep namedtuple representing a batch of steps.

A float32 Tensor with 0s where step_type == LAST and 1s otherwise.