Inspecionar e depurar modelos florestais de decisão

Ver no TensorFlow.org Executar no Google Colab Ver no GitHub Baixar caderno

Neste colab, você aprenderá como inspecionar e criar a estrutura de um modelo diretamente. Vamos supor que você está familiarizado com os conceitos introduzidos na iniciante e intermediário colabs.

Nesta colab, você irá:

  1. Treine um modelo Random Forest e acesse sua estrutura de maneira programática.

  2. Crie um modelo Random Forest manualmente e use-o como um modelo clássico.

Configurar

# Install TensorFlow Dececision Forests.
pip install tensorflow_decision_forests

# Use wurlitzer to capture training logs.
pip install wurlitzer
import tensorflow_decision_forests as tfdf

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import collections

try:
  from wurlitzer import sys_pipes
except:
  from colabtools.googlelog import CaptureLog as sys_pipes

from IPython.core.magic import register_line_magic
from IPython.display import Javascript
WARNING:root:Failure to load the custom c++ tensorflow ops. This error is likely caused the version of TensorFlow and TensorFlow Decision Forests are not compatible.
WARNING:root:TF Parameter Server distributed training not available.

A célula de código oculta limita a altura de saída na colab.

Treine uma floresta aleatória simples

Nós treinamos uma floresta aleatória como no colab iniciante :

# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv

# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")

# Show the first three examples.
print(dataset_df.head(3))

# Convert the pandas dataframe into a tf dataset.
dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")

# Train the Random Forest
model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)
model.fit(x=dataset_tf)
species     island  bill_length_mm  bill_depth_mm  flipper_length_mm  \
0  Adelie  Torgersen            39.1           18.7              181.0   
1  Adelie  Torgersen            39.5           17.4              186.0   
2  Adelie  Torgersen            40.3           18.0              195.0   

   body_mass_g     sex  year  
0       3750.0    male  2007  
1       3800.0  female  2007  
2       3250.0  female  2007
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1612: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
  features_dataframe = dataframe.drop(label, 1)
6/6 [==============================] - 4s 17ms/step
[INFO kernel.cc:736] Start Yggdrasil model training
[INFO kernel.cc:737] Collect training examples
[INFO kernel.cc:392] Number of batches: 6
[INFO kernel.cc:393] Number of examples: 344
[INFO kernel.cc:759] Dataset:
Number of records: 344
Number of columns: 8

Number of columns by type:
    NUMERICAL: 5 (62.5%)
    CATEGORICAL: 3 (37.5%)

Columns:

NUMERICAL: 5 (62.5%)
    0: "bill_depth_mm" NUMERICAL num-nas:2 (0.581395%) mean:17.1512 min:13.1 max:21.5 sd:1.9719
    1: "bill_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:43.9219 min:32.1 max:59.6 sd:5.4516
    2: "body_mass_g" NUMERICAL num-nas:2 (0.581395%) mean:4201.75 min:2700 max:6300 sd:800.781
    3: "flipper_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:200.915 min:172 max:231 sd:14.0411
    6: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.817166

CATEGORICAL: 3 (37.5%)
    4: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 168 (48.8372%)
    5: "sex" CATEGORICAL num-nas:11 (3.19767%) has-dict vocab-size:3 zero-ood-items most-frequent:"male" 168 (50.4505%)
    7: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

[INFO kernel.cc:762] Configure learner
[INFO kernel.cc:787] Training config:
learner: "RANDOM_FOREST"
features: "bill_depth_mm"
features: "bill_length_mm"
features: "body_mass_g"
features: "flipper_length_mm"
features: "island"
features: "sex"
features: "year"
label: "__LABEL"
task: CLASSIFICATION
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
  num_trees: 300
  decision_tree {
    max_depth: 16
    min_examples: 5
    in_split_min_examples_check: true
    missing_value_policy: GLOBAL_IMPUTATION
    allow_na_conditions: false
    categorical_set_greedy_forward {
      sampling: 0.1
      max_num_items: -1
      min_item_frequency: 1
    }
    growing_strategy_local {
    }
    categorical {
      cart {
      }
    }
    num_candidate_attributes_ratio: -1
    axis_aligned_split {
    }
    internal {
      sorting_strategy: PRESORTED
    }
  }
  winner_take_all_inference: true
  compute_oob_performances: true
  compute_oob_variable_importances: true
  adapt_bootstrap_size_ratio_for_maximum_training_duration: false
}

