ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

梯度提升树(Gradient Boosted Trees):模型理解

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

警告:不建议将 Estimator 用于新代码。Estimator 运行 v1.Session 风格的代码,此类代码更加难以正确编写,并且可能会出现意外行为,尤其是与 TF 2 代码结合使用时。Estimator 确实在我们的兼容性保证范围内,但除了安全漏洞之外不会得到任何修复。请参阅迁移指南以了解详情。

TensorFlow Decision Forests 中提供了许多最先进决策森林算法基于现代 Keras 的实现。

对于梯度提升模型(Gradient Boosting model)的端到端演示(end-to-end walkthrough),请查阅在 Tensorflow 中训练提升树(Boosted Trees)模型。在本教程中,您将:

  • 学习如何对提升树模型进行局部全局解释
  • 直观地了解提升树模型如何拟合数据集

如何对提升树模型(Boosted Trees model)进行局部解释和全局解释

局部可解释性指模型的预测在单一样本层面的理解程度,而全局可解释性指模型作为一个整体的理解能力。这种技术可以帮助机器学习 (ML) 从业者在模型开发阶段检测偏差和错误。

对于局部可解释性,您将了解到如何创造并可视化每个实例(per-instance)的贡献度。区别于特征重要性,这种贡献被称为 DFCs(定向特征贡献,directional feature contributions)。

对于全局可解释性,您将检索并呈现基于增益的特征重要性、排列特征重要性,显示汇总的 DFC。

加载泰坦尼克数据集(titanic)

本教程使用泰坦尼克数据集,旨在已知乘客的性别,年龄和客舱等级等特征的情况下预测的存活率。

pip install statsmodels
import numpy as np
import pandas as pd
from IPython.display import clear_output

# Load dataset.
dftrain = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
dfeval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
y_train = dftrain.pop('survived')
y_eval = dfeval.pop('survived')
import tensorflow as tf
tf.random.set_seed(123)

有关特征的描述,请参阅之前的教程。

创建特征列, 输入函数并训练 estimator

数据预处理

特征处理,使用原始的数值特征和独热编码(one-hot-encoding)处理过的非数值特征(如性别,舱位)别建立数据集。

fc = tf.feature_column
CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck',
                       'embark_town', 'alone']
NUMERIC_COLUMNS = ['age', 'fare']

def one_hot_cat_column(feature_name, vocab):
  return fc.indicator_column(
      fc.categorical_column_with_vocabulary_list(feature_name,
                                                 vocab))
feature_columns = []
for feature_name in CATEGORICAL_COLUMNS:
  # Need to one-hot encode categorical features.
  vocabulary = dftrain[feature_name].unique()
  feature_columns.append(one_hot_cat_column(feature_name, vocabulary))

for feature_name in NUMERIC_COLUMNS:
  feature_columns.append(fc.numeric_column(feature_name,
                                           dtype=tf.float32))

构建输入 pipeline

使用 API tf.data 中的 from_tensor_slices 方法建立输入方程来从 Pandas 中直接读取数据。

# 当数据集小的时候,将整个数据集作为一个 batch。
NUM_EXAMPLES = len(y_train)

def make_input_fn(X, y, n_epochs=None, shuffle=True):
  def input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((X.to_dict(orient='list'), y))
    if shuffle:
      dataset = dataset.shuffle(NUM_EXAMPLES)
    # 训练时让数据迭代尽可能多次 (n_epochs=None)。
    dataset = (dataset
      .repeat(n_epochs)
      .batch(NUM_EXAMPLES))
    return dataset
  return input_fn

# 训练并评估输入函数。
train_input_fn = make_input_fn(dftrain, y_train)
eval_input_fn = make_input_fn(dfeval, y_eval, shuffle=False, n_epochs=1)

训练模型

params = {
  'n_trees': 50,
  'max_depth': 3,
  'n_batches_per_layer': 1,
  # You must enable center_bias = True to get DFCs. This will force the model to
  # make an initial prediction before using any features (e.g. use the mean of
  # the training labels for regression or log odds for classification when
  # using cross entropy loss).
  'center_bias': True
}

