Implementation of the scikit-learn classifier API for Keras.
tf.keras.wrappers.scikit_learn.KerasClassifier(
    build_fn=None, **sk_params
)
Methods
check_params
View source
check_params(
    params
)
Checks for user typos in params.
| Arguments | 
|---|
| params | dictionary; the parameters to be checked | 
| Raises | 
|---|
| ValueError | if any member of paramsis not a valid argument. | 
filter_sk_params
View source
filter_sk_params(
    fn, override=None
)
Filters sk_params and returns those in fn's arguments.
| Arguments | 
|---|
| fn | arbitrary function | 
| override | dictionary, values to override sk_params | 
| Returns | 
|---|
| res | dictionary containing variables
in both sk_paramsandfn's arguments. | 
fit
View source
fit(
    x, y, **kwargs
)
Constructs a new model with build_fn & fit the model to (x, y).
| Arguments | 
|---|
| x | array-like, shape (n_samples, n_features)Training samples wheren_samplesis the number of samples
andn_featuresis the number of features. | 
| y | array-like, shape (n_samples,)or(n_samples, n_outputs)True labels forx. | 
| **kwargs | dictionary arguments
Legal arguments are the arguments of Sequential.fit | 
| Returns | 
|---|
| history | object
details about the training history at each epoch. | 
| Raises | 
|---|
| ValueError | In case of invalid shape for yargument. | 
get_params
View source
get_params(
    **params
)
Gets parameters for this estimator.
| Arguments | 
|---|
| **params | ignored (exists for API compatibility). | 
| Returns | 
|---|
| Dictionary of parameter names mapped to their values. | 
predict
View source
predict(
    x, **kwargs
)
Returns the class predictions for the given test data.
| Arguments | 
|---|
| x | array-like, shape (n_samples, n_features)Test samples wheren_samplesis the number of samples
andn_featuresis the number of features. | 
| **kwargs | dictionary arguments
Legal arguments are the arguments
of Sequential.predict_classes. | 
| Returns | 
|---|
| preds | array-like, shape (n_samples,)Class predictions. | 
predict_proba
View source
predict_proba(
    x, **kwargs
)
Returns class probability estimates for the given test data.
| Arguments | 
|---|
| x | array-like, shape (n_samples, n_features)Test samples wheren_samplesis the number of samples
andn_featuresis the number of features. | 
| **kwargs | dictionary arguments
Legal arguments are the arguments
of Sequential.predict_classes. | 
| Returns | 
|---|
| proba | array-like, shape (n_samples, n_outputs)Class probability estimates.
In the case of binary classification,
to match the scikit-learn API,
will return an array of shape(n_samples, 2)(instead of(n_sample, 1)as in Keras). | 
score
View source
score(
    x, y, **kwargs
)
Returns the mean accuracy on the given test data and labels.
| Arguments | 
|---|
| x | array-like, shape (n_samples, n_features)Test samples wheren_samplesis the number of samples
andn_featuresis the number of features. | 
| y | array-like, shape (n_samples,)or(n_samples, n_outputs)True labels forx. | 
| **kwargs | dictionary arguments
Legal arguments are the arguments of Sequential.evaluate. | 
| Returns | 
|---|
| score | float
Mean accuracy of predictions on xwrt.y. | 
| Raises | 
|---|
| ValueError | If the underlying model isn't configured to
compute accuracy. You should pass metrics=["accuracy"]to
the.compile()method of the model. | 
set_params
View source
set_params(
    **params
)
Sets the parameters of this estimator.
| Arguments | 
|---|
| **params | Dictionary of parameter names mapped to their values. |