[INFO kernel.cc:790] Deployment config:
num_threads: 6

[INFO kernel.cc:817] Train model
[INFO random_forest.cc:315] Training random forest on 344 example(s) and 7 feature(s).
[INFO random_forest.cc:628] Training of tree  1/300 (tree index:0) done accuracy:0.964286 logloss:1.28727
[INFO random_forest.cc:628] Training of tree  11/300 (tree index:10) done accuracy:0.956268 logloss:0.584301
[INFO random_forest.cc:628] Training of tree  22/300 (tree index:21) done accuracy:0.965116 logloss:0.378823
[INFO random_forest.cc:628] Training of tree  35/300 (tree index:34) done accuracy:0.968023 logloss:0.178185
[INFO random_forest.cc:628] Training of tree  46/300 (tree index:45) done accuracy:0.973837 logloss:0.170304
[INFO random_forest.cc:628] Training of tree  58/300 (tree index:57) done accuracy:0.973837 logloss:0.171223
[INFO random_forest.cc:628] Training of tree  70/300 (tree index:69) done accuracy:0.979651 logloss:0.169564
[INFO random_forest.cc:628] Training of tree  83/300 (tree index:82) done accuracy:0.976744 logloss:0.17074
[INFO random_forest.cc:628] Training of tree  96/300 (tree index:95) done accuracy:0.976744 logloss:0.0736925
[INFO random_forest.cc:628] Training of tree  106/300 (tree index:105) done accuracy:0.976744 logloss:0.0748649
[INFO random_forest.cc:628] Training of tree  117/300 (tree index:116) done accuracy:0.976744 logloss:0.074671
[INFO random_forest.cc:628] Training of tree  130/300 (tree index:129) done accuracy:0.976744 logloss:0.0736275
[INFO random_forest.cc:628] Training of tree  140/300 (tree index:139) done accuracy:0.976744 logloss:0.0727718
[INFO random_forest.cc:628] Training of tree  152/300 (tree index:151) done accuracy:0.976744 logloss:0.0715068
[INFO random_forest.cc:628] Training of tree  162/300 (tree index:161) done accuracy:0.976744 logloss:0.0708994
[INFO random_forest.cc:628] Training of tree  173/300 (tree index:172) done accuracy:0.976744 logloss:0.069447
[INFO random_forest.cc:628] Training of tree  184/300 (tree index:183) done accuracy:0.976744 logloss:0.0695926
[INFO random_forest.cc:628] Training of tree  195/300 (tree index:194) done accuracy:0.976744 logloss:0.0690138
[INFO random_forest.cc:628] Training of tree  205/300 (tree index:204) done accuracy:0.976744 logloss:0.0694597
[INFO random_forest.cc:628] Training of tree  217/300 (tree index:216) done accuracy:0.976744 logloss:0.068122
[INFO random_forest.cc:628] Training of tree  229/300 (tree index:228) done accuracy:0.976744 logloss:0.0687641
[INFO random_forest.cc:628] Training of tree  239/300 (tree index:238) done accuracy:0.976744 logloss:0.067988
[INFO random_forest.cc:628] Training of tree  250/300 (tree index:249) done accuracy:0.976744 logloss:0.0690187
[INFO random_forest.cc:628] Training of tree  260/300 (tree index:259) done accuracy:0.976744 logloss:0.0690134
[INFO random_forest.cc:628] Training of tree  270/300 (tree index:269) done accuracy:0.976744 logloss:0.0689877
[INFO random_forest.cc:628] Training of tree  280/300 (tree index:279) done accuracy:0.976744 logloss:0.0689845
[INFO random_forest.cc:628] Training of tree  290/300 (tree index:288) done accuracy:0.976744 logloss:0.0690742
[INFO random_forest.cc:628] Training of tree  300/300 (tree index:299) done accuracy:0.976744 logloss:0.068949
[INFO random_forest.cc:696] Final OOB metrics: accuracy:0.976744 logloss:0.068949
[INFO kernel.cc:828] Export model in log directory: /tmp/tmpoqki9pfl
[INFO kernel.cc:836] Save model in resources
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 300 root(s), 5080 node(s), and 7 input feature(s).
[INFO abstract_model.cc:993] Engine "RandomForestGeneric" built
[INFO kernel.cc:848] Use fast generic engine
<keras.callbacks.History at 0x7f09eaa9cb90>

