View source on GitHub |
An environment based on an arbitrary classification problem.
Inherits From: BanditTFEnvironment
, TFEnvironment
tf_agents.bandits.environments.classification_environment.ClassificationBanditEnvironment(
dataset: tf.data.Dataset,
reward_distribution: types.Distribution,
batch_size: tf_agents.typing.types.Int
,
label_dtype_cast: Optional[tf.DType] = None,
shuffle_buffer_size: Optional[types.Int] = None,
repeat_dataset: Optional[bool] = True,
prefetch_size: Optional[types.Int] = None,
seed: Optional[types.Int] = None,
name: Optional[Text] = 'classification'
)
Args | |
---|---|
dataset
|
a tf.data.Dataset consisting of two Tensor s, [inputs, labels]
where inputs can be of any shape, while labels are integer class labels.
The label tensor can be of any rank as long as it has 1 element.
|
reward_distribution
|
a tfd.Distribution with event_shape [num_classes,
num_actions] . Entry [i, j] is the reward for taking action j for an
instance of class i .
|
batch_size
|
if dataset is batched, this is the size of the batches.
|
label_dtype_cast
|
if not None, casts dataset labels to this dtype. |
shuffle_buffer_size
|
If None, do not shuffle. Otherwise, a shuffle buffer
of the specified size is used in the environment's dataset .
|
repeat_dataset
|
Makes the environment iterate on the dataset once
avoiding OutOfRangeError: End of sequence errors when the environment
is stepped past the end of the dataset .
|
prefetch_size
|
If None, do not prefetch. Otherwise, a prefetch buffer of
the specified size is used in the environment's dataset .
|
seed
|
Used to make results deterministic. |
name
|
The name of this environment instance. |
Raises | |
---|---|
ValueError
|
if reward_distribution does not have an event shape with
rank 2.
|
Attributes | |
---|---|
batch_size
|
|
batched
|
|
name
|
Methods
action_spec
action_spec()
Describes the specs of the Tensors expected by step(action)
.
action
can be a single Tensor, or a nested dict, list or tuple of
Tensors.
Returns | |
---|---|
An single TensorSpec , or a nested dict, list or tuple of
TensorSpec objects, which describe the shape and
dtype of each Tensor expected by step() .
|
compute_optimal_action
compute_optimal_action() -> tf_agents.typing.types.NestedTensor
compute_optimal_reward
compute_optimal_reward() -> tf_agents.typing.types.NestedTensor
current_time_step
current_time_step()
Returns the current TimeStep
.
Returns | |
---|---|
A TimeStep namedtuple containing:
step_type: A StepType value.
reward: Reward at this time_step.
discount: A discount in the range [0, 1].
observation: A Tensor, or a nested dict, list or tuple of Tensors
corresponding to observation_spec() .
|
observation_spec
observation_spec()
Defines the TensorSpec
of observations provided by the environment.
Returns | |
---|---|
A TensorSpec , or a nested dict, list or tuple of
TensorSpec objects, which describe the observation.
|
render
render()
Renders a frame from the environment.
Raises | |
---|---|
NotImplementedError
|
If the environment does not support rendering. |
reset
reset()
Resets the environment and returns the current time_step.
Returns | |
---|---|
A TimeStep namedtuple containing:
step_type: A StepType value.
reward: Reward at this time_step.
discount: A discount in the range [0, 1].
observation: A Tensor, or a nested dict, list or tuple of Tensors
corresponding to observation_spec() .
|
reward_spec
reward_spec()
Defines the TensorSpec
of rewards provided by the environment.
Returns | |
---|---|
A TensorSpec , or a nested dict, list or tuple of
TensorSpec objects, which describe the reward.
|
step
step(
action
)
Steps the environment according to the action.
If the environment returned a TimeStep
with StepType.LAST
at the
previous step, this call to step
should reset the environment (note that
it is expected that whoever defines this method, calls reset in this case),
start a new sequence and action
will be ignored.
This method will also start a new sequence if called after the environment
has been constructed and reset()
has not been called. In this case
action
will be ignored.
Expected sequences look like:
time_step -> action -> next_time_step
The action should depend on the previous time_step for correctness.
Args | |
---|---|
action
|
A Tensor, or a nested dict, list or tuple of Tensors corresponding
to action_spec() .
|
Returns | |
---|---|
A TimeStep namedtuple containing:
step_type: A StepType value.
reward: Reward at this time_step.
discount: A discount in the range [0, 1].
observation: A Tensor, or a nested dict, list or tuple of Tensors
corresponding to observation_spec() .
|
time_step_spec
time_step_spec()
Describes the TimeStep
specs of Tensors returned by step()
.
Returns | |
---|---|
A TimeStep namedtuple containing TensorSpec objects defining the
Tensors returned by step() , i.e.
(step_type, reward, discount, observation).
|