View source on GitHub |
Decision Forest in a Keras Model.
Usage example:
import tensorflow_decision_forests as tfdf
import pandas as pd
# Load the dataset in a Pandas dataframe.
train_df = pd.read_csv("project/train.csv")
test_df = pd.read_csv("project/test.csv")
# Convert the dataset into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label")
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="my_label")
# Train the model.
model = tfdf.keras.RandomForestModel()
model.fit(train_ds)
# Evaluate the model on another dataset.
model.evaluate(test_ds)
# Show information about the model
model.summary()
# Export the model with the TF.SavedModel format.
model.save("/path/to/my/model")
# ...
# Load a model: it loads as a generic keras model.
loaded_model = tf_keras.models.load_model("/path/to/my/model")
Modules
core
module: Core wrapper.
wrappers
module: Wrapper around each learning algorithm.
Classes
class AdvancedArguments
: Advanced control of the model that most users won't need to use.
class CartModel
: Cart learning algorithm.
class CoreModel
: Keras Model V2 wrapper around an Yggdrasil Learner and Model.
class DistributedGradientBoostedTreesModel
: Distributed Gradient Boosted Trees learning algorithm.
class FeatureSemantic
: Semantic (e.g.
class FeatureUsage
: Semantic and hyper-parameters for a single feature.
class GradientBoostedTreesModel
: Gradient Boosted Trees learning algorithm.
class Monotonic
: Monotonic constraint between a feature and the model output.
class MultiTaskItem
: A single task in a multi-task configuration.
class RandomForestModel
: Random Forest learning algorithm.
Functions
build_default_feature_signature(...)
: Gets an example of feature values for the default model signature.
build_default_input_model_signature(...)
get_all_models(...)
: Gets the lists of all the available models.
get_worker_idx_and_num_workers(...)
: Gets the current worker index and the total number of workers.
pd_dataframe_to_tf_dataset(...)
: Converts a Panda Dataframe into a TF Dataset compatible with Keras.
set_training_logs_redirection(...)
: Controls the redirection of training logs for display.
yggdrasil_model_to_keras_model(...)
: Converts an Yggdrasil model into a TensorFlow SavedModel / Keras model.
Other Members | |
---|---|
Task |
Instance of google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper
|