Observe o compute_oob_variable_importances=True hiper-parâmetro no construtor modelo. Esta opção calcula a importância da variável Out-of-bag (OOB) durante o treinamento. Este é um popular importância variável permutação para modelos aleatória Floresta.

Calcular a importância da variável OOB não afeta o modelo final, ele irá retardar o treinamento em grandes conjuntos de dados.

Verifique o resumo do modelo:

%set_cell_height 300

model.summary()
<IPython.core.display.Javascript object>
Model: "random_forest_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (7):
    bill_depth_mm
    bill_length_mm
    body_mass_g
    flipper_length_mm
    island
    sex
    year

No weights

Variable Importance: MEAN_DECREASE_IN_ACCURACY:

    1.    "bill_length_mm"  0.151163 ################
    2.            "island"  0.008721 #
    3.     "bill_depth_mm"  0.000000 
    4.       "body_mass_g"  0.000000 
    5.               "sex"  0.000000 
    6.              "year"  0.000000 
    7. "flipper_length_mm" -0.002907 

Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS:

    1.    "bill_length_mm"  0.083305 ################
    2.            "island"  0.007664 #
    3. "flipper_length_mm"  0.003400 
    4.     "bill_depth_mm"  0.002741 
    5.       "body_mass_g"  0.000722 
    6.               "sex"  0.000644 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS:

    1.    "bill_length_mm"  0.508510 ################
    2.            "island"  0.023487 
    3.     "bill_depth_mm"  0.007744 
    4. "flipper_length_mm"  0.006008 
    5.       "body_mass_g"  0.003017 
    6.               "sex"  0.001537 
    7.              "year" -0.000245 

Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS:

    1.            "island"  0.002192 ################
    2.    "bill_length_mm"  0.001572 ############
    3.     "bill_depth_mm"  0.000497 #######
    4.               "sex"  0.000000 ####
    5.              "year"  0.000000 ####
    6.       "body_mass_g" -0.000053 ####
    7. "flipper_length_mm" -0.000890 

Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS:

    1.    "bill_length_mm"  0.071306 ################
    2.            "island"  0.007299 #
    3. "flipper_length_mm"  0.004506 #
    4.     "bill_depth_mm"  0.002124 
    5.       "body_mass_g"  0.000548 
    6.               "sex"  0.000480 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS:

    1.    "bill_length_mm"  0.108642 ################
    2.            "island"  0.014493 ##
    3.     "bill_depth_mm"  0.007406 #
    4. "flipper_length_mm"  0.005195 
    5.       "body_mass_g"  0.001012 
    6.               "sex"  0.000480 
    7.              "year" -0.000053 

Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS:

    1.            "island"  0.002126 ################
    2.    "bill_length_mm"  0.001393 ###########
    3.     "bill_depth_mm"  0.000293 #####
    4.               "sex"  0.000000 ###
    5.              "year"  0.000000 ###
    6.       "body_mass_g" -0.000037 ###
    7. "flipper_length_mm" -0.000550 

Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS:

    1.    "bill_length_mm"  0.083122 ################
    2.            "island"  0.010887 ##
    3. "flipper_length_mm"  0.003425 
    4.     "bill_depth_mm"  0.002731 
    5.       "body_mass_g"  0.000719 
    6.               "sex"  0.000641 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS:

    1.    "bill_length_mm"  0.497611 ################
    2.            "island"  0.024045 
    3.     "bill_depth_mm"  0.007734 
    4. "flipper_length_mm"  0.006017 
    5.       "body_mass_g"  0.003000 
    6.               "sex"  0.001528 
    7.              "year" -0.000243 

Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS:

    1.            "island"  0.002187 ################
    2.    "bill_length_mm"  0.001568 ############
    3.     "bill_depth_mm"  0.000495 #######
    4.               "sex"  0.000000 ####
    5.              "year"  0.000000 ####
    6.       "body_mass_g" -0.000053 ####
    7. "flipper_length_mm" -0.000886 

