tf_agents.bandits.policies.linear_bandit_policy.LinearBanditPolicy

Linear Bandit Policy to be used by LinUCB, LinTS and possibly others.

Inherits From: TFPolicy

action_spec TensorSpec containing action specification.
cov_matrix list of the covariance matrices A in the paper. If the policy accepts per-arm features, the length of this list is 1, as there is only one model. Otherwise, there is one A matrix per arm.
data_vector list of the b vectors in the paper. The b vector is a weighted sum of the observations, where the weight is the corresponding reward. If the policy accepts per-arm features, this list should be of length 1, as there only 1 reward model maintained. Otherwise, each arm has its own vector b.
num_samples list of number of samples per arm, unless the policy accepts per-arm features, in which case this is just the number of samples seen.
time_step_spec A TimeStep spec of the expected time_steps.
exploration_strategy An Enum of type ExplortionStrategy. The strategy used for choosing the actions to incorporate exploration. Currently supported strategies are optimistic and sampling.
alpha a float value used to scale the confidence intervals.
eig_vals list of eigenvalues for each covariance matrix (one per arm, unless the policy accepts per-arm features).
eig_matrix list of eigenvectors for each covariance matrix (one per arm, unless the policy accepts per-arm features).
tikhonov_weight (float) tikhonov regularization term.
add_bias If true, a bias term will be added to the linear reward estimation.
emit_policy_info (tuple of strings) what side information we want to get as part of the policy info. Allowed values can be found in policy_utilities.PolicyInfo.
emit_log_probability Whether to emit log probabilities.
accepts_per_arm_features (bool) Whether the policy accepts per-arm features.
observation_and_action_constraint_splitter A function used for masking valid/invalid actions with each state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the bandit policy and 2) the mask. The mask should be a 0-1 Tensor of shape [batch_size, num_actions]. This function should also work with a TensorSpec as input, and should output TensorSpec objects for the observation and mask.
theta An optional 2-d tf.Tensor of the theta vectors shaped as [k, n], where k denotes the number of arms and n denotes the overall context dimension. When accepts_per_arm_features is true, k is expected to be 1 and n is the total dimension of the (flattened) global features and the (flattened) per-arm features. When supplied, the policy assumes it's consistent with the value computed from the other arguments cov_matrix, data_vector, and tikhonov_weight. If that is not the case, the policy may behave unexpectedly. Supplying pre-computed theta is the most useful for users who desire a greedy policy that selects actions solely based on the theta vectors, because this may significantly reduce the policy's inference latency.
name The name of this policy.

action_spec Describes the TensorSpecs of the Tensors expected by step(action).

action can be a single Tensor, or a nested dict, list or tuple of Tensors.

collect_data_spec Describes the Tensors written when using this policy with an environment.
emit_log_probability Whether this policy instance emits log probabilities or not.
info_spec Describes the Tensors emitted as info by action and distribution.

info can be an empty tuple, a single Tensor, or a nested dict, list or tuple of Tensors.

observation_and_action_constraint_splitter

policy_state_spec Describes the Tensors expected by step(_, policy_state).

policy_state can be an empty tuple, a single Tensor, or a nested dict, list or tuple of Tensors.

policy_step_spec Describes the output of action().
time_step_spec Describes the TimeStep tensors returned by step().
trajectory_spec Describes the Tensors written when using this policy with an environment.
validate_args Whether action & distribution validate input and output args.

Methods

action

View source

Generates next action given the time_step and policy_state.

Args
time_step A TimeStep tuple corresponding to time_step_spec().
policy_state A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state.
seed Seed to use if action performs sampling (optional).

Returns
A PolicyStep named tuple containing: action: An action Tensor matching the action_spec. state: A policy state tensor to be fed into the next call to action. info: Optional side information such as action log probabilities.

Raises
RuntimeError If subclass init didn't call super().init. ValueError or TypeError: If validate_args is True and inputs or outputs do not match time_step_spec, policy_state_spec, or policy_step_spec.

distribution

View source

Generates the distribution over next actions given the time_step.

Args
time_step A TimeStep tuple corresponding to time_step_spec().
policy_state A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state.

Returns
A PolicyStep named tuple containing:

action: A tf.distribution capturing the distribution of next actions. state: A policy state tensor for the next call to distribution. info: Optional side information such as action log probabilities.

Raises
ValueError or TypeError: If validate_args is True and inputs or outputs do not match time_step_spec, policy_state_spec, or policy_step_spec.

get_initial_state

View source

Returns an initial state usable by the policy.

Args
batch_size Tensor or constant: size of the batch dimension. Can be None in which case no dimensions gets added.

Returns
A nested object of type policy_state containing properly initialized Tensors.

update

View source

Update the current policy with another policy.

This would include copying the variables from the other policy.

Args
policy Another policy it can update from.
tau A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard update. This is used for trainable variables.
tau_non_trainable A float scalar in [0, 1] for non_trainable variables. If None, will copy from tau.
sort_variables_by_name A bool, when True would sort the variables by name before doing the update.

Returns
An TF op to do the update.