tf.keras.wrappers.scikit_learn.KerasRegressor
Stay organized with collections
Save and categorize content based on your preferences.
Implementation of the scikit-learn regressor API for Keras.
tf.keras.wrappers.scikit_learn.KerasRegressor(
build_fn=None, **sk_params
)
Methods
check_params
View source
check_params(
params
)
Checks for user typos in params
.
Arguments |
params
|
dictionary; the parameters to be checked
|
Raises |
ValueError
|
if any member of params is not a valid argument.
|
filter_sk_params
View source
filter_sk_params(
fn, override=None
)
Filters sk_params
and returns those in fn
's arguments.
Arguments |
fn
|
arbitrary function
|
override
|
dictionary, values to override sk_params
|
Returns |
res
|
dictionary containing variables
in both sk_params and fn 's arguments.
|
fit
View source
fit(
x, y, **kwargs
)
Constructs a new model with build_fn
& fit the model to (x, y)
.
Arguments |
x
|
array-like, shape (n_samples, n_features)
Training samples where n_samples is the number of samples
and n_features is the number of features.
|
y
|
array-like, shape (n_samples,) or (n_samples, n_outputs)
True labels for x .
|
**kwargs
|
dictionary arguments
Legal arguments are the arguments of Sequential.fit
|
Returns |
history
|
object
details about the training history at each epoch.
|
get_params
View source
get_params(
**params
)
Gets parameters for this estimator.
Arguments |
**params
|
ignored (exists for API compatibility).
|
Returns |
Dictionary of parameter names mapped to their values.
|
predict
View source
predict(
x, **kwargs
)
Returns predictions for the given test data.
Arguments |
x
|
array-like, shape (n_samples, n_features)
Test samples where n_samples is the number of samples
and n_features is the number of features.
|
**kwargs
|
dictionary arguments
Legal arguments are the arguments of Sequential.predict .
|
Returns |
preds
|
array-like, shape (n_samples,)
Predictions.
|
score
View source
score(
x, y, **kwargs
)
Returns the mean loss on the given test data and labels.
Arguments |
x
|
array-like, shape (n_samples, n_features)
Test samples where n_samples is the number of samples
and n_features is the number of features.
|
y
|
array-like, shape (n_samples,)
True labels for x .
|
**kwargs
|
dictionary arguments
Legal arguments are the arguments of Sequential.evaluate .
|
Returns |
score
|
float
Mean accuracy of predictions on x wrt. y .
|
set_params
View source
set_params(
**params
)
Sets the parameters of this estimator.
Arguments |
**params
|
Dictionary of parameter names mapped to their values.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[],null,["# tf.keras.wrappers.scikit_learn.KerasRegressor\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|\n| [TensorFlow 1 version](/versions/r1.15/api_docs/python/tf/keras/wrappers/scikit_learn/KerasRegressor) | [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/wrappers/scikit_learn.py#L314-L355) |\n\nImplementation of the scikit-learn regressor API for Keras.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.keras.wrappers.scikit_learn.KerasRegressor`](/api_docs/python/tf/keras/wrappers/scikit_learn/KerasRegressor)\n\n\u003cbr /\u003e\n\n tf.keras.wrappers.scikit_learn.KerasRegressor(\n build_fn=None, **sk_params\n )\n\nMethods\n-------\n\n### `check_params`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/wrappers/scikit_learn.py#L79-L106) \n\n check_params(\n params\n )\n\nChecks for user typos in `params`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|----------|------------------------------------------|\n| `params` | dictionary; the parameters to be checked |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ||\n|--------------|----------------------------------------------------|\n| `ValueError` | if any member of `params` is not a valid argument. |\n\n\u003cbr /\u003e\n\n### `filter_sk_params`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/wrappers/scikit_learn.py#L170-L187) \n\n filter_sk_params(\n fn, override=None\n )\n\nFilters `sk_params` and returns those in `fn`'s arguments.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|--------------------------------------------|\n| `fn` | arbitrary function |\n| `override` | dictionary, values to override `sk_params` |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|-------|---------------------------------------------------------------------------|\n| `res` | dictionary containing variables in both `sk_params` and `fn`'s arguments. |\n\n\u003cbr /\u003e\n\n### `fit`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/wrappers/scikit_learn.py#L134-L168) \n\n fit(\n x, y, **kwargs\n )\n\nConstructs a new model with `build_fn` \\& fit the model to `(x, y)`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|\n| `x` | array-like, shape `(n_samples, n_features)` Training samples where `n_samples` is the number of samples and `n_features` is the number of features. |\n| `y` | array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` True labels for `x`. |\n| `**kwargs` | dictionary arguments Legal arguments are the arguments of [`Sequential.fit`](../../../../tf/keras/Model#fit) |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|-----------|----------------------------------------------------------|\n| `history` | object details about the training history at each epoch. |\n\n\u003cbr /\u003e\n\n### `get_params`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/wrappers/scikit_learn.py#L108-L119) \n\n get_params(\n **params\n )\n\nGets parameters for this estimator.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-----------------------------------------|\n| `**params` | ignored (exists for API compatibility). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| Dictionary of parameter names mapped to their values. ||\n\n\u003cbr /\u003e\n\n### `predict`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/wrappers/scikit_learn.py#L318-L333) \n\n predict(\n x, **kwargs\n )\n\nReturns predictions for the given test data.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-------------------------------------------------------------------------------------------------------------------------------------------------|\n| `x` | array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. |\n| `**kwargs` | dictionary arguments Legal arguments are the arguments of [`Sequential.predict`](../../../../tf/keras/Model#predict). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---------|-----------------------------------------------|\n| `preds` | array-like, shape `(n_samples,)` Predictions. |\n\n\u003cbr /\u003e\n\n### `score`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/wrappers/scikit_learn.py#L335-L355) \n\n score(\n x, y, **kwargs\n )\n\nReturns the mean loss on the given test data and labels.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-------------------------------------------------------------------------------------------------------------------------------------------------|\n| `x` | array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. |\n| `y` | array-like, shape `(n_samples,)` True labels for `x`. |\n| `**kwargs` | dictionary arguments Legal arguments are the arguments of [`Sequential.evaluate`](../../../../tf/keras/Model#evaluate). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---------|-----------------------------------------------------|\n| `score` | float Mean accuracy of predictions on `x` wrt. `y`. |\n\n\u003cbr /\u003e\n\n### `set_params`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/wrappers/scikit_learn.py#L121-L132) \n\n set_params(\n **params\n )\n\nSets the parameters of this estimator.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-------------------------------------------------------|\n| `**params` | Dictionary of parameter names mapped to their values. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| self ||\n\n\u003cbr /\u003e"]]