Variable Importance: MEAN_MIN_DEPTH:

    1.           "__LABEL"  3.479602 ################
    2.              "year"  3.463891 ###############
    3.               "sex"  3.430498 ###############
    4.       "body_mass_g"  2.898112 ###########
    5.            "island"  2.388925 ########
    6.     "bill_depth_mm"  2.336100 #######
    7.    "bill_length_mm"  1.282960 
    8. "flipper_length_mm"  1.270079 

Variable Importance: NUM_AS_ROOT:

    1. "flipper_length_mm" 157.000000 ################
    2.    "bill_length_mm" 76.000000 #######
    3.     "bill_depth_mm" 52.000000 #####
    4.            "island" 12.000000 
    5.       "body_mass_g"  3.000000 

Variable Importance: NUM_NODES:

    1.    "bill_length_mm" 778.000000 ################
    2.     "bill_depth_mm" 463.000000 #########
    3. "flipper_length_mm" 414.000000 ########
    4.            "island" 342.000000 ######
    5.       "body_mass_g" 338.000000 ######
    6.               "sex" 36.000000 
    7.              "year" 19.000000 

Variable Importance: SUM_SCORE:

    1.    "bill_length_mm" 36515.793787 ################
    2. "flipper_length_mm" 35120.434174 ###############
    3.            "island" 14669.408395 ######
    4.     "bill_depth_mm" 14515.446617 ######
    5.       "body_mass_g" 3485.330881 #
    6.               "sex" 354.201073 
    7.              "year" 49.737758 



Winner take all: true
Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949
Number of trees: 300
Total number of nodes: 5080

Number of nodes by tree:
Count: 300 Average: 16.9333 StdDev: 3.10197
Min: 11 Max: 31 Ignored: 0
----------------------------------------------
[ 11, 12)  6   2.00%   2.00% #
[ 12, 13)  0   0.00%   2.00%
[ 13, 14) 46  15.33%  17.33% #####
[ 14, 15)  0   0.00%  17.33%
[ 15, 16) 70  23.33%  40.67% ########
[ 16, 17)  0   0.00%  40.67%
[ 17, 18) 84  28.00%  68.67% ##########
[ 18, 19)  0   0.00%  68.67%
[ 19, 20) 46  15.33%  84.00% #####
[ 20, 21)  0   0.00%  84.00%
[ 21, 22) 30  10.00%  94.00% ####
[ 22, 23)  0   0.00%  94.00%
[ 23, 24) 13   4.33%  98.33% ##
[ 24, 25)  0   0.00%  98.33%
[ 25, 26)  2   0.67%  99.00%
[ 26, 27)  0   0.00%  99.00%
[ 27, 28)  2   0.67%  99.67%
[ 28, 29)  0   0.00%  99.67%
[ 29, 30)  0   0.00%  99.67%
[ 30, 31]  1   0.33% 100.00%

Depth by leafs:
Count: 2690 Average: 3.53271 StdDev: 1.06789
Min: 2 Max: 7 Ignored: 0
----------------------------------------------
[ 2, 3) 545  20.26%  20.26% ######
[ 3, 4) 747  27.77%  48.03% ########
[ 4, 5) 888  33.01%  81.04% ##########
[ 5, 6) 444  16.51%  97.55% #####
[ 6, 7)  62   2.30%  99.85% #
[ 7, 7]   4   0.15% 100.00%

Number of training obs by leaf:
Count: 2690 Average: 38.3643 StdDev: 44.8651
Min: 5 Max: 155 Ignored: 0
----------------------------------------------
[   5,  12) 1474  54.80%  54.80% ##########
[  12,  20)  124   4.61%  59.41% #
[  20,  27)   48   1.78%  61.19%
[  27,  35)   74   2.75%  63.94% #
[  35,  42)   58   2.16%  66.10%
[  42,  50)   85   3.16%  69.26% #
[  50,  57)   96   3.57%  72.83% #
[  57,  65)   87   3.23%  76.06% #
[  65,  72)   49   1.82%  77.88%
[  72,  80)   23   0.86%  78.74%
[  80,  88)   30   1.12%  79.85%
[  88,  95)   23   0.86%  80.71%
[  95, 103)   42   1.56%  82.27%
[ 103, 110)   62   2.30%  84.57%
[ 110, 118)  115   4.28%  88.85% #
[ 118, 125)  115   4.28%  93.12% #
[ 125, 133)   98   3.64%  96.77% #
[ 133, 140)   49   1.82%  98.59%
[ 140, 148)   31   1.15%  99.74%
[ 148, 155]    7   0.26% 100.00%

