tf.keras.wrappers.scikit_learn.KerasClassifier
Stay organized with collections
Save and categorize content based on your preferences.
Implementation of the scikit-learn classifier API for Keras.
tf.keras.wrappers.scikit_learn.KerasClassifier(
build_fn=None, **sk_params
)
Methods
check_params
View source
check_params(
params
)
Checks for user typos in params
.
Arguments |
params
|
dictionary; the parameters to be checked
|
Raises |
ValueError
|
if any member of params is not a valid argument.
|
filter_sk_params
View source
filter_sk_params(
fn, override=None
)
Filters sk_params
and returns those in fn
's arguments.
Arguments |
fn
|
arbitrary function
|
override
|
dictionary, values to override sk_params
|
Returns |
res
|
dictionary containing variables
in both sk_params and fn 's arguments.
|
fit
View source
fit(
x, y, **kwargs
)
Constructs a new model with build_fn
& fit the model to (x, y)
.
Arguments |
x
|
array-like, shape (n_samples, n_features)
Training samples where n_samples is the number of samples
and n_features is the number of features.
|
y
|
array-like, shape (n_samples,) or (n_samples, n_outputs)
True labels for x .
|
**kwargs
|
dictionary arguments
Legal arguments are the arguments of Sequential.fit
|
Returns |
history
|
object
details about the training history at each epoch.
|
Raises |
ValueError
|
In case of invalid shape for y argument.
|
get_params
View source
get_params(
**params
)
Gets parameters for this estimator.
Arguments |
**params
|
ignored (exists for API compatibility).
|
Returns |
Dictionary of parameter names mapped to their values.
|
predict
View source
predict(
x, **kwargs
)
Returns the class predictions for the given test data.
Arguments |
x
|
array-like, shape (n_samples, n_features)
Test samples where n_samples is the number of samples
and n_features is the number of features.
|
**kwargs
|
dictionary arguments
Legal arguments are the arguments
of Sequential.predict_classes .
|
Returns |
preds
|
array-like, shape (n_samples,)
Class predictions.
|
predict_proba
View source
predict_proba(
x, **kwargs
)
Returns class probability estimates for the given test data.
Arguments |
x
|
array-like, shape (n_samples, n_features)
Test samples where n_samples is the number of samples
and n_features is the number of features.
|
**kwargs
|
dictionary arguments
Legal arguments are the arguments
of Sequential.predict_classes .
|
Returns |
proba
|
array-like, shape (n_samples, n_outputs)
Class probability estimates.
In the case of binary classification,
to match the scikit-learn API,
will return an array of shape (n_samples, 2)
(instead of (n_sample, 1) as in Keras).
|
score
View source
score(
x, y, **kwargs
)
Returns the mean accuracy on the given test data and labels.
Arguments |
x
|
array-like, shape (n_samples, n_features)
Test samples where n_samples is the number of samples
and n_features is the number of features.
|
y
|
array-like, shape (n_samples,) or (n_samples, n_outputs)
True labels for x .
|
**kwargs
|
dictionary arguments
Legal arguments are the arguments of Sequential.evaluate .
|
Returns |
score
|
float
Mean accuracy of predictions on x wrt. y .
|
Raises |
ValueError
|
If the underlying model isn't configured to
compute accuracy. You should pass metrics=["accuracy"] to
the .compile() method of the model.
|
set_params
View source
set_params(
**params
)
Sets the parameters of this estimator.
Arguments |
**params
|
Dictionary of parameter names mapped to their values.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[],null,["# tf.keras.wrappers.scikit_learn.KerasClassifier\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|\n| [TensorFlow 1 version](/versions/r1.15/api_docs/python/tf/keras/wrappers/scikit_learn/KerasClassifier) | [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/wrappers/scikit_learn.py#L191-L310) |\n\nImplementation of the scikit-learn classifier API for Keras.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.keras.wrappers.scikit_learn.KerasClassifier`](/api_docs/python/tf/keras/wrappers/scikit_learn/KerasClassifier)\n\n\u003cbr /\u003e\n\n tf.keras.wrappers.scikit_learn.KerasClassifier(\n build_fn=None, **sk_params\n )\n\nMethods\n-------\n\n### `check_params`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/wrappers/scikit_learn.py#L79-L106) \n\n check_params(\n params\n )\n\nChecks for user typos in `params`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|----------|------------------------------------------|\n| `params` | dictionary; the parameters to be checked |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ||\n|--------------|----------------------------------------------------|\n| `ValueError` | if any member of `params` is not a valid argument. |\n\n\u003cbr /\u003e\n\n### `filter_sk_params`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/wrappers/scikit_learn.py#L170-L187) \n\n filter_sk_params(\n fn, override=None\n )\n\nFilters `sk_params` and returns those in `fn`'s arguments.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|--------------------------------------------|\n| `fn` | arbitrary function |\n| `override` | dictionary, values to override `sk_params` |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|-------|---------------------------------------------------------------------------|\n| `res` | dictionary containing variables in both `sk_params` and `fn`'s arguments. |\n\n\u003cbr /\u003e\n\n### `fit`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/wrappers/scikit_learn.py#L195-L223) \n\n fit(\n x, y, **kwargs\n )\n\nConstructs a new model with `build_fn` \\& fit the model to `(x, y)`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|\n| `x` | array-like, shape `(n_samples, n_features)` Training samples where `n_samples` is the number of samples and `n_features` is the number of features. |\n| `y` | array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` True labels for `x`. |\n| `**kwargs` | dictionary arguments Legal arguments are the arguments of [`Sequential.fit`](../../../../tf/keras/Model#fit) |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|-----------|----------------------------------------------------------|\n| `history` | object details about the training history at each epoch. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ||\n|--------------|--------------------------------------------|\n| `ValueError` | In case of invalid shape for `y` argument. |\n\n\u003cbr /\u003e\n\n### `get_params`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/wrappers/scikit_learn.py#L108-L119) \n\n get_params(\n **params\n )\n\nGets parameters for this estimator.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-----------------------------------------|\n| `**params` | ignored (exists for API compatibility). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| Dictionary of parameter names mapped to their values. ||\n\n\u003cbr /\u003e\n\n### `predict`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/wrappers/scikit_learn.py#L225-L242) \n\n predict(\n x, **kwargs\n )\n\nReturns the class predictions for the given test data.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-------------------------------------------------------------------------------------------------------------------------------------------------|\n| `x` | array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. |\n| `**kwargs` | dictionary arguments Legal arguments are the arguments of [`Sequential.predict_classes`](../../../../tf/keras/Sequential#predict_classes). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---------|-----------------------------------------------------|\n| `preds` | array-like, shape `(n_samples,)` Class predictions. |\n\n\u003cbr /\u003e\n\n### `predict_proba`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/wrappers/scikit_learn.py#L244-L270) \n\n predict_proba(\n x, **kwargs\n )\n\nReturns class probability estimates for the given test data.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-------------------------------------------------------------------------------------------------------------------------------------------------|\n| `x` | array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. |\n| `**kwargs` | dictionary arguments Legal arguments are the arguments of [`Sequential.predict_classes`](../../../../tf/keras/Sequential#predict_classes). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `proba` | array-like, shape `(n_samples, n_outputs)` Class probability estimates. In the case of binary classification, to match the scikit-learn API, will return an array of shape `(n_samples, 2)` (instead of `(n_sample, 1)` as in Keras). |\n\n\u003cbr /\u003e\n\n### `score`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/wrappers/scikit_learn.py#L272-L310) \n\n score(\n x, y, **kwargs\n )\n\nReturns the mean accuracy on the given test data and labels.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-------------------------------------------------------------------------------------------------------------------------------------------------|\n| `x` | array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. |\n| `y` | array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` True labels for `x`. |\n| `**kwargs` | dictionary arguments Legal arguments are the arguments of [`Sequential.evaluate`](../../../../tf/keras/Model#evaluate). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---------|-----------------------------------------------------|\n| `score` | float Mean accuracy of predictions on `x` wrt. `y`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ||\n|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------|\n| `ValueError` | If the underlying model isn't configured to compute accuracy. You should pass `metrics=[\"accuracy\"]` to the `.compile()` method of the model. |\n\n\u003cbr /\u003e\n\n### `set_params`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/wrappers/scikit_learn.py#L121-L132) \n\n set_params(\n **params\n )\n\nSets the parameters of this estimator.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments ||\n|------------|-------------------------------------------------------|\n| `**params` | Dictionary of parameter names mapped to their values. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| self ||\n\n\u003cbr /\u003e"]]