est = tf.estimator.BoostedTreesClassifier(feature_columns, **params)
# Train model.
est.train(train_input_fn, max_steps=100)

# Evaluation.
results = est.evaluate(eval_input_fn)
clear_output()
pd.Series(results).to_frame()

出于性能原因,当您的数据适合内存时,我们建议在 tf.estimator.BoostedTreesClassifier 函数中使用参数 train_in_memory=True。但是,如果训练时间不是关注的问题,或者如果您有一个非常大的数据集并且想要进行分布式训练,请使用上面显示的 tf.estimator.BoostedTrees API。

当您使用此方法时,请不要对数据分批(batch),而是对整个数据集进行操作。

in_memory_params = dict(params)
in_memory_params['n_batches_per_layer'] = 1
# In-memory input_fn does not use batching.
def make_inmemory_train_input_fn(X, y):
  y = np.expand_dims(y, axis=1)
  def input_fn():
    return dict(X), y
  return input_fn
train_input_fn = make_inmemory_train_input_fn(dftrain, y_train)

# Train the model.
est = tf.estimator.BoostedTreesClassifier(
    feature_columns, 
    train_in_memory=True, 
    **in_memory_params)

est.train(train_input_fn)
print(est.evaluate(eval_input_fn))
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmphu8iw8sw
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphu8iw8sw', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphu8iw8sw/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 0 vs previous value: 0. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 121.035
INFO:tensorflow:loss = 0.34396845, step = 99 (0.827 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 153...
INFO:tensorflow:Saving checkpoints for 153 into /tmp/tmphu8iw8sw/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 153...
INFO:tensorflow:Loss for final step: 0.32042706.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:29
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.47604s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:30
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.81439394, accuracy_baseline = 0.625, auc = 0.86923784, auc_precision_recall = 0.85286695, average_loss = 0.41441453, global_step = 153, label/mean = 0.375, loss = 0.41441453, precision = 0.7604167, prediction/mean = 0.38847554, recall = 0.7373737
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
{'accuracy': 0.81439394, 'accuracy_baseline': 0.625, 'auc': 0.86923784, 'auc_precision_recall': 0.85286695, 'average_loss': 0.41441453, 'label/mean': 0.375, 'loss': 0.41441453, 'precision': 0.7604167, 'prediction/mean': 0.38847554, 'recall': 0.7373737, 'global_step': 153}

模型说明与绘制

import matplotlib.pyplot as plt
import seaborn as sns
sns_colors = sns.color_palette('colorblind')

局部可解释性(Local interpretability)

接下来,您将输出定向特征贡献(DFCs)来解释单个预测。输出依据 Palczewska et al 和 Saabas 在 解释随机森林(Interpreting Random Forests) 中提出的方法产生(scikit-learn 中随机森林相关的包 treeinterpreter 使用原理相同的远离). 使用以下语句输出 DFCs:

pred_dicts = list(est.experimental_predict_with_explanations(pred_input_fn))

(注:该方法被命名为实验性,因为我们可能会在放弃实验性前缀之前修改 API。)

pred_dicts = list(est.experimental_predict_with_explanations(eval_input_fn))
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphu8iw8sw', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
# Create DFC Pandas dataframe.
labels = y_eval.values
probs = pd.Series([pred['probabilities'][1] for pred in pred_dicts])
df_dfc = pd.DataFrame([pred['dfc'] for pred in pred_dicts])
df_dfc.describe().T

DFC 有一个非常好的属性,即贡献 + 偏差的总和等于给定样本的预测。

# Sum of DFCs + bias == probabality.
bias = pred_dicts[0]['bias']
dfc_prob = df_dfc.sum(axis=1) + bias
np.testing.assert_almost_equal(dfc_prob.values,
                               probs.values)

为单个乘客绘制 DFCs,绘图时按贡献的方向性对其进行涂色并添加特征的值。

# Boilerplate code for plotting :)
def _get_color(value):
    """To make positive DFCs plot green, negative DFCs plot red."""
    green, red = sns.color_palette()[2:4]
    if value >= 0: return green
    return red

def _add_feature_values(feature_values, ax):
    """Display feature's values on left of plot."""
    x_coord = ax.get_xlim()[0]
    OFFSET = 0.15
    for y_coord, (feat_name, feat_val) in enumerate(feature_values.items()):
        t = plt.text(x_coord, y_coord - OFFSET, '{}'.format(feat_val), size=12)
        t.set_bbox(dict(facecolor='white', alpha=0.5))
    from matplotlib.font_manager import FontProperties
    font = FontProperties()
    font.set_weight('bold')
    t = plt.text(x_coord, y_coord + 1 - OFFSET, 'feature\nvalue',
    fontproperties=font, size=12)

def plot_example(example):
  TOP_N = 8 # View top 8 features.
  sorted_ix = example.abs().sort_values()[-TOP_N:].index  # Sort by magnitude.
  example = example[sorted_ix]
  colors = example.map(_get_color).tolist()
  ax = example.to_frame().plot(kind='barh',
                          color=colors,
                          legend=None,
                          alpha=0.75,
                          figsize=(10,6))
  ax.grid(False, axis='y')
  ax.set_yticklabels(ax.get_yticklabels(), size=14)

  # Add feature values.
  _add_feature_values(dfeval.iloc[ID][sorted_ix], ax)
  return ax
# Plot results.
ID = 182
example = df_dfc.iloc[ID]  # Choose ith example from evaluation set.
TOP_N = 8  # View top 8 features.
sorted_ix = example.abs().sort_values()[-TOP_N:].index
ax = plot_example(example)
ax.set_title('Feature contributions for example {}\n pred: {:1.2f}; label: {}'.format(ID, probs[ID], labels[ID]))
ax.set_xlabel('Contribution to predicted probability', size=14)
plt.show()

png

更大的贡献值意味着对模型的预测有更大的影响。负的贡献表示此样例该特征的值减小了减小了模型的预测,正贡献值表示增加了模型的预测。

您也可以使用小提琴图(violin plot)来绘制该样例的 DFCs 并与整体分布比较。

# Boilerplate plotting code.
def dist_violin_plot(df_dfc, ID):
  # Initialize plot.
  fig, ax = plt.subplots(1, 1, figsize=(10, 6))

  # Create example dataframe.
  TOP_N = 8  # View top 8 features.
  example = df_dfc.iloc[ID]
  ix = example.abs().sort_values()[-TOP_N:].index
  example = example[ix]
  example_df = example.to_frame(name='dfc')

  # Add contributions of entire distribution.
  parts=ax.violinplot([df_dfc[w] for w in ix],
                 vert=False,
                 showextrema=False,
                 widths=0.7,
                 positions=np.arange(len(ix)))
  face_color = sns_colors[0]
  alpha = 0.15
  for pc in parts['bodies']:
      pc.set_facecolor(face_color)
      pc.set_alpha(alpha)

  # Add feature values.
  _add_feature_values(dfeval.iloc[ID][sorted_ix], ax)

  # Add local contributions.
  ax.scatter(example,
              np.arange(example.shape[0]),
              color=sns.color_palette()[2],
              s=100,
              marker="s",
              label='contributions for example')

  # Legend
  # Proxy plot, to show violinplot dist on legend.
  ax.plot([0,0], [1,1], label='eval set contributions\ndistributions',
          color=face_color, alpha=alpha, linewidth=10)
  legend = ax.legend(loc='lower right', shadow=True, fontsize='x-large',
                     frameon=True)
  legend.get_frame().set_facecolor('white')

  # Format plot.
  ax.set_yticks(np.arange(example.shape[0]))
  ax.set_yticklabels(example.index)
  ax.grid(False, axis='y')
  ax.set_xlabel('Contribution to predicted probability', size=14)

绘制此样例。

dist_violin_plot(df_dfc, ID)
plt.title('Feature contributions for example {}\n pred: {:1.2f}; label: {}'.format(ID, probs[ID], labels[ID]))
plt.show()

png

最后,第三方的工具,如:LIMEshap 也可以帮助理解模型的各个预测。

全局特征重要性(Global feature importances)

此外,您或许想了解模型这个整体而不是单个预测。接下来,您将计算并使用:

  • 使用 est.experimental_feature_importances 得到基于增益的特征重要性
  • 排列特征重要性(Permutation feature importances)
  • 使用 est.experimental_predict_with_explanations 得到总 DFCs。

基于增益的特征重要性在分离特定特征时测量损失的变化。而排列特征重要性是在评估集上通过每次打乱一个特征后观察模型性能的变化计算而出。

一般来说,排列特征重要性要优于基于增益的特征重要性,尽管这两种方法在潜在预测变量的测量范围或类别数量不确定时和特征相关联时不可信(来源)。 对不同种类特征重要性的更透彻概括和更翔实讨论请参考 这篇文章

基于增益的特征重要性(Gain-based feature importances)

基于增益的特征重要性使用 est.experimental_feature_importances 内置到 TensorFlow 提升树 Estimator 中。

importances = est.experimental_feature_importances(normalize=True)
df_imp = pd.Series(importances)

# Visualize importances.
N = 8
ax = (df_imp.iloc[0:N][::-1]
    .plot(kind='barh',
          color=sns_colors[0],
          title='Gain feature importances',
          figsize=(10, 6)))
ax.grid(False, axis='y')

png

平均绝对 DFCs

您还可以得到绝对DFCs的平均值来从全局的角度分析影响。

# Plot.
dfc_mean = df_dfc.abs().mean()
N = 8
sorted_ix = dfc_mean.abs().sort_values()[-N:].index  # Average and sort by absolute.
ax = dfc_mean[sorted_ix].plot(kind='barh',
                       color=sns_colors[1],
                       title='Mean |directional feature contributions|',
                       figsize=(10, 6))
ax.grid(False, axis='y')

png

您可以看到 DFCs 如何随特征的值变化而变化。

FEATURE = 'fare'
feature = pd.Series(df_dfc[FEATURE].values, index=dfeval[FEATURE].values).sort_index()
ax = sns.regplot(feature.index.values, feature.values, lowess=True)
ax.set_ylabel('contribution')
ax.set_xlabel(FEATURE)
ax.set_xlim(0, 100)
plt.show()
/home/kbuilder/.local/lib/python3.7/site-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  FutureWarning

png

排列特征重要性(Permutation feature importances)

def permutation_importances(est, X_eval, y_eval, metric, features):
    """Column by column, shuffle values and observe effect on eval set.

    source: http://explained.ai/rf-importance/index.html
    A similar approach can be done during training. See "Drop-column importance"
    in the above article."""
    baseline = metric(est, X_eval, y_eval)
    imp = []
    for col in features:
        save = X_eval[col].copy()
        X_eval[col] = np.random.permutation(X_eval[col])
        m = metric(est, X_eval, y_eval)
        X_eval[col] = save
        imp.append(baseline - m)
    return np.array(imp)

def accuracy_metric(est, X, y):
    """TensorFlow estimator accuracy."""
    eval_input_fn = make_input_fn(X,
                                  y=y,
                                  shuffle=False,
                                  n_epochs=1)
    return est.evaluate(input_fn=eval_input_fn)['accuracy']
features = CATEGORICAL_COLUMNS + NUMERIC_COLUMNS
importances = permutation_importances(est, dfeval, y_eval, accuracy_metric,
                                      features)
df_imp = pd.Series(importances, index=features)

sorted_ix = df_imp.abs().sort_values().index
ax = df_imp[sorted_ix][-5:].plot(kind='barh', color=sns_colors[2], figsize=(10, 6))
ax.grid(False, axis='y')
ax.set_title('Permutation feature importance')
plt.show()
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:32
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.50321s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:33
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.81439394, accuracy_baseline = 0.625, auc = 0.86923784, auc_precision_recall = 0.85286695, average_loss = 0.41441453, global_step = 153, label/mean = 0.375, loss = 0.41441453, precision = 0.7604167, prediction/mean = 0.38847554, recall = 0.7373737
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:33
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.47380s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:34
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.64772725, accuracy_baseline = 0.625, auc = 0.6696052, auc_precision_recall = 0.6017952, average_loss = 0.6911759, global_step = 153, label/mean = 0.375, loss = 0.6911759, precision = 0.53061223, prediction/mean = 0.39098436, recall = 0.5252525
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:35
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.47861s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:35
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.79924244, accuracy_baseline = 0.625, auc = 0.8575452, auc_precision_recall = 0.83676726, average_loss = 0.43859679, global_step = 153, label/mean = 0.375, loss = 0.43859679, precision = 0.7254902, prediction/mean = 0.3975416, recall = 0.74747473
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:36
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.48540s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:36
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.81439394, accuracy_baseline = 0.625, auc = 0.8685644, auc_precision_recall = 0.84974575, average_loss = 0.41721427, global_step = 153, label/mean = 0.375, loss = 0.41721427, precision = 0.7604167, prediction/mean = 0.3896767, recall = 0.7373737
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:37
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.47008s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:37
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.7462121, accuracy_baseline = 0.625, auc = 0.8023263, auc_precision_recall = 0.70273143, average_loss = 0.54740644, global_step = 153, label/mean = 0.375, loss = 0.54740644, precision = 0.6818182, prediction/mean = 0.37652308, recall = 0.6060606
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:38
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.45468s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:38
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.8030303, accuracy_baseline = 0.625, auc = 0.86134064, auc_precision_recall = 0.8375716, average_loss = 0.43290317, global_step = 153, label/mean = 0.375, loss = 0.43290317, precision = 0.75268817, prediction/mean = 0.38996464, recall = 0.7070707
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:39
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.44388s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:39
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.81060606, accuracy_baseline = 0.625, auc = 0.8542087, auc_precision_recall = 0.8357475, average_loss = 0.43002674, global_step = 153, label/mean = 0.375, loss = 0.43002674, precision = 0.76344085, prediction/mean = 0.3821859, recall = 0.7171717
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:40
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.47039s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:41
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.81439394, accuracy_baseline = 0.625, auc = 0.86923784, auc_precision_recall = 0.85286695, average_loss = 0.41441453, global_step = 153, label/mean = 0.375, loss = 0.41441453, precision = 0.7604167, prediction/mean = 0.38847554, recall = 0.7373737
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:41
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.46282s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:42
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.77272725, accuracy_baseline = 0.625, auc = 0.81190693, auc_precision_recall = 0.8001259, average_loss = 0.48571673, global_step = 153, label/mean = 0.375, loss = 0.48571673, precision = 0.7241379, prediction/mean = 0.37983444, recall = 0.6363636
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-25T19:11:42
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphu8iw8sw/model.ckpt-153
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.46223s
INFO:tensorflow:Finished evaluation at 2021-08-25-19:11:43
INFO:tensorflow:Saving dict for global step 153: accuracy = 0.81060606, accuracy_baseline = 0.625, auc = 0.8340985, auc_precision_recall = 0.79896426, average_loss = 0.45829892, global_step = 153, label/mean = 0.375, loss = 0.45829892, precision = 0.7816092, prediction/mean = 0.3777119, recall = 0.68686867
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 153: /tmp/tmphu8iw8sw/model.ckpt-153

png

可视化模型拟合过程

首先,使用以下公式构建训练数据:

$$z=x* e^{-x^2 - y^2}$$

其中, (z) 是您要试着预测的值(因变量),(x) 和 (y) 是特征。

from numpy.random import uniform, seed
from scipy.interpolate import griddata

# Create fake data
seed(0)
npts = 5000
x = uniform(-2, 2, npts)
y = uniform(-2, 2, npts)
z = x*np.exp(-x**2 - y**2)
xy = np.zeros((2,np.size(x)))
xy[0] = x
xy[1] = y
xy = xy.T
# Prep data for training.
df = pd.DataFrame({'x': x, 'y': y, 'z': z})

xi = np.linspace(-2.0, 2.0, 200),
yi = np.linspace(-2.1, 2.1, 210),
xi,yi = np.meshgrid(xi, yi)

df_predict = pd.DataFrame({
    'x' : xi.flatten(),
    'y' : yi.flatten(),
})
predict_shape = xi.shape
def plot_contour(x, y, z, **kwargs):
  # Grid the data.
  plt.figure(figsize=(10, 8))
  # Contour the gridded data, plotting dots at the nonuniform data points.
  CS = plt.contour(x, y, z, 15, linewidths=0.5, colors='k')
  CS = plt.contourf(x, y, z, 15,
                    vmax=abs(zi).max(), vmin=-abs(zi).max(), cmap='RdBu_r')
  plt.colorbar()  # Draw colorbar.
  # Plot data points.
  plt.xlim(-2, 2)
  plt.ylim(-2, 2)

您可以可视化这个方程,红色代表较大的值。

zi = griddata(xy, z, (xi, yi), method='linear', fill_value='0')
plot_contour(xi, yi, zi)
plt.scatter(df.x, df.y, marker='.')
plt.title('Contour on training data')
plt.show()

png

fc = [tf.feature_column.numeric_column('x'),
      tf.feature_column.numeric_column('y')]
def predict(est):
  """Predictions from a given estimator."""
  predict_input_fn = lambda: tf.data.Dataset.from_tensors(dict(df_predict))
  preds = np.array([p['predictions'][0] for p in est.predict(predict_input_fn)])
  return preds.reshape(predict_shape)

首先,我们尝试用线性模型拟合数据。

train_input_fn = make_input_fn(df, df.z)
est = tf.estimator.LinearRegressor(fc)
est.train(train_input_fn, max_steps=500);
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpob5vo3oc
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpob5vo3oc', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_v1.py:1684: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/ftrl.py:147: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpob5vo3oc/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.02283058, step = 0
INFO:tensorflow:global_step/sec: 293.205
INFO:tensorflow:loss = 0.01925436, step = 100 (0.342 sec)
INFO:tensorflow:global_step/sec: 335.657
INFO:tensorflow:loss = 0.021295102, step = 200 (0.298 sec)
INFO:tensorflow:global_step/sec: 332.371
INFO:tensorflow:loss = 0.017289896, step = 300 (0.301 sec)
INFO:tensorflow:global_step/sec: 335.085
INFO:tensorflow:loss = 0.018039428, step = 400 (0.298 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...
INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmpob5vo3oc/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...
INFO:tensorflow:Loss for final step: 0.019477699.
plot_contour(xi, yi, predict(est))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpob5vo3oc/model.ckpt-500
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

png

可见,拟合效果并不好。接下来,我们试着用 GBDT 模型拟合并了解模型是如何拟合方程的。

n_trees = 37

est = tf.estimator.BoostedTreesRegressor(fc, n_batches_per_layer=1, n_trees=n_trees)
est.train(train_input_fn, max_steps=500)
clear_output()
plot_contour(xi, yi, predict(est))
plt.text(-1.8, 2.1, '# trees: {}'.format(n_trees), color='w', backgroundcolor='black', size=20)
plt.show()
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpnh5eqxum/model.ckpt-222
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

png

随着树的数量增加,模型的预测越来越接近真实方程。

总结

在本教程中,您学习了如何使用定向特征贡献和特征重要性技术来解释提升树模型。这些技术可以帮助您了解特征如何影响模型的预测。 最后,您还通过查看多个模型的决策图面直观地了解了提升树模型如何拟合复杂函数。