ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
ใน colab นี้ คุณจะได้เรียนรู้วิธีตรวจสอบและสร้างโครงสร้างของแบบจำลองโดยตรง เราถือว่าคุณมีความคุ้นเคยกับแนวคิดที่นำมาใช้ใน การเริ่มต้น และ ขั้นกลาง colabs
ใน colab นี้ คุณจะ:
ฝึกโมเดล Random Forest และเข้าถึงโครงสร้างโดยทางโปรแกรม
สร้างแบบจำลอง Random Forest ด้วยมือและใช้เป็นแบบจำลองคลาสสิก
ติดตั้ง
# Install TensorFlow Dececision Forests.pip install tensorflow_decision_forests# Use wurlitzer to capture training logs.pip install wurlitzer
import tensorflow_decision_forests as tfdf
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import collections
try:
from wurlitzer import sys_pipes
except:
from colabtools.googlelog import CaptureLog as sys_pipes
from IPython.core.magic import register_line_magic
from IPython.display import Javascript
WARNING:root:Failure to load the custom c++ tensorflow ops. This error is likely caused the version of TensorFlow and TensorFlow Decision Forests are not compatible. WARNING:root:TF Parameter Server distributed training not available.
เซลล์โค้ดที่ซ่อนอยู่จะจำกัดความสูงของเอาต์พุตใน colab
# Some of the model training logs can cover the full
# screen if not compressed to a smaller viewport.
# This magic allows setting a max height for a cell.
@register_line_magic
def set_cell_height(size):
display(
Javascript("google.colab.output.setIframeHeight(0, true, {maxHeight: " +
str(size) + "})"))
ฝึก Random Forest ง่ายๆ
เราฝึกอบรมป่าสุ่มเหมือนใน Colab เริ่มต้น :
# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv
# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")
# Show the first three examples.
print(dataset_df.head(3))
# Convert the pandas dataframe into a tf dataset.
dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")
# Train the Random Forest
model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)
model.fit(x=dataset_tf)
species island bill_length_mm bill_depth_mm flipper_length_mm \
0 Adelie Torgersen 39.1 18.7 181.0
1 Adelie Torgersen 39.5 17.4 186.0
2 Adelie Torgersen 40.3 18.0 195.0
body_mass_g sex year
0 3750.0 male 2007
1 3800.0 female 2007
2 3250.0 female 2007
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1612: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
features_dataframe = dataframe.drop(label, 1)
6/6 [==============================] - 4s 17ms/step
[INFO kernel.cc:736] Start Yggdrasil model training
[INFO kernel.cc:737] Collect training examples
[INFO kernel.cc:392] Number of batches: 6
[INFO kernel.cc:393] Number of examples: 344
[INFO kernel.cc:759] Dataset:
Number of records: 344
Number of columns: 8
Number of columns by type:
NUMERICAL: 5 (62.5%)
CATEGORICAL: 3 (37.5%)
Columns:
NUMERICAL: 5 (62.5%)
0: "bill_depth_mm" NUMERICAL num-nas:2 (0.581395%) mean:17.1512 min:13.1 max:21.5 sd:1.9719
1: "bill_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:43.9219 min:32.1 max:59.6 sd:5.4516
2: "body_mass_g" NUMERICAL num-nas:2 (0.581395%) mean:4201.75 min:2700 max:6300 sd:800.781
3: "flipper_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:200.915 min:172 max:231 sd:14.0411
6: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.817166
CATEGORICAL: 3 (37.5%)
4: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 168 (48.8372%)
5: "sex" CATEGORICAL num-nas:11 (3.19767%) has-dict vocab-size:3 zero-ood-items most-frequent:"male" 168 (50.4505%)
7: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item
Terminology:
nas: Number of non-available (i.e. missing) values.
ood: Out of dictionary.
manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
tokenized: The attribute value is obtained through tokenization.
has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
vocab-size: Number of unique values.
[INFO kernel.cc:762] Configure learner
[INFO kernel.cc:787] Training config:
learner: "RANDOM_FOREST"
features: "bill_depth_mm"
features: "bill_length_mm"
features: "body_mass_g"
features: "flipper_length_mm"
features: "island"
features: "sex"
features: "year"
label: "__LABEL"
task: CLASSIFICATION
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
num_trees: 300
decision_tree {
max_depth: 16
min_examples: 5
in_split_min_examples_check: true
missing_value_policy: GLOBAL_IMPUTATION
allow_na_conditions: false
categorical_set_greedy_forward {
sampling: 0.1
max_num_items: -1
min_item_frequency: 1
}
growing_strategy_local {
}
categorical {
cart {
}
}
num_candidate_attributes_ratio: -1
axis_aligned_split {
}
internal {
sorting_strategy: PRESORTED
}
}
winner_take_all_inference: true
compute_oob_performances: true
compute_oob_variable_importances: true
adapt_bootstrap_size_ratio_for_maximum_training_duration: false
}
[INFO kernel.cc:790] Deployment config:
num_threads: 6
[INFO kernel.cc:817] Train model
[INFO random_forest.cc:315] Training random forest on 344 example(s) and 7 feature(s).
[INFO random_forest.cc:628] Training of tree 1/300 (tree index:0) done accuracy:0.964286 logloss:1.28727
[INFO random_forest.cc:628] Training of tree 11/300 (tree index:10) done accuracy:0.956268 logloss:0.584301
[INFO random_forest.cc:628] Training of tree 22/300 (tree index:21) done accuracy:0.965116 logloss:0.378823
[INFO random_forest.cc:628] Training of tree 35/300 (tree index:34) done accuracy:0.968023 logloss:0.178185
[INFO random_forest.cc:628] Training of tree 46/300 (tree index:45) done accuracy:0.973837 logloss:0.170304
[INFO random_forest.cc:628] Training of tree 58/300 (tree index:57) done accuracy:0.973837 logloss:0.171223
[INFO random_forest.cc:628] Training of tree 70/300 (tree index:69) done accuracy:0.979651 logloss:0.169564
[INFO random_forest.cc:628] Training of tree 83/300 (tree index:82) done accuracy:0.976744 logloss:0.17074
[INFO random_forest.cc:628] Training of tree 96/300 (tree index:95) done accuracy:0.976744 logloss:0.0736925
[INFO random_forest.cc:628] Training of tree 106/300 (tree index:105) done accuracy:0.976744 logloss:0.0748649
[INFO random_forest.cc:628] Training of tree 117/300 (tree index:116) done accuracy:0.976744 logloss:0.074671
[INFO random_forest.cc:628] Training of tree 130/300 (tree index:129) done accuracy:0.976744 logloss:0.0736275
[INFO random_forest.cc:628] Training of tree 140/300 (tree index:139) done accuracy:0.976744 logloss:0.0727718
[INFO random_forest.cc:628] Training of tree 152/300 (tree index:151) done accuracy:0.976744 logloss:0.0715068
[INFO random_forest.cc:628] Training of tree 162/300 (tree index:161) done accuracy:0.976744 logloss:0.0708994
[INFO random_forest.cc:628] Training of tree 173/300 (tree index:172) done accuracy:0.976744 logloss:0.069447
[INFO random_forest.cc:628] Training of tree 184/300 (tree index:183) done accuracy:0.976744 logloss:0.0695926
[INFO random_forest.cc:628] Training of tree 195/300 (tree index:194) done accuracy:0.976744 logloss:0.0690138
[INFO random_forest.cc:628] Training of tree 205/300 (tree index:204) done accuracy:0.976744 logloss:0.0694597
[INFO random_forest.cc:628] Training of tree 217/300 (tree index:216) done accuracy:0.976744 logloss:0.068122
[INFO random_forest.cc:628] Training of tree 229/300 (tree index:228) done accuracy:0.976744 logloss:0.0687641
[INFO random_forest.cc:628] Training of tree 239/300 (tree index:238) done accuracy:0.976744 logloss:0.067988
[INFO random_forest.cc:628] Training of tree 250/300 (tree index:249) done accuracy:0.976744 logloss:0.0690187
[INFO random_forest.cc:628] Training of tree 260/300 (tree index:259) done accuracy:0.976744 logloss:0.0690134
[INFO random_forest.cc:628] Training of tree 270/300 (tree index:269) done accuracy:0.976744 logloss:0.0689877
[INFO random_forest.cc:628] Training of tree 280/300 (tree index:279) done accuracy:0.976744 logloss:0.0689845
[INFO random_forest.cc:628] Training of tree 290/300 (tree index:288) done accuracy:0.976744 logloss:0.0690742
[INFO random_forest.cc:628] Training of tree 300/300 (tree index:299) done accuracy:0.976744 logloss:0.068949
[INFO random_forest.cc:696] Final OOB metrics: accuracy:0.976744 logloss:0.068949
[INFO kernel.cc:828] Export model in log directory: /tmp/tmpoqki9pfl
[INFO kernel.cc:836] Save model in resources
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 300 root(s), 5080 node(s), and 7 input feature(s).
[INFO abstract_model.cc:993] Engine "RandomForestGeneric" built
[INFO kernel.cc:848] Use fast generic engine
<keras.callbacks.History at 0x7f09eaa9cb90>
หมายเหตุ compute_oob_variable_importances=True Hyper-พารามิเตอร์ในตัวสร้างแบบจำลอง ความคิดเห็นนี้คำนวณความสำคัญของตัวแปร Out-of-bag (OOB) ระหว่างการฝึก นี้เป็นที่นิยม สำคัญตัวแปรการเปลี่ยนแปลง สำหรับรุ่นที่สุ่มป่า
การคำนวณความสำคัญของตัวแปร OOB จะไม่ส่งผลกระทบต่อโมเดลสุดท้าย แต่จะทำให้การฝึกอบรมบนชุดข้อมูลขนาดใหญ่ช้าลง
ตรวจสอบสรุปแบบจำลอง:
%set_cell_height 300
model.summary()
<IPython.core.display.Javascript object>
Model: "random_forest_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"
Input Features (7):
bill_depth_mm
bill_length_mm
body_mass_g
flipper_length_mm
island
sex
year
No weights
Variable Importance: MEAN_DECREASE_IN_ACCURACY:
1. "bill_length_mm" 0.151163 ################
2. "island" 0.008721 #
3. "bill_depth_mm" 0.000000
4. "body_mass_g" 0.000000
5. "sex" 0.000000
6. "year" 0.000000
7. "flipper_length_mm" -0.002907
Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS:
1. "bill_length_mm" 0.083305 ################
2. "island" 0.007664 #
3. "flipper_length_mm" 0.003400
4. "bill_depth_mm" 0.002741
5. "body_mass_g" 0.000722
6. "sex" 0.000644
7. "year" 0.000000
Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS:
1. "bill_length_mm" 0.508510 ################
2. "island" 0.023487
3. "bill_depth_mm" 0.007744
4. "flipper_length_mm" 0.006008
5. "body_mass_g" 0.003017
6. "sex" 0.001537
7. "year" -0.000245
Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS:
1. "island" 0.002192 ################
2. "bill_length_mm" 0.001572 ############
3. "bill_depth_mm" 0.000497 #######
4. "sex" 0.000000 ####
5. "year" 0.000000 ####
6. "body_mass_g" -0.000053 ####
7. "flipper_length_mm" -0.000890
Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS:
1. "bill_length_mm" 0.071306 ################
2. "island" 0.007299 #
3. "flipper_length_mm" 0.004506 #
4. "bill_depth_mm" 0.002124
5. "body_mass_g" 0.000548
6. "sex" 0.000480
7. "year" 0.000000
Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS:
1. "bill_length_mm" 0.108642 ################
2. "island" 0.014493 ##
3. "bill_depth_mm" 0.007406 #
4. "flipper_length_mm" 0.005195
5. "body_mass_g" 0.001012
6. "sex" 0.000480
7. "year" -0.000053
Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS:
1. "island" 0.002126 ################
2. "bill_length_mm" 0.001393 ###########
3. "bill_depth_mm" 0.000293 #####
4. "sex" 0.000000 ###
5. "year" 0.000000 ###
6. "body_mass_g" -0.000037 ###
7. "flipper_length_mm" -0.000550
Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS:
1. "bill_length_mm" 0.083122 ################
2. "island" 0.010887 ##
3. "flipper_length_mm" 0.003425
4. "bill_depth_mm" 0.002731
5. "body_mass_g" 0.000719
6. "sex" 0.000641
7. "year" 0.000000
Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS:
1. "bill_length_mm" 0.497611 ################
2. "island" 0.024045
3. "bill_depth_mm" 0.007734
4. "flipper_length_mm" 0.006017
5. "body_mass_g" 0.003000
6. "sex" 0.001528
7. "year" -0.000243
Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS:
1. "island" 0.002187 ################
2. "bill_length_mm" 0.001568 ############
3. "bill_depth_mm" 0.000495 #######
4. "sex" 0.000000 ####
5. "year" 0.000000 ####
6. "body_mass_g" -0.000053 ####
7. "flipper_length_mm" -0.000886
Variable Importance: MEAN_MIN_DEPTH:
1. "__LABEL" 3.479602 ################
2. "year" 3.463891 ###############
3. "sex" 3.430498 ###############
4. "body_mass_g" 2.898112 ###########
5. "island" 2.388925 ########
6. "bill_depth_mm" 2.336100 #######
7. "bill_length_mm" 1.282960
8. "flipper_length_mm" 1.270079
Variable Importance: NUM_AS_ROOT:
1. "flipper_length_mm" 157.000000 ################
2. "bill_length_mm" 76.000000 #######
3. "bill_depth_mm" 52.000000 #####
4. "island" 12.000000
5. "body_mass_g" 3.000000
Variable Importance: NUM_NODES:
1. "bill_length_mm" 778.000000 ################
2. "bill_depth_mm" 463.000000 #########
3. "flipper_length_mm" 414.000000 ########
4. "island" 342.000000 ######
5. "body_mass_g" 338.000000 ######
6. "sex" 36.000000
7. "year" 19.000000
Variable Importance: SUM_SCORE:
1. "bill_length_mm" 36515.793787 ################
2. "flipper_length_mm" 35120.434174 ###############
3. "island" 14669.408395 ######
4. "bill_depth_mm" 14515.446617 ######
5. "body_mass_g" 3485.330881 #
6. "sex" 354.201073
7. "year" 49.737758
Winner take all: true
Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949
Number of trees: 300
Total number of nodes: 5080
Number of nodes by tree:
Count: 300 Average: 16.9333 StdDev: 3.10197
Min: 11 Max: 31 Ignored: 0
----------------------------------------------
[ 11, 12) 6 2.00% 2.00% #
[ 12, 13) 0 0.00% 2.00%
[ 13, 14) 46 15.33% 17.33% #####
[ 14, 15) 0 0.00% 17.33%
[ 15, 16) 70 23.33% 40.67% ########
[ 16, 17) 0 0.00% 40.67%
[ 17, 18) 84 28.00% 68.67% ##########
[ 18, 19) 0 0.00% 68.67%
[ 19, 20) 46 15.33% 84.00% #####
[ 20, 21) 0 0.00% 84.00%
[ 21, 22) 30 10.00% 94.00% ####
[ 22, 23) 0 0.00% 94.00%
[ 23, 24) 13 4.33% 98.33% ##
[ 24, 25) 0 0.00% 98.33%
[ 25, 26) 2 0.67% 99.00%
[ 26, 27) 0 0.00% 99.00%
[ 27, 28) 2 0.67% 99.67%
[ 28, 29) 0 0.00% 99.67%
[ 29, 30) 0 0.00% 99.67%
[ 30, 31] 1 0.33% 100.00%
Depth by leafs:
Count: 2690 Average: 3.53271 StdDev: 1.06789
Min: 2 Max: 7 Ignored: 0
----------------------------------------------
[ 2, 3) 545 20.26% 20.26% ######
[ 3, 4) 747 27.77% 48.03% ########
[ 4, 5) 888 33.01% 81.04% ##########
[ 5, 6) 444 16.51% 97.55% #####
[ 6, 7) 62 2.30% 99.85% #
[ 7, 7] 4 0.15% 100.00%
Number of training obs by leaf:
Count: 2690 Average: 38.3643 StdDev: 44.8651
Min: 5 Max: 155 Ignored: 0
----------------------------------------------
[ 5, 12) 1474 54.80% 54.80% ##########
[ 12, 20) 124 4.61% 59.41% #
[ 20, 27) 48 1.78% 61.19%
[ 27, 35) 74 2.75% 63.94% #
[ 35, 42) 58 2.16% 66.10%
[ 42, 50) 85 3.16% 69.26% #
[ 50, 57) 96 3.57% 72.83% #
[ 57, 65) 87 3.23% 76.06% #
[ 65, 72) 49 1.82% 77.88%
[ 72, 80) 23 0.86% 78.74%
[ 80, 88) 30 1.12% 79.85%
[ 88, 95) 23 0.86% 80.71%
[ 95, 103) 42 1.56% 82.27%
[ 103, 110) 62 2.30% 84.57%
[ 110, 118) 115 4.28% 88.85% #
[ 118, 125) 115 4.28% 93.12% #
[ 125, 133) 98 3.64% 96.77% #
[ 133, 140) 49 1.82% 98.59%
[ 140, 148) 31 1.15% 99.74%
[ 148, 155] 7 0.26% 100.00%
Attribute in nodes:
778 : bill_length_mm [NUMERICAL]
463 : bill_depth_mm [NUMERICAL]
414 : flipper_length_mm [NUMERICAL]
342 : island [CATEGORICAL]
338 : body_mass_g [NUMERICAL]
36 : sex [CATEGORICAL]
19 : year [NUMERICAL]
Attribute in nodes with depth <= 0:
157 : flipper_length_mm [NUMERICAL]
76 : bill_length_mm [NUMERICAL]
52 : bill_depth_mm [NUMERICAL]
12 : island [CATEGORICAL]
3 : body_mass_g [NUMERICAL]
Attribute in nodes with depth <= 1:
250 : bill_length_mm [NUMERICAL]
244 : flipper_length_mm [NUMERICAL]
183 : bill_depth_mm [NUMERICAL]
170 : island [CATEGORICAL]
53 : body_mass_g [NUMERICAL]
Attribute in nodes with depth <= 2:
462 : bill_length_mm [NUMERICAL]
320 : flipper_length_mm [NUMERICAL]
310 : bill_depth_mm [NUMERICAL]
287 : island [CATEGORICAL]
162 : body_mass_g [NUMERICAL]
9 : sex [CATEGORICAL]
5 : year [NUMERICAL]
Attribute in nodes with depth <= 3:
669 : bill_length_mm [NUMERICAL]
410 : bill_depth_mm [NUMERICAL]
383 : flipper_length_mm [NUMERICAL]
328 : island [CATEGORICAL]
286 : body_mass_g [NUMERICAL]
32 : sex [CATEGORICAL]
10 : year [NUMERICAL]
Attribute in nodes with depth <= 5:
778 : bill_length_mm [NUMERICAL]
462 : bill_depth_mm [NUMERICAL]
413 : flipper_length_mm [NUMERICAL]
342 : island [CATEGORICAL]
338 : body_mass_g [NUMERICAL]
36 : sex [CATEGORICAL]
19 : year [NUMERICAL]
Condition type in nodes:
2012 : HigherCondition
378 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
288 : HigherCondition
12 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
730 : HigherCondition
170 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
1259 : HigherCondition
296 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
1758 : HigherCondition
360 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
2010 : HigherCondition
378 : ContainsBitmapCondition
Node format: NOT_SET
Training OOB:
trees: 1, Out-of-bag evaluation: accuracy:0.964286 logloss:1.28727
trees: 11, Out-of-bag evaluation: accuracy:0.956268 logloss:0.584301
trees: 22, Out-of-bag evaluation: accuracy:0.965116 logloss:0.378823
trees: 35, Out-of-bag evaluation: accuracy:0.968023 logloss:0.178185
trees: 46, Out-of-bag evaluation: accuracy:0.973837 logloss:0.170304
trees: 58, Out-of-bag evaluation: accuracy:0.973837 logloss:0.171223
trees: 70, Out-of-bag evaluation: accuracy:0.979651 logloss:0.169564
trees: 83, Out-of-bag evaluation: accuracy:0.976744 logloss:0.17074
trees: 96, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736925
trees: 106, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0748649
trees: 117, Out-of-bag evaluation: accuracy:0.976744 logloss:0.074671
trees: 130, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736275
trees: 140, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0727718
trees: 152, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0715068
trees: 162, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0708994
trees: 173, Out-of-bag evaluation: accuracy:0.976744 logloss:0.069447
trees: 184, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0695926
trees: 195, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690138
trees: 205, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0694597
trees: 217, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068122
trees: 229, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0687641
trees: 239, Out-of-bag evaluation: accuracy:0.976744 logloss:0.067988
trees: 250, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690187
trees: 260, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690134
trees: 270, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689877
trees: 280, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689845
trees: 290, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690742
trees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949
หมายเหตุ importances ตัวแปรหลายที่มีชื่อ MEAN_DECREASE_IN_* *
วางโมเดล
ถัดไป พล็อตโมเดล
Random Forest เป็นโมเดลขนาดใหญ่ (โมเดลนี้มีต้นไม้ 300 ต้นและโหนดประมาณ 5,000 โหนด ดูข้อมูลสรุปด้านบน) ดังนั้น ให้พล็อตเฉพาะต้นไม้ต้นแรก และจำกัดโหนดไว้ที่ความลึก 3
tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)
ตรวจสอบโครงสร้างโมเดล
โครงสร้างและรูปแบบข้อมูล meta สามารถใช้ได้ผ่านการตรวจสอบที่สร้างขึ้นโดย make_inspector()
inspector = model.make_inspector()
สำหรับโมเดลของเรา ฟิลด์ตัวตรวจสอบที่มีอยู่คือ:
[field for field in dir(inspector) if not field.startswith("_")]
['MODEL_NAME', 'dataspec', 'evaluation', 'export_to_tensorboard', 'extract_all_trees', 'extract_tree', 'features', 'iterate_on_nodes', 'label', 'label_classes', 'model_type', 'num_trees', 'objective', 'specialized_header', 'task', 'training_logs', 'variable_importances', 'winner_take_all_inference']
อย่าลืมดู API การอ้างอิง หรือการใช้งาน ? สำหรับเอกสารในตัว
?inspector.model_type
ข้อมูลเมตาของโมเดลบางส่วน:
print("Model type:", inspector.model_type())
print("Number of trees:", inspector.num_trees())
print("Objective:", inspector.objective())
print("Input features:", inspector.features())
Model type: RANDOM_FOREST Number of trees: 300 Objective: Classification(label=__LABEL, class=None, num_classes=3) Input features: ["bill_depth_mm" (1; #0), "bill_length_mm" (1; #1), "body_mass_g" (1; #2), "flipper_length_mm" (1; #3), "island" (4; #4), "sex" (4; #5), "year" (1; #6)]
evaluate() คือการประเมินผลของรูปแบบการคำนวณระหว่างการฝึกอบรม ชุดข้อมูลที่ใช้สำหรับการประเมินนี้ขึ้นอยู่กับอัลกอริทึม ตัวอย่างเช่น อาจเป็นชุดข้อมูลการตรวจสอบความถูกต้อง หรือชุดข้อมูล out-of-bag-dataset
inspector.evaluation()
Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06894904488784283, rmse=None, ndcg=None, aucs=None)
ความสำคัญของตัวแปรคือ:
print(f"Available variable importances:")
for importance in inspector.variable_importances().keys():
print("\t", importance)
Available variable importances:
MEAN_DECREASE_IN_AUC_3_VS_OTHERS
NUM_AS_ROOT
MEAN_DECREASE_IN_AUC_2_VS_OTHERS
MEAN_DECREASE_IN_AP_2_VS_OTHERS
MEAN_DECREASE_IN_ACCURACY
SUM_SCORE
MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS
MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS
MEAN_DECREASE_IN_AP_3_VS_OTHERS
MEAN_DECREASE_IN_AUC_1_VS_OTHERS
MEAN_MIN_DEPTH
MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS
NUM_NODES
MEAN_DECREASE_IN_AP_1_VS_OTHERS
ตัวแปรที่มีความสำคัญต่างกันมีความหมายต่างกัน ตัวอย่างเช่นภาพยนตร์ที่มีการลดลงของค่าเฉลี่ยใน AUC ของ 0.05 หมายความว่าเอาคุณลักษณะนี้จากชุดข้อมูลการฝึกอบรมจะลด / เจ็บ AUC 5%
# Mean decrease in AUC of the class 1 vs the others.
inspector.variable_importances()["MEAN_DECREASE_IN_AUC_1_VS_OTHERS"]
[("bill_length_mm" (1; #1), 0.0713061951754389),
("island" (4; #4), 0.007298519736842035),
("flipper_length_mm" (1; #3), 0.004505893640351366),
("bill_depth_mm" (1; #0), 0.0021244517543865804),
("body_mass_g" (1; #2), 0.0005482456140351033),
("sex" (4; #5), 0.00047971491228060437),
("year" (1; #6), 0.0)]
สุดท้าย เข้าถึงโครงสร้างต้นไม้จริง:
inspector.extract_tree(tree_idx=0)
Tree(NonLeafNode(condition=(bill_length_mm >= 43.25; miss=True), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True), pos_child=NonLeafNode(condition=(bill_depth_mm >= 17.225584030151367; miss=False), pos_child=LeafNode(value=ProbabilityValue([0.16666666666666666, 0.0, 0.8333333333333334],n=6.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=104.0)), value=ProbabilityValue([0.00909090909090909, 0.0, 0.990909090909091],n=110.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0)), value=ProbabilityValue([0.005847953216374269, 0.3567251461988304, 0.6374269005847953],n=171.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.100000381469727; miss=True), pos_child=NonLeafNode(condition=(flipper_length_mm >= 187.5; miss=True), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=104.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 42.30000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 40.55000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0)), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=53.0)), value=ProbabilityValue([0.9827586206896551, 0.017241379310344827, 0.0],n=58.0)), value=ProbabilityValue([0.9047619047619048, 0.09523809523809523, 0.0],n=63.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=6.0)), value=ProbabilityValue([0.930635838150289, 0.03468208092485549, 0.03468208092485549],n=173.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)),label_classes={self.label_classes})
การสกัดต้นไม้ไม่ได้ผล ถ้าความเร็วเป็นสิ่งที่สำคัญการตรวจสอบรูปแบบสามารถทำได้ด้วย iterate_on_nodes() วิธีการแทน วิธีนี้เป็นตัววนซ้ำแบบสั่งจองล่วงหน้าแบบ Depth First Pre-order บนโหนดทั้งหมดของโมเดล
สำหรับตัวอย่างต่อไปนี้จะคำนวณจำนวนครั้งที่แต่ละคุณลักษณะถูกใช้ (นี่คือชนิดของตัวแปรโครงสร้างที่สำคัญ):
# number_of_use[F] will be the number of node using feature F in its condition.
number_of_use = collections.defaultdict(lambda: 0)
# Iterate over all the nodes in a Depth First Pre-order traversals.
for node_iter in inspector.iterate_on_nodes():
if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):
# Skip the leaf nodes
continue
# Iterate over all the features used in the condition.
# By default, models are "oblique" i.e. each node tests a single feature.
for feature in node_iter.node.condition.features():
number_of_use[feature] += 1
print("Number of condition nodes per features:")
for feature, count in number_of_use.items():
print("\t", feature.name, ":", count)
Number of condition nodes per features:
bill_length_mm : 778
bill_depth_mm : 463
flipper_length_mm : 414
island : 342
body_mass_g : 338
year : 19
sex : 36
การสร้างแบบจำลองด้วยมือ
ในส่วนนี้ คุณจะต้องสร้างโมเดล Random Forest ขนาดเล็กด้วยมือ เพื่อให้ง่ายเป็นพิเศษ แบบจำลองจะมีต้นไม้อย่างง่ายเพียงต้นเดียว:
3 label classes: Red, blue and green.
2 features: f1 (numerical) and f2 (string categorical)
f1>=1.5
├─(pos)─ f2 in ["cat","dog"]
│ ├─(pos)─ value: [0.8, 0.1, 0.1]
│ └─(neg)─ value: [0.1, 0.8, 0.1]
└─(neg)─ value: [0.1, 0.1, 0.8]
# Create the model builder
builder = tfdf.builder.RandomForestBuilder(
path="/tmp/manual_model",
objective=tfdf.py_tree.objective.ClassificationObjective(
label="color", classes=["red", "blue", "green"]))
ต้นไม้แต่ละต้นจะถูกเพิ่มทีละต้น
# So alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue
builder.add_tree(
Tree(
NonLeafNode(
condition=NumericalHigherThanCondition(
feature=SimpleColumnSpec(name="f1", type=ColumnType.NUMERICAL),
threshold=1.5,
missing_evaluation=False),
pos_child=NonLeafNode(
condition=CategoricalIsInCondition(
feature=SimpleColumnSpec(name="f2",type=ColumnType.CATEGORICAL),
mask=["cat", "dog"],
missing_evaluation=False),
pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),
neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),
neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))
สรุปการเขียนต้นไม้
builder.close()
[INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s). [INFO kernel.cc:848] Use fast generic engine 2021-11-08 12:19:14.555155: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/manual_model/assets INFO:tensorflow:Assets written to: /tmp/manual_model/assets
ตอนนี้คุณสามารถเปิดโมเดลเป็นโมเดล keras ปกติ และทำการคาดการณ์:
manual_model = tf.keras.models.load_model("/tmp/manual_model")
[INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s). [INFO kernel.cc:848] Use fast generic engine
examples = tf.data.Dataset.from_tensor_slices({
"f1": [1.0, 2.0, 3.0],
"f2": ["cat", "cat", "bird"]
}).batch(2)
predictions = manual_model.predict(examples)
print("predictions:\n",predictions)
predictions: [[0.1 0.1 0.8] [0.8 0.1 0.1] [0.1 0.8 0.1]]
เข้าถึงโครงสร้าง:
yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode("utf-8")
print("yggdrasil_model_path:",yggdrasil_model_path)
inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)
print("Input features:", inspector.features())
yggdrasil_model_path: /tmp/manual_model/assets/ Input features: ["f1" (1; #1), "f2" (4; #2)]
และแน่นอน คุณสามารถพล็อตโมเดลที่สร้างด้วยตนเองนี้ได้:
tfdf.model_plotter.plot_model_in_colab(manual_model)
ดูบน TensorFlow.org
ทำงานใน Google Colab
ดูบน GitHub
ดาวน์โหลดโน๊ตบุ๊ค