Attribute in nodes:
    778 : bill_length_mm [NUMERICAL]
    463 : bill_depth_mm [NUMERICAL]
    414 : flipper_length_mm [NUMERICAL]
    342 : island [CATEGORICAL]
    338 : body_mass_g [NUMERICAL]
    36 : sex [CATEGORICAL]
    19 : year [NUMERICAL]

Attribute in nodes with depth <= 0:
    157 : flipper_length_mm [NUMERICAL]
    76 : bill_length_mm [NUMERICAL]
    52 : bill_depth_mm [NUMERICAL]
    12 : island [CATEGORICAL]
    3 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 1:
    250 : bill_length_mm [NUMERICAL]
    244 : flipper_length_mm [NUMERICAL]
    183 : bill_depth_mm [NUMERICAL]
    170 : island [CATEGORICAL]
    53 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 2:
    462 : bill_length_mm [NUMERICAL]
    320 : flipper_length_mm [NUMERICAL]
    310 : bill_depth_mm [NUMERICAL]
    287 : island [CATEGORICAL]
    162 : body_mass_g [NUMERICAL]
    9 : sex [CATEGORICAL]
    5 : year [NUMERICAL]

Attribute in nodes with depth <= 3:
    669 : bill_length_mm [NUMERICAL]
    410 : bill_depth_mm [NUMERICAL]
    383 : flipper_length_mm [NUMERICAL]
    328 : island [CATEGORICAL]
    286 : body_mass_g [NUMERICAL]
    32 : sex [CATEGORICAL]
    10 : year [NUMERICAL]

Attribute in nodes with depth <= 5:
    778 : bill_length_mm [NUMERICAL]
    462 : bill_depth_mm [NUMERICAL]
    413 : flipper_length_mm [NUMERICAL]
    342 : island [CATEGORICAL]
    338 : body_mass_g [NUMERICAL]
    36 : sex [CATEGORICAL]
    19 : year [NUMERICAL]

Condition type in nodes:
    2012 : HigherCondition
    378 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
    288 : HigherCondition
    12 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
    730 : HigherCondition
    170 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
    1259 : HigherCondition
    296 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
    1758 : HigherCondition
    360 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
    2010 : HigherCondition
    378 : ContainsBitmapCondition
Node format: NOT_SET

Training OOB:
    trees: 1, Out-of-bag evaluation: accuracy:0.964286 logloss:1.28727
    trees: 11, Out-of-bag evaluation: accuracy:0.956268 logloss:0.584301
    trees: 22, Out-of-bag evaluation: accuracy:0.965116 logloss:0.378823
    trees: 35, Out-of-bag evaluation: accuracy:0.968023 logloss:0.178185
    trees: 46, Out-of-bag evaluation: accuracy:0.973837 logloss:0.170304
    trees: 58, Out-of-bag evaluation: accuracy:0.973837 logloss:0.171223
    trees: 70, Out-of-bag evaluation: accuracy:0.979651 logloss:0.169564
    trees: 83, Out-of-bag evaluation: accuracy:0.976744 logloss:0.17074
    trees: 96, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736925
    trees: 106, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0748649
    trees: 117, Out-of-bag evaluation: accuracy:0.976744 logloss:0.074671
    trees: 130, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736275
    trees: 140, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0727718
    trees: 152, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0715068
    trees: 162, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0708994
    trees: 173, Out-of-bag evaluation: accuracy:0.976744 logloss:0.069447
    trees: 184, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0695926
    trees: 195, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690138
    trees: 205, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0694597
    trees: 217, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068122
    trees: 229, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0687641
    trees: 239, Out-of-bag evaluation: accuracy:0.976744 logloss:0.067988
    trees: 250, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690187
    trees: 260, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690134
    trees: 270, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689877
    trees: 280, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689845
    trees: 290, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690742
    trees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949

