Ver no TensorFlow.org | Executar no Google Colab | Ver no GitHub | Baixar caderno |
Neste colab, você aprenderá como inspecionar e criar a estrutura de um modelo diretamente. Vamos supor que você está familiarizado com os conceitos introduzidos na iniciante e intermediário colabs.
Nesta colab, você irá:
Treine um modelo Random Forest e acesse sua estrutura de maneira programática.
Crie um modelo Random Forest manualmente e use-o como um modelo clássico.
Configurar
# Install TensorFlow Dececision Forests.
pip install tensorflow_decision_forests
# Use wurlitzer to capture training logs.
pip install wurlitzer
import tensorflow_decision_forests as tfdf
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import collections
try:
from wurlitzer import sys_pipes
except:
from colabtools.googlelog import CaptureLog as sys_pipes
from IPython.core.magic import register_line_magic
from IPython.display import Javascript
WARNING:root:Failure to load the custom c++ tensorflow ops. This error is likely caused the version of TensorFlow and TensorFlow Decision Forests are not compatible. WARNING:root:TF Parameter Server distributed training not available.
A célula de código oculta limita a altura de saída na colab.
# Some of the model training logs can cover the full
# screen if not compressed to a smaller viewport.
# This magic allows setting a max height for a cell.
@register_line_magic
def set_cell_height(size):
display(
Javascript("google.colab.output.setIframeHeight(0, true, {maxHeight: " +
str(size) + "})"))
Treine uma floresta aleatória simples
Nós treinamos uma floresta aleatória como no colab iniciante :
# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv
# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")
# Show the first three examples.
print(dataset_df.head(3))
# Convert the pandas dataframe into a tf dataset.
dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")
# Train the Random Forest
model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)
model.fit(x=dataset_tf)
species island bill_length_mm bill_depth_mm flipper_length_mm \ 0 Adelie Torgersen 39.1 18.7 181.0 1 Adelie Torgersen 39.5 17.4 186.0 2 Adelie Torgersen 40.3 18.0 195.0 body_mass_g sex year 0 3750.0 male 2007 1 3800.0 female 2007 2 3250.0 female 2007 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1612: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only features_dataframe = dataframe.drop(label, 1) 6/6 [==============================] - 4s 17ms/step [INFO kernel.cc:736] Start Yggdrasil model training [INFO kernel.cc:737] Collect training examples [INFO kernel.cc:392] Number of batches: 6 [INFO kernel.cc:393] Number of examples: 344 [INFO kernel.cc:759] Dataset: Number of records: 344 Number of columns: 8 Number of columns by type: NUMERICAL: 5 (62.5%) CATEGORICAL: 3 (37.5%) Columns: NUMERICAL: 5 (62.5%) 0: "bill_depth_mm" NUMERICAL num-nas:2 (0.581395%) mean:17.1512 min:13.1 max:21.5 sd:1.9719 1: "bill_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:43.9219 min:32.1 max:59.6 sd:5.4516 2: "body_mass_g" NUMERICAL num-nas:2 (0.581395%) mean:4201.75 min:2700 max:6300 sd:800.781 3: "flipper_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:200.915 min:172 max:231 sd:14.0411 6: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.817166 CATEGORICAL: 3 (37.5%) 4: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 168 (48.8372%) 5: "sex" CATEGORICAL num-nas:11 (3.19767%) has-dict vocab-size:3 zero-ood-items most-frequent:"male" 168 (50.4505%) 7: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values. [INFO kernel.cc:762] Configure learner [INFO kernel.cc:787] Training config: learner: "RANDOM_FOREST" features: "bill_depth_mm" features: "bill_length_mm" features: "body_mass_g" features: "flipper_length_mm" features: "island" features: "sex" features: "year" label: "__LABEL" task: CLASSIFICATION [yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] { num_trees: 300 decision_tree { max_depth: 16 min_examples: 5 in_split_min_examples_check: true missing_value_policy: GLOBAL_IMPUTATION allow_na_conditions: false categorical_set_greedy_forward { sampling: 0.1 max_num_items: -1 min_item_frequency: 1 } growing_strategy_local { } categorical { cart { } } num_candidate_attributes_ratio: -1 axis_aligned_split { } internal { sorting_strategy: PRESORTED } } winner_take_all_inference: true compute_oob_performances: true compute_oob_variable_importances: true adapt_bootstrap_size_ratio_for_maximum_training_duration: false } [INFO kernel.cc:790] Deployment config: num_threads: 6 [INFO kernel.cc:817] Train model [INFO random_forest.cc:315] Training random forest on 344 example(s) and 7 feature(s). [INFO random_forest.cc:628] Training of tree 1/300 (tree index:0) done accuracy:0.964286 logloss:1.28727 [INFO random_forest.cc:628] Training of tree 11/300 (tree index:10) done accuracy:0.956268 logloss:0.584301 [INFO random_forest.cc:628] Training of tree 22/300 (tree index:21) done accuracy:0.965116 logloss:0.378823 [INFO random_forest.cc:628] Training of tree 35/300 (tree index:34) done accuracy:0.968023 logloss:0.178185 [INFO random_forest.cc:628] Training of tree 46/300 (tree index:45) done accuracy:0.973837 logloss:0.170304 [INFO random_forest.cc:628] Training of tree 58/300 (tree index:57) done accuracy:0.973837 logloss:0.171223 [INFO random_forest.cc:628] Training of tree 70/300 (tree index:69) done accuracy:0.979651 logloss:0.169564 [INFO random_forest.cc:628] Training of tree 83/300 (tree index:82) done accuracy:0.976744 logloss:0.17074 [INFO random_forest.cc:628] Training of tree 96/300 (tree index:95) done accuracy:0.976744 logloss:0.0736925 [INFO random_forest.cc:628] Training of tree 106/300 (tree index:105) done accuracy:0.976744 logloss:0.0748649 [INFO random_forest.cc:628] Training of tree 117/300 (tree index:116) done accuracy:0.976744 logloss:0.074671 [INFO random_forest.cc:628] Training of tree 130/300 (tree index:129) done accuracy:0.976744 logloss:0.0736275 [INFO random_forest.cc:628] Training of tree 140/300 (tree index:139) done accuracy:0.976744 logloss:0.0727718 [INFO random_forest.cc:628] Training of tree 152/300 (tree index:151) done accuracy:0.976744 logloss:0.0715068 [INFO random_forest.cc:628] Training of tree 162/300 (tree index:161) done accuracy:0.976744 logloss:0.0708994 [INFO random_forest.cc:628] Training of tree 173/300 (tree index:172) done accuracy:0.976744 logloss:0.069447 [INFO random_forest.cc:628] Training of tree 184/300 (tree index:183) done accuracy:0.976744 logloss:0.0695926 [INFO random_forest.cc:628] Training of tree 195/300 (tree index:194) done accuracy:0.976744 logloss:0.0690138 [INFO random_forest.cc:628] Training of tree 205/300 (tree index:204) done accuracy:0.976744 logloss:0.0694597 [INFO random_forest.cc:628] Training of tree 217/300 (tree index:216) done accuracy:0.976744 logloss:0.068122 [INFO random_forest.cc:628] Training of tree 229/300 (tree index:228) done accuracy:0.976744 logloss:0.0687641 [INFO random_forest.cc:628] Training of tree 239/300 (tree index:238) done accuracy:0.976744 logloss:0.067988 [INFO random_forest.cc:628] Training of tree 250/300 (tree index:249) done accuracy:0.976744 logloss:0.0690187 [INFO random_forest.cc:628] Training of tree 260/300 (tree index:259) done accuracy:0.976744 logloss:0.0690134 [INFO random_forest.cc:628] Training of tree 270/300 (tree index:269) done accuracy:0.976744 logloss:0.0689877 [INFO random_forest.cc:628] Training of tree 280/300 (tree index:279) done accuracy:0.976744 logloss:0.0689845 [INFO random_forest.cc:628] Training of tree 290/300 (tree index:288) done accuracy:0.976744 logloss:0.0690742 [INFO random_forest.cc:628] Training of tree 300/300 (tree index:299) done accuracy:0.976744 logloss:0.068949 [INFO random_forest.cc:696] Final OOB metrics: accuracy:0.976744 logloss:0.068949 [INFO kernel.cc:828] Export model in log directory: /tmp/tmpoqki9pfl [INFO kernel.cc:836] Save model in resources [INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 300 root(s), 5080 node(s), and 7 input feature(s). [INFO abstract_model.cc:993] Engine "RandomForestGeneric" built [INFO kernel.cc:848] Use fast generic engine <keras.callbacks.History at 0x7f09eaa9cb90>
Observe o compute_oob_variable_importances=True
hiper-parâmetro no construtor modelo. Esta opção calcula a importância da variável Out-of-bag (OOB) durante o treinamento. Este é um popular importância variável permutação para modelos aleatória Floresta.
Calcular a importância da variável OOB não afeta o modelo final, ele irá retardar o treinamento em grandes conjuntos de dados.
Verifique o resumo do modelo:
%set_cell_height 300
model.summary()
<IPython.core.display.Javascript object> Model: "random_forest_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= ================================================================= Total params: 1 Trainable params: 0 Non-trainable params: 1 _________________________________________________________________ Type: "RANDOM_FOREST" Task: CLASSIFICATION Label: "__LABEL" Input Features (7): bill_depth_mm bill_length_mm body_mass_g flipper_length_mm island sex year No weights Variable Importance: MEAN_DECREASE_IN_ACCURACY: 1. "bill_length_mm" 0.151163 ################ 2. "island" 0.008721 # 3. "bill_depth_mm" 0.000000 4. "body_mass_g" 0.000000 5. "sex" 0.000000 6. "year" 0.000000 7. "flipper_length_mm" -0.002907 Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS: 1. "bill_length_mm" 0.083305 ################ 2. "island" 0.007664 # 3. "flipper_length_mm" 0.003400 4. "bill_depth_mm" 0.002741 5. "body_mass_g" 0.000722 6. "sex" 0.000644 7. "year" 0.000000 Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS: 1. "bill_length_mm" 0.508510 ################ 2. "island" 0.023487 3. "bill_depth_mm" 0.007744 4. "flipper_length_mm" 0.006008 5. "body_mass_g" 0.003017 6. "sex" 0.001537 7. "year" -0.000245 Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS: 1. "island" 0.002192 ################ 2. "bill_length_mm" 0.001572 ############ 3. "bill_depth_mm" 0.000497 ####### 4. "sex" 0.000000 #### 5. "year" 0.000000 #### 6. "body_mass_g" -0.000053 #### 7. "flipper_length_mm" -0.000890 Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS: 1. "bill_length_mm" 0.071306 ################ 2. "island" 0.007299 # 3. "flipper_length_mm" 0.004506 # 4. "bill_depth_mm" 0.002124 5. "body_mass_g" 0.000548 6. "sex" 0.000480 7. "year" 0.000000 Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS: 1. "bill_length_mm" 0.108642 ################ 2. "island" 0.014493 ## 3. "bill_depth_mm" 0.007406 # 4. "flipper_length_mm" 0.005195 5. "body_mass_g" 0.001012 6. "sex" 0.000480 7. "year" -0.000053 Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS: 1. "island" 0.002126 ################ 2. "bill_length_mm" 0.001393 ########### 3. "bill_depth_mm" 0.000293 ##### 4. "sex" 0.000000 ### 5. "year" 0.000000 ### 6. "body_mass_g" -0.000037 ### 7. "flipper_length_mm" -0.000550 Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS: 1. "bill_length_mm" 0.083122 ################ 2. "island" 0.010887 ## 3. "flipper_length_mm" 0.003425 4. "bill_depth_mm" 0.002731 5. "body_mass_g" 0.000719 6. "sex" 0.000641 7. "year" 0.000000 Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS: 1. "bill_length_mm" 0.497611 ################ 2. "island" 0.024045 3. "bill_depth_mm" 0.007734 4. "flipper_length_mm" 0.006017 5. "body_mass_g" 0.003000 6. "sex" 0.001528 7. "year" -0.000243 Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS: 1. "island" 0.002187 ################ 2. "bill_length_mm" 0.001568 ############ 3. "bill_depth_mm" 0.000495 ####### 4. "sex" 0.000000 #### 5. "year" 0.000000 #### 6. "body_mass_g" -0.000053 #### 7. "flipper_length_mm" -0.000886 Variable Importance: MEAN_MIN_DEPTH: 1. "__LABEL" 3.479602 ################ 2. "year" 3.463891 ############### 3. "sex" 3.430498 ############### 4. "body_mass_g" 2.898112 ########### 5. "island" 2.388925 ######## 6. "bill_depth_mm" 2.336100 ####### 7. "bill_length_mm" 1.282960 8. "flipper_length_mm" 1.270079 Variable Importance: NUM_AS_ROOT: 1. "flipper_length_mm" 157.000000 ################ 2. "bill_length_mm" 76.000000 ####### 3. "bill_depth_mm" 52.000000 ##### 4. "island" 12.000000 5. "body_mass_g" 3.000000 Variable Importance: NUM_NODES: 1. "bill_length_mm" 778.000000 ################ 2. "bill_depth_mm" 463.000000 ######### 3. "flipper_length_mm" 414.000000 ######## 4. "island" 342.000000 ###### 5. "body_mass_g" 338.000000 ###### 6. "sex" 36.000000 7. "year" 19.000000 Variable Importance: SUM_SCORE: 1. "bill_length_mm" 36515.793787 ################ 2. "flipper_length_mm" 35120.434174 ############### 3. "island" 14669.408395 ###### 4. "bill_depth_mm" 14515.446617 ###### 5. "body_mass_g" 3485.330881 # 6. "sex" 354.201073 7. "year" 49.737758 Winner take all: true Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949 Number of trees: 300 Total number of nodes: 5080 Number of nodes by tree: Count: 300 Average: 16.9333 StdDev: 3.10197 Min: 11 Max: 31 Ignored: 0 ---------------------------------------------- [ 11, 12) 6 2.00% 2.00% # [ 12, 13) 0 0.00% 2.00% [ 13, 14) 46 15.33% 17.33% ##### [ 14, 15) 0 0.00% 17.33% [ 15, 16) 70 23.33% 40.67% ######## [ 16, 17) 0 0.00% 40.67% [ 17, 18) 84 28.00% 68.67% ########## [ 18, 19) 0 0.00% 68.67% [ 19, 20) 46 15.33% 84.00% ##### [ 20, 21) 0 0.00% 84.00% [ 21, 22) 30 10.00% 94.00% #### [ 22, 23) 0 0.00% 94.00% [ 23, 24) 13 4.33% 98.33% ## [ 24, 25) 0 0.00% 98.33% [ 25, 26) 2 0.67% 99.00% [ 26, 27) 0 0.00% 99.00% [ 27, 28) 2 0.67% 99.67% [ 28, 29) 0 0.00% 99.67% [ 29, 30) 0 0.00% 99.67% [ 30, 31] 1 0.33% 100.00% Depth by leafs: Count: 2690 Average: 3.53271 StdDev: 1.06789 Min: 2 Max: 7 Ignored: 0 ---------------------------------------------- [ 2, 3) 545 20.26% 20.26% ###### [ 3, 4) 747 27.77% 48.03% ######## [ 4, 5) 888 33.01% 81.04% ########## [ 5, 6) 444 16.51% 97.55% ##### [ 6, 7) 62 2.30% 99.85% # [ 7, 7] 4 0.15% 100.00% Number of training obs by leaf: Count: 2690 Average: 38.3643 StdDev: 44.8651 Min: 5 Max: 155 Ignored: 0 ---------------------------------------------- [ 5, 12) 1474 54.80% 54.80% ########## [ 12, 20) 124 4.61% 59.41% # [ 20, 27) 48 1.78% 61.19% [ 27, 35) 74 2.75% 63.94% # [ 35, 42) 58 2.16% 66.10% [ 42, 50) 85 3.16% 69.26% # [ 50, 57) 96 3.57% 72.83% # [ 57, 65) 87 3.23% 76.06% # [ 65, 72) 49 1.82% 77.88% [ 72, 80) 23 0.86% 78.74% [ 80, 88) 30 1.12% 79.85% [ 88, 95) 23 0.86% 80.71% [ 95, 103) 42 1.56% 82.27% [ 103, 110) 62 2.30% 84.57% [ 110, 118) 115 4.28% 88.85% # [ 118, 125) 115 4.28% 93.12% # [ 125, 133) 98 3.64% 96.77% # [ 133, 140) 49 1.82% 98.59% [ 140, 148) 31 1.15% 99.74% [ 148, 155] 7 0.26% 100.00% Attribute in nodes: 778 : bill_length_mm [NUMERICAL] 463 : bill_depth_mm [NUMERICAL] 414 : flipper_length_mm [NUMERICAL] 342 : island [CATEGORICAL] 338 : body_mass_g [NUMERICAL] 36 : sex [CATEGORICAL] 19 : year [NUMERICAL] Attribute in nodes with depth <= 0: 157 : flipper_length_mm [NUMERICAL] 76 : bill_length_mm [NUMERICAL] 52 : bill_depth_mm [NUMERICAL] 12 : island [CATEGORICAL] 3 : body_mass_g [NUMERICAL] Attribute in nodes with depth <= 1: 250 : bill_length_mm [NUMERICAL] 244 : flipper_length_mm [NUMERICAL] 183 : bill_depth_mm [NUMERICAL] 170 : island [CATEGORICAL] 53 : body_mass_g [NUMERICAL] Attribute in nodes with depth <= 2: 462 : bill_length_mm [NUMERICAL] 320 : flipper_length_mm [NUMERICAL] 310 : bill_depth_mm [NUMERICAL] 287 : island [CATEGORICAL] 162 : body_mass_g [NUMERICAL] 9 : sex [CATEGORICAL] 5 : year [NUMERICAL] Attribute in nodes with depth <= 3: 669 : bill_length_mm [NUMERICAL] 410 : bill_depth_mm [NUMERICAL] 383 : flipper_length_mm [NUMERICAL] 328 : island [CATEGORICAL] 286 : body_mass_g [NUMERICAL] 32 : sex [CATEGORICAL] 10 : year [NUMERICAL] Attribute in nodes with depth <= 5: 778 : bill_length_mm [NUMERICAL] 462 : bill_depth_mm [NUMERICAL] 413 : flipper_length_mm [NUMERICAL] 342 : island [CATEGORICAL] 338 : body_mass_g [NUMERICAL] 36 : sex [CATEGORICAL] 19 : year [NUMERICAL] Condition type in nodes: 2012 : HigherCondition 378 : ContainsBitmapCondition Condition type in nodes with depth <= 0: 288 : HigherCondition 12 : ContainsBitmapCondition Condition type in nodes with depth <= 1: 730 : HigherCondition 170 : ContainsBitmapCondition Condition type in nodes with depth <= 2: 1259 : HigherCondition 296 : ContainsBitmapCondition Condition type in nodes with depth <= 3: 1758 : HigherCondition 360 : ContainsBitmapCondition Condition type in nodes with depth <= 5: 2010 : HigherCondition 378 : ContainsBitmapCondition Node format: NOT_SET Training OOB: trees: 1, Out-of-bag evaluation: accuracy:0.964286 logloss:1.28727 trees: 11, Out-of-bag evaluation: accuracy:0.956268 logloss:0.584301 trees: 22, Out-of-bag evaluation: accuracy:0.965116 logloss:0.378823 trees: 35, Out-of-bag evaluation: accuracy:0.968023 logloss:0.178185 trees: 46, Out-of-bag evaluation: accuracy:0.973837 logloss:0.170304 trees: 58, Out-of-bag evaluation: accuracy:0.973837 logloss:0.171223 trees: 70, Out-of-bag evaluation: accuracy:0.979651 logloss:0.169564 trees: 83, Out-of-bag evaluation: accuracy:0.976744 logloss:0.17074 trees: 96, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736925 trees: 106, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0748649 trees: 117, Out-of-bag evaluation: accuracy:0.976744 logloss:0.074671 trees: 130, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736275 trees: 140, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0727718 trees: 152, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0715068 trees: 162, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0708994 trees: 173, Out-of-bag evaluation: accuracy:0.976744 logloss:0.069447 trees: 184, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0695926 trees: 195, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690138 trees: 205, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0694597 trees: 217, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068122 trees: 229, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0687641 trees: 239, Out-of-bag evaluation: accuracy:0.976744 logloss:0.067988 trees: 250, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690187 trees: 260, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690134 trees: 270, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689877 trees: 280, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689845 trees: 290, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690742 trees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949
Observe as várias importâncias variáveis com nome MEAN_DECREASE_IN_*
.
Traçando o modelo
Em seguida, plote o modelo.
Uma floresta aleatória é um modelo grande (este modelo tem 300 árvores e ~ 5k nós; consulte o resumo acima). Portanto, plote apenas a primeira árvore e limite os nós à profundidade 3.
tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)
Inspecione a estrutura do modelo
A estrutura do modelo e meta-dados está disponível através do inspector criado por make_inspector()
.
inspector = model.make_inspector()
Para nosso modelo, os campos de inspetor disponíveis são:
[field for field in dir(inspector) if not field.startswith("_")]
['MODEL_NAME', 'dataspec', 'evaluation', 'export_to_tensorboard', 'extract_all_trees', 'extract_tree', 'features', 'iterate_on_nodes', 'label', 'label_classes', 'model_type', 'num_trees', 'objective', 'specialized_header', 'task', 'training_logs', 'variable_importances', 'winner_take_all_inference']
Lembre-se de ver a API referência ou uso ?
para a documentação embutida.
?inspector.model_type
Alguns dos meta-dados do modelo:
print("Model type:", inspector.model_type())
print("Number of trees:", inspector.num_trees())
print("Objective:", inspector.objective())
print("Input features:", inspector.features())
Model type: RANDOM_FOREST Number of trees: 300 Objective: Classification(label=__LABEL, class=None, num_classes=3) Input features: ["bill_depth_mm" (1; #0), "bill_length_mm" (1; #1), "body_mass_g" (1; #2), "flipper_length_mm" (1; #3), "island" (4; #4), "sex" (4; #5), "year" (1; #6)]
evaluate()
é a avaliação do modelo calculado durante o treinamento. O conjunto de dados usado para esta avaliação depende do algoritmo. Por exemplo, pode ser o conjunto de dados de validação ou o conjunto de dados out-of-bag.
inspector.evaluation()
Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06894904488784283, rmse=None, ndcg=None, aucs=None)
As importâncias variáveis são:
print(f"Available variable importances:")
for importance in inspector.variable_importances().keys():
print("\t", importance)
Available variable importances: MEAN_DECREASE_IN_AUC_3_VS_OTHERS NUM_AS_ROOT MEAN_DECREASE_IN_AUC_2_VS_OTHERS MEAN_DECREASE_IN_AP_2_VS_OTHERS MEAN_DECREASE_IN_ACCURACY SUM_SCORE MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS MEAN_DECREASE_IN_AP_3_VS_OTHERS MEAN_DECREASE_IN_AUC_1_VS_OTHERS MEAN_MIN_DEPTH MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS NUM_NODES MEAN_DECREASE_IN_AP_1_VS_OTHERS
Importâncias de variáveis diferentes têm semânticas diferentes. Por exemplo, um recurso com uma diminuição média da AUC de 0.05
meios que remover esta característica do conjunto de dados de treino que reduziria / ferir a AUC de 5%.
# Mean decrease in AUC of the class 1 vs the others.
inspector.variable_importances()["MEAN_DECREASE_IN_AUC_1_VS_OTHERS"]
[("bill_length_mm" (1; #1), 0.0713061951754389), ("island" (4; #4), 0.007298519736842035), ("flipper_length_mm" (1; #3), 0.004505893640351366), ("bill_depth_mm" (1; #0), 0.0021244517543865804), ("body_mass_g" (1; #2), 0.0005482456140351033), ("sex" (4; #5), 0.00047971491228060437), ("year" (1; #6), 0.0)]
Por fim, acesse a estrutura de árvore real:
inspector.extract_tree(tree_idx=0)
Tree(NonLeafNode(condition=(bill_length_mm >= 43.25; miss=True), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True), pos_child=NonLeafNode(condition=(bill_depth_mm >= 17.225584030151367; miss=False), pos_child=LeafNode(value=ProbabilityValue([0.16666666666666666, 0.0, 0.8333333333333334],n=6.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=104.0)), value=ProbabilityValue([0.00909090909090909, 0.0, 0.990909090909091],n=110.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0)), value=ProbabilityValue([0.005847953216374269, 0.3567251461988304, 0.6374269005847953],n=171.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.100000381469727; miss=True), pos_child=NonLeafNode(condition=(flipper_length_mm >= 187.5; miss=True), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=104.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 42.30000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 40.55000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0)), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=53.0)), value=ProbabilityValue([0.9827586206896551, 0.017241379310344827, 0.0],n=58.0)), value=ProbabilityValue([0.9047619047619048, 0.09523809523809523, 0.0],n=63.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=6.0)), value=ProbabilityValue([0.930635838150289, 0.03468208092485549, 0.03468208092485549],n=173.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)),label_classes={self.label_classes})
Extrair uma árvore não é eficiente. Se a velocidade é importante, a inspeção modelo pode ser feito com os iterate_on_nodes()
método em vez. Este método é um iterador de percursos de pré-encomenda de profundidade em todos os nós do modelo.
O exemplo a seguir calcula quantas vezes cada recurso é usado (este é um tipo de importância de variável estrutural):
# number_of_use[F] will be the number of node using feature F in its condition.
number_of_use = collections.defaultdict(lambda: 0)
# Iterate over all the nodes in a Depth First Pre-order traversals.
for node_iter in inspector.iterate_on_nodes():
if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):
# Skip the leaf nodes
continue
# Iterate over all the features used in the condition.
# By default, models are "oblique" i.e. each node tests a single feature.
for feature in node_iter.node.condition.features():
number_of_use[feature] += 1
print("Number of condition nodes per features:")
for feature, count in number_of_use.items():
print("\t", feature.name, ":", count)
Number of condition nodes per features: bill_length_mm : 778 bill_depth_mm : 463 flipper_length_mm : 414 island : 342 body_mass_g : 338 year : 19 sex : 36
Criação de um modelo à mão
Nesta seção, você criará um pequeno modelo de floresta aleatória manualmente. Para facilitar ainda mais, o modelo conterá apenas uma árvore simples:
3 label classes: Red, blue and green.
2 features: f1 (numerical) and f2 (string categorical)
f1>=1.5
├─(pos)─ f2 in ["cat","dog"]
│ ├─(pos)─ value: [0.8, 0.1, 0.1]
│ └─(neg)─ value: [0.1, 0.8, 0.1]
└─(neg)─ value: [0.1, 0.1, 0.8]
# Create the model builder
builder = tfdf.builder.RandomForestBuilder(
path="/tmp/manual_model",
objective=tfdf.py_tree.objective.ClassificationObjective(
label="color", classes=["red", "blue", "green"]))
Cada árvore é adicionada uma a uma.
# So alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue
builder.add_tree(
Tree(
NonLeafNode(
condition=NumericalHigherThanCondition(
feature=SimpleColumnSpec(name="f1", type=ColumnType.NUMERICAL),
threshold=1.5,
missing_evaluation=False),
pos_child=NonLeafNode(
condition=CategoricalIsInCondition(
feature=SimpleColumnSpec(name="f2",type=ColumnType.CATEGORICAL),
mask=["cat", "dog"],
missing_evaluation=False),
pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),
neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),
neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))
Conclua a escrita da árvore
builder.close()
[INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s). [INFO kernel.cc:848] Use fast generic engine 2021-11-08 12:19:14.555155: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/manual_model/assets INFO:tensorflow:Assets written to: /tmp/manual_model/assets
Agora você pode abrir o modelo como um modelo keras regular e fazer previsões:
manual_model = tf.keras.models.load_model("/tmp/manual_model")
[INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s). [INFO kernel.cc:848] Use fast generic engine
examples = tf.data.Dataset.from_tensor_slices({
"f1": [1.0, 2.0, 3.0],
"f2": ["cat", "cat", "bird"]
}).batch(2)
predictions = manual_model.predict(examples)
print("predictions:\n",predictions)
predictions: [[0.1 0.1 0.8] [0.8 0.1 0.1] [0.1 0.8 0.1]]
Acesse a estrutura:
yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode("utf-8")
print("yggdrasil_model_path:",yggdrasil_model_path)
inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)
print("Input features:", inspector.features())
yggdrasil_model_path: /tmp/manual_model/assets/ Input features: ["f1" (1; #1), "f2" (4; #2)]
E, claro, você pode plotar este modelo construído manualmente:
tfdf.model_plotter.plot_model_in_colab(manual_model)