A nest of BoundedTensorSpec representing the actions.
reward_network
An instance of a tf_agents.network.Network, callable via
network(observation, step_type) -> (output, final_state).
temperature
float or callable that returns a float. The temperature used
in the Boltzmann exploration.
boltzmann_gumbel_exploration_constant
optional positive float. When
provided, the policy implements Neural Bandit with Boltzmann-Gumbel
exploration from the paper: N. Cesa-Bianchi et al., "Boltzmann
Exploration Done Right", NIPS 2017.
observation_and_action_constraint_splitter
A function used for masking
valid/invalid actions with each state of the environment. The function
takes in a full observation and returns a tuple consisting of 1) the
part of the observation intended as input to the network and 2) the
mask. The mask should be a 0-1 Tensor of shape [batch_size,
num_actions]. This function should also work with a TensorSpec as
input, and should output TensorSpec objects for the observation and
mask.
accepts_per_arm_features
(bool) Whether the policy accepts per-arm
features.
constraints
iterable of constraints objects that are instances of
tf_agents.bandits.agents.BaseConstraint.
emit_policy_info
(tuple of strings) what side information we want to get
as part of the policy info. Allowed values can be found in
policy_utilities.PolicyInfo.
num_samples_list
list or tuple of tf.Variable's. Used only in
Boltzmann-Gumbel exploration. Otherwise, empty.
name
The name of this policy. All variables in this module will fall
under that name. Defaults to the class name.
Raises
NotImplementedError
If action_spec contains more than one
BoundedTensorSpec or the BoundedTensorSpec is not valid.
Attributes
accepts_per_arm_features
action_spec
Describes the TensorSpecs of the Tensors expected by step(action).
action can be a single Tensor, or a nested dict, list or tuple of
Tensors.
collect_data_spec
Describes the Tensors written when using this policy with an environment.
emit_log_probability
Whether this policy instance emits log probabilities or not.
info_spec
Describes the Tensors emitted as info by action and distribution.
info can be an empty tuple, a single Tensor, or a nested dict,
list or tuple of Tensors.
observation_and_action_constraint_splitter
policy_state_spec
Describes the Tensors expected by step(_, policy_state).
policy_state can be an empty tuple, a single Tensor, or a nested dict,
list or tuple of Tensors.
policy_step_spec
Describes the output of action().
time_step_spec
Describes the TimeStep tensors returned by step().
trajectory_spec
Describes the Tensors written when using this policy with an environment.
validate_args
Whether action & distribution validate input and output args.
Generates next action given the time_step and policy_state.
Args
time_step
A TimeStep tuple corresponding to time_step_spec().
policy_state
A Tensor, or a nested dict, list or tuple of Tensors
representing the previous policy_state.
seed
Seed to use if action performs sampling (optional).
Returns
A PolicyStep named tuple containing:
action: An action Tensor matching the action_spec.
state: A policy state tensor to be fed into the next call to action.
info: Optional side information such as action log probabilities.
Raises
RuntimeError
If subclass init didn't call super().init.
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec, policy_state_spec,
or policy_step_spec.
Generates the distribution over next actions given the time_step.
Args
time_step
A TimeStep tuple corresponding to time_step_spec().
policy_state
A Tensor, or a nested dict, list or tuple of Tensors
representing the previous policy_state.
Returns
A PolicyStep named tuple containing:
action: A tf.distribution capturing the distribution of next actions.
state: A policy state tensor for the next call to distribution.
info: Optional side information such as action log probabilities.
Raises
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec, policy_state_spec,
or policy_step_spec.