Observe as várias importâncias variáveis com nome MEAN_DECREASE_IN_* .

Traçando o modelo

Em seguida, plote o modelo.

Uma floresta aleatória é um modelo grande (este modelo tem 300 árvores e ~ 5k nós; consulte o resumo acima). Portanto, plote apenas a primeira árvore e limite os nós à profundidade 3.

tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)

Inspecione a estrutura do modelo

A estrutura do modelo e meta-dados está disponível através do inspector criado por make_inspector() .

inspector = model.make_inspector()

Para nosso modelo, os campos de inspetor disponíveis são:

[field for field in dir(inspector) if not field.startswith("_")]
['MODEL_NAME',
 'dataspec',
 'evaluation',
 'export_to_tensorboard',
 'extract_all_trees',
 'extract_tree',
 'features',
 'iterate_on_nodes',
 'label',
 'label_classes',
 'model_type',
 'num_trees',
 'objective',
 'specialized_header',
 'task',
 'training_logs',
 'variable_importances',
 'winner_take_all_inference']

Lembre-se de ver a API referência ou uso ? para a documentação embutida.

?inspector.model_type

Alguns dos meta-dados do modelo:

print("Model type:", inspector.model_type())
print("Number of trees:", inspector.num_trees())
print("Objective:", inspector.objective())
print("Input features:", inspector.features())
Model type: RANDOM_FOREST
Number of trees: 300
Objective: Classification(label=__LABEL, class=None, num_classes=3)
Input features: ["bill_depth_mm" (1; #0), "bill_length_mm" (1; #1), "body_mass_g" (1; #2), "flipper_length_mm" (1; #3), "island" (4; #4), "sex" (4; #5), "year" (1; #6)]

evaluate() é a avaliação do modelo calculado durante o treinamento. O conjunto de dados usado para esta avaliação depende do algoritmo. Por exemplo, pode ser o conjunto de dados de validação ou o conjunto de dados out-of-bag.

inspector.evaluation()
Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06894904488784283, rmse=None, ndcg=None, aucs=None)

As importâncias variáveis ​​são:

print(f"Available variable importances:")
for importance in inspector.variable_importances().keys():
  print("\t", importance)
Available variable importances:
     MEAN_DECREASE_IN_AUC_3_VS_OTHERS
     NUM_AS_ROOT
     MEAN_DECREASE_IN_AUC_2_VS_OTHERS
     MEAN_DECREASE_IN_AP_2_VS_OTHERS
     MEAN_DECREASE_IN_ACCURACY
     SUM_SCORE
     MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS
     MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS
     MEAN_DECREASE_IN_AP_3_VS_OTHERS
     MEAN_DECREASE_IN_AUC_1_VS_OTHERS
     MEAN_MIN_DEPTH
     MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS
     NUM_NODES
     MEAN_DECREASE_IN_AP_1_VS_OTHERS

Importâncias de variáveis ​​diferentes têm semânticas diferentes. Por exemplo, um recurso com uma diminuição média da AUC de 0.05 meios que remover esta característica do conjunto de dados de treino que reduziria / ferir a AUC de 5%.

# Mean decrease in AUC of the class 1 vs the others.
inspector.variable_importances()["MEAN_DECREASE_IN_AUC_1_VS_OTHERS"]
[("bill_length_mm" (1; #1), 0.0713061951754389),
 ("island" (4; #4), 0.007298519736842035),
 ("flipper_length_mm" (1; #3), 0.004505893640351366),
 ("bill_depth_mm" (1; #0), 0.0021244517543865804),
 ("body_mass_g" (1; #2), 0.0005482456140351033),
 ("sex" (4; #5), 0.00047971491228060437),
 ("year" (1; #6), 0.0)]

Por fim, acesse a estrutura de árvore real:

inspector.extract_tree(tree_idx=0)
Tree(NonLeafNode(condition=(bill_length_mm >= 43.25; miss=True), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True), pos_child=NonLeafNode(condition=(bill_depth_mm >= 17.225584030151367; miss=False), pos_child=LeafNode(value=ProbabilityValue([0.16666666666666666, 0.0, 0.8333333333333334],n=6.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=104.0)), value=ProbabilityValue([0.00909090909090909, 0.0, 0.990909090909091],n=110.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0)), value=ProbabilityValue([0.005847953216374269, 0.3567251461988304, 0.6374269005847953],n=171.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.100000381469727; miss=True), pos_child=NonLeafNode(condition=(flipper_length_mm >= 187.5; miss=True), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=104.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 42.30000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 40.55000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0)), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=53.0)), value=ProbabilityValue([0.9827586206896551, 0.017241379310344827, 0.0],n=58.0)), value=ProbabilityValue([0.9047619047619048, 0.09523809523809523, 0.0],n=63.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=6.0)), value=ProbabilityValue([0.930635838150289, 0.03468208092485549, 0.03468208092485549],n=173.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)),label_classes={self.label_classes})

Extrair uma árvore não é eficiente. Se a velocidade é importante, a inspeção modelo pode ser feito com os iterate_on_nodes() método em vez. Este método é um iterador de percursos de pré-encomenda de profundidade em todos os nós do modelo.

O exemplo a seguir calcula quantas vezes cada recurso é usado (este é um tipo de importância de variável estrutural):

# number_of_use[F] will be the number of node using feature F in its condition.
number_of_use = collections.defaultdict(lambda: 0)

# Iterate over all the nodes in a Depth First Pre-order traversals.
for node_iter in inspector.iterate_on_nodes():

  if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):
    # Skip the leaf nodes
    continue

  # Iterate over all the features used in the condition.
  # By default, models are "oblique" i.e. each node tests a single feature.
  for feature in node_iter.node.condition.features():
    number_of_use[feature] += 1

print("Number of condition nodes per features:")
for feature, count in number_of_use.items():
  print("\t", feature.name, ":", count)
Number of condition nodes per features:
     bill_length_mm : 778
     bill_depth_mm : 463
     flipper_length_mm : 414
     island : 342
     body_mass_g : 338
     year : 19
     sex : 36

Criação de um modelo à mão

Nesta seção, você criará um pequeno modelo de floresta aleatória manualmente. Para facilitar ainda mais, o modelo conterá apenas uma árvore simples:

3 label classes: Red, blue and green.
2 features: f1 (numerical) and f2 (string categorical)

f1>=1.5
    ├─(pos)─ f2 in ["cat","dog"]
    │         ├─(pos)─ value: [0.8, 0.1, 0.1]
    │         └─(neg)─ value: [0.1, 0.8, 0.1]
    └─(neg)─ value: [0.1, 0.1, 0.8]
# Create the model builder
builder = tfdf.builder.RandomForestBuilder(
    path="/tmp/manual_model",
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="color", classes=["red", "blue", "green"]))

Cada árvore é adicionada uma a uma.

# So alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue

builder.add_tree(
    Tree(
        NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(name="f1", type=ColumnType.NUMERICAL),
                threshold=1.5,
                missing_evaluation=False),
            pos_child=NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="f2",type=ColumnType.CATEGORICAL),
                    mask=["cat", "dog"],
                    missing_evaluation=False),
                pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),
                neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),
            neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))

Conclua a escrita da árvore

builder.close()
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
[INFO kernel.cc:848] Use fast generic engine
2021-11-08 12:19:14.555155: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/manual_model/assets
INFO:tensorflow:Assets written to: /tmp/manual_model/assets

Agora você pode abrir o modelo como um modelo keras regular e fazer previsões:

manual_model = tf.keras.models.load_model("/tmp/manual_model")
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
[INFO kernel.cc:848] Use fast generic engine
examples = tf.data.Dataset.from_tensor_slices({
        "f1": [1.0, 2.0, 3.0],
        "f2": ["cat", "cat", "bird"]
    }).batch(2)

predictions = manual_model.predict(examples)

print("predictions:\n",predictions)
predictions:
 [[0.1 0.1 0.8]
 [0.8 0.1 0.1]
 [0.1 0.8 0.1]]

Acesse a estrutura:

yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode("utf-8")
print("yggdrasil_model_path:",yggdrasil_model_path)

inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)
print("Input features:", inspector.features())
yggdrasil_model_path: /tmp/manual_model/assets/
Input features: ["f1" (1; #1), "f2" (4; #2)]

E, claro, você pode plotar este modelo construído manualmente:

tfdf.model_plotter.plot_model_in_colab(manual_model)