View source on GitHub |
Keras Tuner Oracle that wraps the AI Platform Vizier backend.
tfc.CloudOracle(
project_id: Text,
region: Text,
objective: Union[Text, oracle_module.Objective] = None,
hyperparameters: hp_module.HyperParameters = None,
study_config: Optional[Dict[Text, Any]] = None,
max_trials: int = None,
study_id: Optional[Text] = None
)
This is an implementation of KerasTuner
Oracle that uses Google Cloud's
Vizier Service.
Each Oracle class implements a particular hyperparameter tuning algorithm. An Oracle is passed as an argument to a Tuner. The Oracle tells the Tuner which hyperparameters should be tried next.
To learn more about Keras Tuner Oracles please refer to: https://keras-team.github.io/keras-tuner/documentation/oracles/
AI Platform Vizier is a black-box optimization service that helps you tune hyperparameters in complex machine learning (ML) models. When ML models have many different hyperparameters, it can be difficult and time consuming to tune them manually. AI Platform Vizier optimizes your model's output by tuning the hyperparameters for you. To learn more about AI Platform Vizier service see: https://cloud.google.com/ai-platform/optimizer/docs/overview
Examples:
oracle = CloudOracle(
project_id=project_id,
region='us-central1',
objective='accuracy',
hyperparameters=hyperparameters,
study_config=None,
max_trials=4,
study_id=None,
)
Args | |
---|---|
project_id
|
A GCP project id. |
region
|
A GCP region. e.g. 'us-central1'. |
objective
|
If a string, the direction of the optimization (min or max) will be inferred. |
hyperparameters
|
Mandatory and must include definitions for all hyperparameters used during the search. Can be used to override (or register in advance) hyperparameters in the search space. |
study_config
|
Study configuration for Vizier service. |
max_trials
|
Total number of trials (model configurations) to test at
most. If None, it continues the search until it reaches the
Vizier trial limit for each study. Users may stop the search
externally (e.g. by killing the job). Note that the Oracle may
interrupt the search before max_trials models have been
tested.
|
study_id
|
An identifier of the study. If not supplied,
system-determined unique ID is given.
The full study name will be
projects/{project_id}/locations/{region}/studies/{study_id} ,
and the full trial name will be
{study name}/trials/{trial_id} .
|
Methods
create_trial
create_trial(
tuner_id: Text
) -> trial_module.Trial
Create a new Trial
to be run by the Tuner
.
Args | |
---|---|
tuner_id
|
An ID that identifies the Tuner requesting a Trial .
Tuners that should run the same trial (for instance, when
running a multi-worker model) should have the same ID. If
multiple suggestTrialsRequests have the same tuner_id, the
service will return the identical suggested trial if the trial
is PENDING, and provide a new trial if the last suggested trial
was completed.
|
Returns | |
---|---|
A Trial object containing a set of hyperparameter values to run
in a Tuner .
|
Raises | |
---|---|
SuggestionInactiveError
|
Indicates that a suggestion was requested from an inactive study. |
end_trial
end_trial(
trial_id: Text,
status: Text = 'COMPLETED'
)
Record the measured objective for a set of parameter values.
get_best_trials
get_best_trials(
num_trials: int = 1
) -> List[trial_module.Trial]
Returns the trials with the best objective values found so far.
Args | |
---|---|
num_trials
|
positive int, number of trials to return. |
Returns | |
---|---|
List of KerasTuner Trials. |
get_space
get_space()
Returns the HyperParameters
search space.
get_state
get_state()
Returns the current state of this object.
This method is called during save
.
get_trial
get_trial(
trial_id
)
Returns the Trial
specified by trial_id
.
reload
reload()
Reloads this object using set_state
.
Arguments | |
---|---|
fname
|
The file name to restore from. |
remaining_trials
remaining_trials()
save
save()
Saves this object using get_state
.
Arguments | |
---|---|
fname
|
The file name to save to. |
set_state
set_state(
state
)
Sets the current state of this object.
This method is called during reload
.
Arguments | |
---|---|
state
|
Dict. The state to restore for this object. |
update_space
update_space(
hyperparameters
)
Add new hyperparameters to the tracking space.
Already recorded parameters get ignored.
Args | |
---|---|
hyperparameters
|
An updated HyperParameters object. |
update_trial
update_trial(
trial_id: Text,
metrics: Mapping[Text, Union[int, float]],
step: int = 0
)
Used by a worker to report the status of a trial.