![]() |
Linear Thompson Sampling Agent.
Inherits From: LinearBanditAgent
, TFAgent
tf_agents.bandits.agents.linear_thompson_sampling_agent.LinearThompsonSamplingAgent(
time_step_spec: tf_agents.typing.types.TimeStep
,
action_spec: tf_agents.typing.types.BoundedTensorSpec
,
variable_collection: Optional[tf_agents.bandits.agents.linear_bandit_agent.LinearBanditVariableCollection
] = None,
alpha: float = 1.0,
gamma: float = 1.0,
use_eigendecomp: bool = False,
tikhonov_weight: float = 1.0,
add_bias: bool = False,
emit_policy_info: Sequence[Text] = (),
observation_and_action_constraint_splitter: Optional[types.Splitter] = None,
accepts_per_arm_features: bool = False,
debug_summaries: bool = False,
summarize_grads_and_vars: bool = False,
enable_summaries: bool = True,
dtype: tf.DType = tf.float32,
name: Optional[Text] = None
)
Implements the Linear Thompson Sampling Agent from the following paper:
"Thompson Sampling for Contextual Bandits with Linear Payoffs",
Shipra Agrawal, Navin Goyal, ICML 2013. The actual algorithm implemented is
Algorithm 3
from the supplementary material of the paper from
<a href="http://proceedings.mlr.press/v28/agrawal13-supp.pdf">http://proceedings.mlr.press/v28/agrawal13-supp.pdf</a>
.
In a nutshell, the agent maintains two parameters weight_covariances
and
parameter_estimators
, and updates them based on experience. The inverse of
the weight covariance parameters are updated with the outer product of the
observations using the Woodbury inverse matrix update, while the parameter
estimators are updated by the reward-weighted observation vectors for every
action.
Args | |
---|---|
time_step_spec
|
A TimeStep spec describing the expected TimeStep s.
|
action_spec
|
A scalar BoundedTensorSpec with int32 or int64 dtype
describing the number of actions for this agent.
|
variable_collection
|
Instance of LinearBanditVariableCollection .
Collection of variables to be updated by the agent. If None , a new
instance of LinearBanditVariableCollection will be created.
|
alpha
|
(float) positive scalar. This is the exploration parameter that multiplies the confidence intervals. |
gamma
|
a float forgetting factor in [0.0, 1.0]. When set to 1.0, the algorithm does not forget. |
use_eigendecomp
|
whether to use eigen-decomposition or not. The default solver is Conjugate Gradient. |
tikhonov_weight
|
(float) tikhonov regularization term. |
add_bias
|
If true, a bias term will be added to the linear reward estimation. |
emit_policy_info
|
(tuple of strings) what side information we want to get
as part of the policy info. Allowed values can be found in
policy_utilities.PolicyInfo .
|
observation_and_action_constraint_splitter
|
A function used for masking
valid/invalid actions with each state of the environment. The function
takes in a full observation and returns a tuple consisting of 1) the
part of the observation intended as input to the bandit agent and
policy, and 2) the boolean mask. This function should also work with a
TensorSpec as input, and should output TensorSpec objects for the
observation and mask.
|
accepts_per_arm_features
|
(bool) Whether the agent accepts per-arm features. |
debug_summaries
|
A Python bool, default False. When True, debug summaries are gathered. |
summarize_grads_and_vars
|
A Python bool, default False. When True, gradients and network variable summaries are written during training. |
enable_summaries
|
A Python bool, default True. When False, all summaries (debug or otherwise) should not be written. |
dtype
|
The type of the parameters stored and updated by the agent. Should
be one of tf.float32 and tf.float64 . Defaults to tf.float32 .
|
name
|
a name for this instance of LinearThompsonSamplingAgent .
|
Raises | |
---|---|
ValueError if dtype is not one of tf.float32 or tf.float64 .
|
Attributes | |
---|---|
action_spec
|
TensorSpec describing the action produced by the agent. |
alpha
|
|
collect_data_context
|
|
collect_data_spec
|
Returns a Trajectory spec, as expected by the collect_policy .
|
collect_policy
|
Return a policy that can be used to collect data from the environment. |
cov_matrix
|
|
data_context
|
|
data_vector
|
|
debug_summaries
|
|
eig_matrix
|
|
eig_vals
|
|
num_actions
|
|
num_samples
|
|
policy
|
Return the current policy held by the agent. |
summaries_enabled
|
|
summarize_grads_and_vars
|
|
theta
|
Returns the matrix of per-arm feature weights.
The returned matrix has shape (num_actions, context_dim). It's equivalent to a stacking of theta vectors from the paper. |
time_step_spec
|
Describes the TimeStep tensors expected by the agent.
|
train_sequence_length
|
The number of time steps needed in experience tensors passed to train .
Train requires experience to be a For example, for non-RNN DQN training, If this value is |
train_step_counter
|
|
training_data_spec
|
Returns a trajectory spec, as expected by the train() function. |
Methods
compute_summaries
compute_summaries(
loss: tf_agents.typing.types.Tensor
)
initialize
initialize() -> Optional[tf.Operation]
Initializes the agent.
Returns | |
---|---|
An operation that can be used to initialize the agent. |
Raises | |
---|---|
RuntimeError
|
If the class was not initialized properly (super.__init__
was not called).
|
loss
loss(
experience: tf_agents.typing.types.NestedTensor
,
weights: Optional[types.Tensor] = None,
training: bool = False,
**kwargs
) -> tf_agents.agents.tf_agent.LossInfo
Gets loss from the agent.
If the user calls this from _train, it must be in a tf.GradientTape
scope
in order to apply gradients to trainable variables.
If intermediate gradient steps are needed, _loss and _train will return
different values since _loss only supports updating all gradients at once
after all losses have been calculated.
Args | |
---|---|
experience
|
A batch of experience data in the form of a Trajectory . The
structure of experience must match that of self.training_data_spec .
All tensors in experience must be shaped [batch, time, ...] where
time must be equal to self.train_step_length if that
property is not None .
|
weights
|
(optional). A Tensor , either 0-D or shaped [batch] ,
containing weights to be used when calculating the total train loss.
Weights are typically multiplied elementwise against the per-batch loss,
but the implementation is up to the Agent.
|
training
|
Explicit argument to pass to loss . This typically affects
network computation paths like dropout and batch normalization.
|
**kwargs
|
Any additional data as args to loss .
|
Returns | |
---|---|
A LossInfo loss tuple containing loss and info tensors.
|
Raises | |
---|---|
RuntimeError
|
If the class was not initialized properly (super.__init__
was not called).
|
post_process_policy
post_process_policy() -> tf_agents.policies.TFPolicy
Post process policies after training.
The policies of some agents require expensive post processing after training before they can be used. e.g. A Recommender agent might require rebuilding an index of actions. For such agents, this method will return a post processed version of the policy. The post processing may either update the existing policies in place or create a new policy, depnding on the agent. The default implementation for agents that do not want to override this method is to return agent.policy.
Returns | |
---|---|
The post processed policy. |
preprocess_sequence
preprocess_sequence(
experience: tf_agents.typing.types.NestedTensor
) -> tf_agents.typing.types.NestedTensor
Defines preprocess_sequence function to be fed into replay buffers.
This defines how we preprocess the collected data before training.
Defaults to pass through for most agents.
Structure of experience
must match that of self.collect_data_spec
.
Args | |
---|---|
experience
|
a Trajectory shaped [batch, time, ...] or [time, ...] which
represents the collected experience data.
|
Returns | |
---|---|
A post processed Trajectory with the same shape as the input.
|
train
train(
experience: tf_agents.typing.types.NestedTensor
,
weights: Optional[types.Tensor] = None,
**kwargs
) -> tf_agents.agents.tf_agent.LossInfo
Trains the agent.
Args | |
---|---|
experience
|
A batch of experience data in the form of a Trajectory . The
structure of experience must match that of self.training_data_spec .
All tensors in experience must be shaped [batch, time, ...] where
time must be equal to self.train_step_length if that
property is not None .
|
weights
|
(optional). A Tensor , either 0-D or shaped [batch] ,
containing weights to be used when calculating the total train loss.
Weights are typically multiplied elementwise against the per-batch loss,
but the implementation is up to the Agent.
|
**kwargs
|
Any additional data to pass to the subclass. |
Returns | |
---|---|
A LossInfo loss tuple containing loss and info tensors.
|
Raises | |
---|---|
RuntimeError
|
If the class was not initialized properly (super.__init__
was not called).
|
update_alpha
update_alpha(
alpha
)