An adapter to convert an existing model to have a vector-valued support.

Inherits From: Model

This adapter makes it convenient to use the Inference Gym models with inference algorithms which cannot handle structured events. It does so by reshaping individual event Tensors and concatenating them into a single vector.

The resultant vector-valued model has updated properties and sample transformations which reflect the transformation above. Note that the sample transformations will still return structured values, as those generally cannot be as readily flattened.

There are only two restrictions on the models that can be handled by this class:

  1. The individual Tensors in an event must all have the same dtype.
  2. The default_event_space_bijector must apply to a single tensor at time.

The second restriction will be lifted soon.


base_model = gym.targets.SyntheticItemResponseTheory()
vec_model = gym.targets.VectorModel(base_model)

# ==> {
#         'mean_student_ability': tf.float32,
#         'student_ability': tf.float32,
#         'question_difficulty': tf.float32,
#     }

# ==> tf.float32

# ==> {
#         'mean_student_ability': [],
#         'student_ability': [400],
#         'question_difficulty': [100],
#     }

# ==> [501]

model An Inference Gym model.

TypeError If model has more than one unique Tensor dtype.

default_event_space_bijector Bijector mapping the reals (R**n) to the event space of this model.
dtype The DType of Tensors handled by this model.
event_shape Shape of a single sample from as a TensorShape.

May be partially defined or unknown.

name Python str name prefixed to Ops created by this class.
sample_transformations A dictionary of names to SampleTransformations.

Child Classes

class SampleTransformation



View source

The un-normalized log density of evaluated at a point.

This corresponds to the target distribution associated with the model, often its posterior.

value A (nest of) Tensor to evaluate the log density at.
name Python str name prefixed to Ops created by this method.

unnormalized_log_prob A floating point Tensor.