tf_agents.agents.data_converter.AsTrajectory
Class that validates and converts other data types to Trajectory.
tf_agents.agents.data_converter.AsTrajectory(
data_context: tf_agents.agents.data_converter.DataContext
,
sequence_length: typing.Optional[int] = None,
num_outer_dims: te.Literal[1, 2] = 2
)
Note that validation and conversion allows values to contain dictionaries
with extra keys as compared to the the specs in the data context. These
additional entries / observations are ignored and dropped during conversion.
This non-strict checking allows users to provide additional info and
observation keys at input without having to manually prune them before
converting.
Args |
data_context
|
An instance of DataContext , typically accessed from the
TFAgent.data_context property.
|
sequence_length
|
The required time dimension value (if any), typically
determined by the subclass of TFAgent .
|
num_outer_dims
|
Expected number of outer dimensions. Either 1 or 2. If
1 , call expects an outer batch dimension. If 2 , then call expects
the two outer dimensions [batch, time] .
|
Methods
__call__
View source
__call__(
value: typing.Any
) -> tf_agents.trajectories.Trajectory
Convers value
to a Trajectory. Performs data validation and pruning.
- If
value
is already a Trajectory
, only validation is performed.
- If
value
is a Transition
with tensors containing two ([B, T]
)
outer dims, then it is simply repackaged to a Trajectory
and then
validated.
- If
value
is a Transition
with tensors containing one ([B]
) outer
dim, a ValueError
is raised.
Args |
value
|
A Trajectory or Transition object to convert.
|
Returns |
A validated and pruned Trajectory .
|
Raises |
TypeError
|
If value is not one of Trajectory or Transition .
|
ValueError
|
If value has structure that doesn't match the converter's
spec.
|
TypeError
|
If value has a structure that doesn't match the converter's
spec.
|
ValueError
|
If value is a Transition without a time dimension, as
training Trajectories typically have batch and time dimensions.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[null,null,["Last updated 2024-04-26 UTC."],[],[]]