
An environment based on an arbitrary classification problem.

Inherits From: BanditTFEnvironment, TFEnvironment

dataset a consisting of two Tensors, [inputs, labels] where inputs can be of any shape, while labels are integer class labels. The label tensor can be of any rank as long as it has 1 element.
reward_distribution a tfd.Distribution with event_shape [num_classes, num_actions]. Entry [i, j] is the reward for taking action j for an instance of class i.
batch_size if dataset is batched, this is the size of the batches.
label_dtype_cast if not None, casts dataset labels to this dtype.
shuffle_buffer_size If None, do not shuffle. Otherwise, a shuffle buffer of the specified size is used in the environment's dataset.
repeat_dataset Makes the environment iterate on the dataset once avoiding OutOfRangeError: End of sequence errors when the environment is stepped past the end of the dataset.
prefetch_size If None, do not prefetch. Otherwise, a prefetch buffer of the specified size is used in the environment's dataset.
seed Used to make results deterministic.
name The name of this environment instance.

ValueError if reward_distribution does not have an event shape with rank 2.






View source

Describes the specs of the Tensors expected by step(action).

action can be a single Tensor, or a nested dict, list or tuple of Tensors.

An single TensorSpec, or a nested dict, list or tuple of TensorSpec objects, which describe the shape and dtype of each Tensor expected by step().


View source


View source


View source

Returns the current TimeStep.

A TimeStep namedtuple containing: step_type: A StepType value. reward: Reward at this time_step. discount: A discount in the range [0, 1]. observation: A Tensor, or a nested dict, list or tuple of Tensors corresponding to observation_spec().


View source

Defines the TensorSpec of observations provided by the environment.

A TensorSpec, or a nested dict, list or tuple of TensorSpec objects, which describe the observation.


View source

Renders a frame from the environment.

NotImplementedError If the environment does not support rendering.


View source

Resets the environment and returns the current time_step.

A TimeStep namedtuple containing: step_type: A StepType value. reward: Reward at this time_step. discount: A discount in the range [0, 1]. observation: A Tensor, or a nested dict, list or tuple of Tensors corresponding to observation_spec().


View source

Defines the TensorSpec of rewards provided by the environment.

A TensorSpec, or a nested dict, list or tuple of TensorSpec objects, which describe the reward.


View source

Steps the environment according to the action.

If the environment returned a TimeStep with StepType.LAST at the previous step, this call to step should reset the environment (note that it is expected that whoever defines this method, calls reset in this case), start a new sequence and action will be ignored.

This method will also start a new sequence if called after the environment has been constructed and reset() has not been called. In this case action will be ignored.

Expected sequences look like:

time_step -> action -> next_time_step

The action should depend on the previous time_step for correctness.

action A Tensor, or a nested dict, list or tuple of Tensors corresponding to action_spec().

A TimeStep namedtuple containing: step_type: A StepType value. reward: Reward at this time_step. discount: A discount in the range [0, 1]. observation: A Tensor, or a nested dict, list or tuple of Tensors corresponding to observation_spec().


View source

Describes the TimeStep specs of Tensors returned by step().

A TimeStep namedtuple containing TensorSpec objects defining the Tensors returned by step(), i.e. (step_type, reward, discount, observation).