Has the ability to load and apply an ML model.
tfx_bsl.public.beam.run_inference.ModelHandler()
Methods
batch_elements_kwargs
batch_elements_kwargs() -> Mapping[str, Any]
Returns: kwargs suitable for beam.BatchElements.
get_metrics_namespace
get_metrics_namespace() -> str
Returns: A namespace for metrics collected by the RunInference transform.
get_num_bytes
get_num_bytes(
batch: Sequence[ExampleT]
) -> int
Returns: The number of bytes of data for a batch.
get_postprocess_fns
get_postprocess_fns() -> Iterable[Callable[[Any], Any]]
Gets all postprocessing functions to be run after inference. Functions are in order that they should be applied.
get_preprocess_fns
get_preprocess_fns() -> Iterable[Callable[[Any], Any]]
Gets all preprocessing functions to be run before batching/inference. Functions are in order that they should be applied.
get_resource_hints
get_resource_hints() -> dict
Returns: Resource hints for the transform.
load_model
load_model() -> ModelT
Loads and initializes a model for processing.
model_copies
model_copies() -> int
Returns the maximum number of model copies that should be loaded at one time. This only impacts model handlers that are using share_model_across_processes to share their model across processes instead of being loaded per process.
override_metrics
override_metrics(
metrics_namespace: str = ''
) -> bool
Returns a boolean representing whether or not a model handler will override metrics reporting. If True, RunInference will not report any metrics.
run_inference
run_inference(
batch: Sequence[ExampleT],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionT]
Runs inferences on a batch of examples.
Args | |
---|---|
batch
|
A sequence of examples or features. |
model
|
The model used to make inferences. |
inference_args
|
Extra arguments for models whose inference call requires extra parameters. |
Returns | |
---|---|
An Iterable of Predictions. |
set_environment_vars
set_environment_vars()
Sets environment variables using a dictionary provided via kwargs. Keys are the env variable name, and values are the env variable value. Child ModelHandler classes should set _env_vars via kwargs in init, or else call super().init().
share_model_across_processes
share_model_across_processes() -> bool
Returns a boolean representing whether or not a model should be shared across multiple processes instead of being loaded per process. This is primary useful for large models that can't fit multiple copies in memory. Multi-process support may vary by runner, but this will fallback to loading per process as necessary. See https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html
should_skip_batching
should_skip_batching() -> bool
Whether RunInference's batching should be skipped.
Can be flipped to
True by using with_no_batching
update_model_path
update_model_path(
model_path: Optional[str] = None
)
Update the model path produced by side inputs. update_model_path should be used when a ModelHandler represents a single model, not multiple models. This will be true in most cases. For more information see the website section on model updates https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh
update_model_paths
update_model_paths(
model: ModelT,
model_paths: Optional[Union[str, List[KeyModelPathMapping]]] = None
)
Update the model paths produced by side inputs. update_model_paths should be used when updating multiple models at once (e.g. when using a KeyedModelHandler that holds multiple models). For more information see the KeyedModelHandler documentation https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.KeyedModelHandler documentation and the website section on model updates https://beam.apache.org/documentation/ml/about-ml/#automatic-model-refresh
validate_inference_args
validate_inference_args(
inference_args: Optional[Dict[str, Any]]
)
Validates inference_args passed in the inference call.
Because most frameworks do not need extra arguments in their predict() call, the default behavior is to error out if inference_args are present.
with_no_batching
with_no_batching() -> 'ModelHandler[Union[
ExampleT, Iterable[ExampleT]], PostProcessT, ModelT, PostProcessT]'
Returns a new ModelHandler which does not require batching of inputs so that RunInference will skip this step. RunInference will expect the input to be pre-batched and passed in as an Iterable of records. If you skip batching, any preprocessing functions should accept a batch of data, not just a single record.
This option is only recommended if you want to do custom batching yourself.
If you just want to pass in records without a batching dimension, it is
recommended to (1) add max_batch_size=1
to batch_elements_kwargs
and
(2) remove the batching dimension as part of your inference call (by
calling record=batch[0]
)
with_postprocess_fn
with_postprocess_fn(
fn: Callable[[PredictionT], PostProcessT]
) -> 'ModelHandler[ExampleT, PostProcessT, ModelT, PostProcessT]'
Returns a new ModelHandler with a postprocessing function associated with it. The postprocessing function will be run after inference and should map the base ModelHandler's output type to your desired output type. If you apply multiple postprocessing functions, they will be run on your original inference result in order from first applied to last applied.
with_preprocess_fn
with_preprocess_fn(
fn: Callable[[PreProcessT], ExampleT]
) -> 'ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]'
Returns a new ModelHandler with a preprocessing function associated with it. The preprocessing function will be run before batching/inference and should map your input PCollection to the base ModelHandler's input type. If you apply multiple preprocessing functions, they will be run on your original PCollection in order from last applied to first applied.