Xem trên TensorFlow.org | Chạy trong Google Colab | Xem trên GitHub | Tải xuống sổ ghi chép |
Trong chuyên mục này, bạn sẽ học cách trực tiếp kiểm tra và tạo cấu trúc của một mô hình. Chúng tôi giả sử bạn đã quen thuộc với các khái niệm được giới thiệu trong người mới bắt đầu và trung gian colabs.
Trong chuyên mục này, bạn sẽ:
Đào tạo mô hình Rừng ngẫu nhiên và truy cập cấu trúc của nó theo chương trình.
Tạo mô hình Rừng Ngẫu nhiên bằng tay và sử dụng nó như một mô hình cổ điển.
Thành lập
# Install TensorFlow Dececision Forests.pip install tensorflow_decision_forests# Use wurlitzer to capture training logs.pip install wurlitzer
import tensorflow_decision_forests as tfdf
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import collections
try:
from wurlitzer import sys_pipes
except:
from colabtools.googlelog import CaptureLog as sys_pipes
from IPython.core.magic import register_line_magic
from IPython.display import Javascript
WARNING:root:Failure to load the custom c++ tensorflow ops. This error is likely caused the version of TensorFlow and TensorFlow Decision Forests are not compatible. WARNING:root:TF Parameter Server distributed training not available.
Ô mã ẩn giới hạn chiều cao đầu ra trong cột.
# Some of the model training logs can cover the full
# screen if not compressed to a smaller viewport.
# This magic allows setting a max height for a cell.
@register_line_magic
def set_cell_height(size):
display(
Javascript("google.colab.output.setIframeHeight(0, true, {maxHeight: " +
str(size) + "})"))
Huấn luyện một khu rừng ngẫu nhiên đơn giản
Chúng tôi đào tạo một rừng ngẫu nhiên như trong colab người mới bắt đầu :
# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv
# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")
# Show the first three examples.
print(dataset_df.head(3))
# Convert the pandas dataframe into a tf dataset.
dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")
# Train the Random Forest
model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)
model.fit(x=dataset_tf)
species island bill_length_mm bill_depth_mm flipper_length_mm \
0 Adelie Torgersen 39.1 18.7 181.0
1 Adelie Torgersen 39.5 17.4 186.0
2 Adelie Torgersen 40.3 18.0 195.0
body_mass_g sex year
0 3750.0 male 2007
1 3800.0 female 2007
2 3250.0 female 2007
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1612: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
features_dataframe = dataframe.drop(label, 1)
6/6 [==============================] - 4s 17ms/step
[INFO kernel.cc:736] Start Yggdrasil model training
[INFO kernel.cc:737] Collect training examples
[INFO kernel.cc:392] Number of batches: 6
[INFO kernel.cc:393] Number of examples: 344
[INFO kernel.cc:759] Dataset:
Number of records: 344
Number of columns: 8
Number of columns by type:
NUMERICAL: 5 (62.5%)
CATEGORICAL: 3 (37.5%)
Columns:
NUMERICAL: 5 (62.5%)
0: "bill_depth_mm" NUMERICAL num-nas:2 (0.581395%) mean:17.1512 min:13.1 max:21.5 sd:1.9719
1: "bill_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:43.9219 min:32.1 max:59.6 sd:5.4516
2: "body_mass_g" NUMERICAL num-nas:2 (0.581395%) mean:4201.75 min:2700 max:6300 sd:800.781
3: "flipper_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:200.915 min:172 max:231 sd:14.0411
6: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.817166
CATEGORICAL: 3 (37.5%)
4: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 168 (48.8372%)
5: "sex" CATEGORICAL num-nas:11 (3.19767%) has-dict vocab-size:3 zero-ood-items most-frequent:"male" 168 (50.4505%)
7: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item
Terminology:
nas: Number of non-available (i.e. missing) values.
ood: Out of dictionary.
manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
tokenized: The attribute value is obtained through tokenization.
has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
vocab-size: Number of unique values.
[INFO kernel.cc:762] Configure learner
[INFO kernel.cc:787] Training config:
learner: "RANDOM_FOREST"
features: "bill_depth_mm"
features: "bill_length_mm"
features: "body_mass_g"
features: "flipper_length_mm"
features: "island"
features: "sex"
features: "year"
label: "__LABEL"
task: CLASSIFICATION
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
num_trees: 300
decision_tree {
max_depth: 16
min_examples: 5
in_split_min_examples_check: true
missing_value_policy: GLOBAL_IMPUTATION
allow_na_conditions: false
categorical_set_greedy_forward {
sampling: 0.1
max_num_items: -1
min_item_frequency: 1
}
growing_strategy_local {
}
categorical {
cart {
}
}
num_candidate_attributes_ratio: -1
axis_aligned_split {
}
internal {
sorting_strategy: PRESORTED
}
}
winner_take_all_inference: true
compute_oob_performances: true
compute_oob_variable_importances: true
adapt_bootstrap_size_ratio_for_maximum_training_duration: false
}
[INFO kernel.cc:790] Deployment config:
num_threads: 6
[INFO kernel.cc:817] Train model
[INFO random_forest.cc:315] Training random forest on 344 example(s) and 7 feature(s).
[INFO random_forest.cc:628] Training of tree 1/300 (tree index:0) done accuracy:0.964286 logloss:1.28727
[INFO random_forest.cc:628] Training of tree 11/300 (tree index:10) done accuracy:0.956268 logloss:0.584301
[INFO random_forest.cc:628] Training of tree 22/300 (tree index:21) done accuracy:0.965116 logloss:0.378823
[INFO random_forest.cc:628] Training of tree 35/300 (tree index:34) done accuracy:0.968023 logloss:0.178185
[INFO random_forest.cc:628] Training of tree 46/300 (tree index:45) done accuracy:0.973837 logloss:0.170304
[INFO random_forest.cc:628] Training of tree 58/300 (tree index:57) done accuracy:0.973837 logloss:0.171223
[INFO random_forest.cc:628] Training of tree 70/300 (tree index:69) done accuracy:0.979651 logloss:0.169564
[INFO random_forest.cc:628] Training of tree 83/300 (tree index:82) done accuracy:0.976744 logloss:0.17074
[INFO random_forest.cc:628] Training of tree 96/300 (tree index:95) done accuracy:0.976744 logloss:0.0736925
[INFO random_forest.cc:628] Training of tree 106/300 (tree index:105) done accuracy:0.976744 logloss:0.0748649
[INFO random_forest.cc:628] Training of tree 117/300 (tree index:116) done accuracy:0.976744 logloss:0.074671
[INFO random_forest.cc:628] Training of tree 130/300 (tree index:129) done accuracy:0.976744 logloss:0.0736275
[INFO random_forest.cc:628] Training of tree 140/300 (tree index:139) done accuracy:0.976744 logloss:0.0727718
[INFO random_forest.cc:628] Training of tree 152/300 (tree index:151) done accuracy:0.976744 logloss:0.0715068
[INFO random_forest.cc:628] Training of tree 162/300 (tree index:161) done accuracy:0.976744 logloss:0.0708994
[INFO random_forest.cc:628] Training of tree 173/300 (tree index:172) done accuracy:0.976744 logloss:0.069447
[INFO random_forest.cc:628] Training of tree 184/300 (tree index:183) done accuracy:0.976744 logloss:0.0695926
[INFO random_forest.cc:628] Training of tree 195/300 (tree index:194) done accuracy:0.976744 logloss:0.0690138
[INFO random_forest.cc:628] Training of tree 205/300 (tree index:204) done accuracy:0.976744 logloss:0.0694597
[INFO random_forest.cc:628] Training of tree 217/300 (tree index:216) done accuracy:0.976744 logloss:0.068122
[INFO random_forest.cc:628] Training of tree 229/300 (tree index:228) done accuracy:0.976744 logloss:0.0687641
[INFO random_forest.cc:628] Training of tree 239/300 (tree index:238) done accuracy:0.976744 logloss:0.067988
[INFO random_forest.cc:628] Training of tree 250/300 (tree index:249) done accuracy:0.976744 logloss:0.0690187
[INFO random_forest.cc:628] Training of tree 260/300 (tree index:259) done accuracy:0.976744 logloss:0.0690134
[INFO random_forest.cc:628] Training of tree 270/300 (tree index:269) done accuracy:0.976744 logloss:0.0689877
[INFO random_forest.cc:628] Training of tree 280/300 (tree index:279) done accuracy:0.976744 logloss:0.0689845
[INFO random_forest.cc:628] Training of tree 290/300 (tree index:288) done accuracy:0.976744 logloss:0.0690742
[INFO random_forest.cc:628] Training of tree 300/300 (tree index:299) done accuracy:0.976744 logloss:0.068949
[INFO random_forest.cc:696] Final OOB metrics: accuracy:0.976744 logloss:0.068949
[INFO kernel.cc:828] Export model in log directory: /tmp/tmpoqki9pfl
[INFO kernel.cc:836] Save model in resources
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 300 root(s), 5080 node(s), and 7 input feature(s).
[INFO abstract_model.cc:993] Engine "RandomForestGeneric" built
[INFO kernel.cc:848] Use fast generic engine
<keras.callbacks.History at 0x7f09eaa9cb90>
Lưu ý compute_oob_variable_importances=True siêu tham số trong các nhà xây dựng mô hình. Lựa chọn này tính toán tầm quan trọng thay đổi của Chi phí bỏ ra (OOB) trong quá trình đào tạo. Đây là một phổ biến biến hoán vị quan trọng cho các mô hình rừng ngẫu nhiên.
Tính toán mức độ quan trọng của biến OOB không ảnh hưởng đến mô hình cuối cùng, nó sẽ làm chậm quá trình đào tạo trên các bộ dữ liệu lớn.
Kiểm tra tóm tắt mô hình:
%set_cell_height 300
model.summary()
<IPython.core.display.Javascript object>
Model: "random_forest_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"
Input Features (7):
bill_depth_mm
bill_length_mm
body_mass_g
flipper_length_mm
island
sex
year
No weights
Variable Importance: MEAN_DECREASE_IN_ACCURACY:
1. "bill_length_mm" 0.151163 ################
2. "island" 0.008721 #
3. "bill_depth_mm" 0.000000
4. "body_mass_g" 0.000000
5. "sex" 0.000000
6. "year" 0.000000
7. "flipper_length_mm" -0.002907
Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS:
1. "bill_length_mm" 0.083305 ################
2. "island" 0.007664 #
3. "flipper_length_mm" 0.003400
4. "bill_depth_mm" 0.002741
5. "body_mass_g" 0.000722
6. "sex" 0.000644
7. "year" 0.000000
Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS:
1. "bill_length_mm" 0.508510 ################
2. "island" 0.023487
3. "bill_depth_mm" 0.007744
4. "flipper_length_mm" 0.006008
5. "body_mass_g" 0.003017
6. "sex" 0.001537
7. "year" -0.000245
Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS:
1. "island" 0.002192 ################
2. "bill_length_mm" 0.001572 ############
3. "bill_depth_mm" 0.000497 #######
4. "sex" 0.000000 ####
5. "year" 0.000000 ####
6. "body_mass_g" -0.000053 ####
7. "flipper_length_mm" -0.000890
Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS:
1. "bill_length_mm" 0.071306 ################
2. "island" 0.007299 #
3. "flipper_length_mm" 0.004506 #
4. "bill_depth_mm" 0.002124
5. "body_mass_g" 0.000548
6. "sex" 0.000480
7. "year" 0.000000
Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS:
1. "bill_length_mm" 0.108642 ################
2. "island" 0.014493 ##
3. "bill_depth_mm" 0.007406 #
4. "flipper_length_mm" 0.005195
5. "body_mass_g" 0.001012
6. "sex" 0.000480
7. "year" -0.000053
Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS:
1. "island" 0.002126 ################
2. "bill_length_mm" 0.001393 ###########
3. "bill_depth_mm" 0.000293 #####
4. "sex" 0.000000 ###
5. "year" 0.000000 ###
6. "body_mass_g" -0.000037 ###
7. "flipper_length_mm" -0.000550
Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS:
1. "bill_length_mm" 0.083122 ################
2. "island" 0.010887 ##
3. "flipper_length_mm" 0.003425
4. "bill_depth_mm" 0.002731
5. "body_mass_g" 0.000719
6. "sex" 0.000641
7. "year" 0.000000
Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS:
1. "bill_length_mm" 0.497611 ################
2. "island" 0.024045
3. "bill_depth_mm" 0.007734
4. "flipper_length_mm" 0.006017
5. "body_mass_g" 0.003000
6. "sex" 0.001528
7. "year" -0.000243
Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS:
1. "island" 0.002187 ################
2. "bill_length_mm" 0.001568 ############
3. "bill_depth_mm" 0.000495 #######
4. "sex" 0.000000 ####
5. "year" 0.000000 ####
6. "body_mass_g" -0.000053 ####
7. "flipper_length_mm" -0.000886
Variable Importance: MEAN_MIN_DEPTH:
1. "__LABEL" 3.479602 ################
2. "year" 3.463891 ###############
3. "sex" 3.430498 ###############
4. "body_mass_g" 2.898112 ###########
5. "island" 2.388925 ########
6. "bill_depth_mm" 2.336100 #######
7. "bill_length_mm" 1.282960
8. "flipper_length_mm" 1.270079
Variable Importance: NUM_AS_ROOT:
1. "flipper_length_mm" 157.000000 ################
2. "bill_length_mm" 76.000000 #######
3. "bill_depth_mm" 52.000000 #####
4. "island" 12.000000
5. "body_mass_g" 3.000000
Variable Importance: NUM_NODES:
1. "bill_length_mm" 778.000000 ################
2. "bill_depth_mm" 463.000000 #########
3. "flipper_length_mm" 414.000000 ########
4. "island" 342.000000 ######
5. "body_mass_g" 338.000000 ######
6. "sex" 36.000000
7. "year" 19.000000
Variable Importance: SUM_SCORE:
1. "bill_length_mm" 36515.793787 ################
2. "flipper_length_mm" 35120.434174 ###############
3. "island" 14669.408395 ######
4. "bill_depth_mm" 14515.446617 ######
5. "body_mass_g" 3485.330881 #
6. "sex" 354.201073
7. "year" 49.737758
Winner take all: true
Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949
Number of trees: 300
Total number of nodes: 5080
Number of nodes by tree:
Count: 300 Average: 16.9333 StdDev: 3.10197
Min: 11 Max: 31 Ignored: 0
----------------------------------------------
[ 11, 12) 6 2.00% 2.00% #
[ 12, 13) 0 0.00% 2.00%
[ 13, 14) 46 15.33% 17.33% #####
[ 14, 15) 0 0.00% 17.33%
[ 15, 16) 70 23.33% 40.67% ########
[ 16, 17) 0 0.00% 40.67%
[ 17, 18) 84 28.00% 68.67% ##########
[ 18, 19) 0 0.00% 68.67%
[ 19, 20) 46 15.33% 84.00% #####
[ 20, 21) 0 0.00% 84.00%
[ 21, 22) 30 10.00% 94.00% ####
[ 22, 23) 0 0.00% 94.00%
[ 23, 24) 13 4.33% 98.33% ##
[ 24, 25) 0 0.00% 98.33%
[ 25, 26) 2 0.67% 99.00%
[ 26, 27) 0 0.00% 99.00%
[ 27, 28) 2 0.67% 99.67%
[ 28, 29) 0 0.00% 99.67%
[ 29, 30) 0 0.00% 99.67%
[ 30, 31] 1 0.33% 100.00%
Depth by leafs:
Count: 2690 Average: 3.53271 StdDev: 1.06789
Min: 2 Max: 7 Ignored: 0
----------------------------------------------
[ 2, 3) 545 20.26% 20.26% ######
[ 3, 4) 747 27.77% 48.03% ########
[ 4, 5) 888 33.01% 81.04% ##########
[ 5, 6) 444 16.51% 97.55% #####
[ 6, 7) 62 2.30% 99.85% #
[ 7, 7] 4 0.15% 100.00%
Number of training obs by leaf:
Count: 2690 Average: 38.3643 StdDev: 44.8651
Min: 5 Max: 155 Ignored: 0
----------------------------------------------
[ 5, 12) 1474 54.80% 54.80% ##########
[ 12, 20) 124 4.61% 59.41% #
[ 20, 27) 48 1.78% 61.19%
[ 27, 35) 74 2.75% 63.94% #
[ 35, 42) 58 2.16% 66.10%
[ 42, 50) 85 3.16% 69.26% #
[ 50, 57) 96 3.57% 72.83% #
[ 57, 65) 87 3.23% 76.06% #
[ 65, 72) 49 1.82% 77.88%
[ 72, 80) 23 0.86% 78.74%
[ 80, 88) 30 1.12% 79.85%
[ 88, 95) 23 0.86% 80.71%
[ 95, 103) 42 1.56% 82.27%
[ 103, 110) 62 2.30% 84.57%
[ 110, 118) 115 4.28% 88.85% #
[ 118, 125) 115 4.28% 93.12% #
[ 125, 133) 98 3.64% 96.77% #
[ 133, 140) 49 1.82% 98.59%
[ 140, 148) 31 1.15% 99.74%
[ 148, 155] 7 0.26% 100.00%
Attribute in nodes:
778 : bill_length_mm [NUMERICAL]
463 : bill_depth_mm [NUMERICAL]
414 : flipper_length_mm [NUMERICAL]
342 : island [CATEGORICAL]
338 : body_mass_g [NUMERICAL]
36 : sex [CATEGORICAL]
19 : year [NUMERICAL]
Attribute in nodes with depth <= 0:
157 : flipper_length_mm [NUMERICAL]
76 : bill_length_mm [NUMERICAL]
52 : bill_depth_mm [NUMERICAL]
12 : island [CATEGORICAL]
3 : body_mass_g [NUMERICAL]
Attribute in nodes with depth <= 1:
250 : bill_length_mm [NUMERICAL]
244 : flipper_length_mm [NUMERICAL]
183 : bill_depth_mm [NUMERICAL]
170 : island [CATEGORICAL]
53 : body_mass_g [NUMERICAL]
Attribute in nodes with depth <= 2:
462 : bill_length_mm [NUMERICAL]
320 : flipper_length_mm [NUMERICAL]
310 : bill_depth_mm [NUMERICAL]
287 : island [CATEGORICAL]
162 : body_mass_g [NUMERICAL]
9 : sex [CATEGORICAL]
5 : year [NUMERICAL]
Attribute in nodes with depth <= 3:
669 : bill_length_mm [NUMERICAL]
410 : bill_depth_mm [NUMERICAL]
383 : flipper_length_mm [NUMERICAL]
328 : island [CATEGORICAL]
286 : body_mass_g [NUMERICAL]
32 : sex [CATEGORICAL]
10 : year [NUMERICAL]
Attribute in nodes with depth <= 5:
778 : bill_length_mm [NUMERICAL]
462 : bill_depth_mm [NUMERICAL]
413 : flipper_length_mm [NUMERICAL]
342 : island [CATEGORICAL]
338 : body_mass_g [NUMERICAL]
36 : sex [CATEGORICAL]
19 : year [NUMERICAL]
Condition type in nodes:
2012 : HigherCondition
378 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
288 : HigherCondition
12 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
730 : HigherCondition
170 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
1259 : HigherCondition
296 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
1758 : HigherCondition
360 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
2010 : HigherCondition
378 : ContainsBitmapCondition
Node format: NOT_SET
Training OOB:
trees: 1, Out-of-bag evaluation: accuracy:0.964286 logloss:1.28727
trees: 11, Out-of-bag evaluation: accuracy:0.956268 logloss:0.584301
trees: 22, Out-of-bag evaluation: accuracy:0.965116 logloss:0.378823
trees: 35, Out-of-bag evaluation: accuracy:0.968023 logloss:0.178185
trees: 46, Out-of-bag evaluation: accuracy:0.973837 logloss:0.170304
trees: 58, Out-of-bag evaluation: accuracy:0.973837 logloss:0.171223
trees: 70, Out-of-bag evaluation: accuracy:0.979651 logloss:0.169564
trees: 83, Out-of-bag evaluation: accuracy:0.976744 logloss:0.17074
trees: 96, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736925
trees: 106, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0748649
trees: 117, Out-of-bag evaluation: accuracy:0.976744 logloss:0.074671
trees: 130, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736275
trees: 140, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0727718
trees: 152, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0715068
trees: 162, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0708994
trees: 173, Out-of-bag evaluation: accuracy:0.976744 logloss:0.069447
trees: 184, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0695926
trees: 195, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690138
trees: 205, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0694597
trees: 217, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068122
trees: 229, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0687641
trees: 239, Out-of-bag evaluation: accuracy:0.976744 logloss:0.067988
trees: 250, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690187
trees: 260, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690134
trees: 270, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689877
trees: 280, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689845
trees: 290, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690742
trees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949
Lưu ý nhiều importances biến với tên MEAN_DECREASE_IN_* .
Vẽ mô hình
Tiếp theo, vẽ mô hình.
Rừng Ngẫu nhiên là một mô hình lớn (mô hình này có 300 cây và ~ 5k nút; xem phần tóm tắt ở trên). Do đó, chỉ vẽ cây đầu tiên và giới hạn các nút ở độ sâu 3.
tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)
Kiểm tra cấu trúc mô hình
Cấu trúc mô hình và siêu dữ liệu có sẵn thông qua thanh tra được tạo ra bởi make_inspector() .
inspector = model.make_inspector()
Đối với mô hình của chúng tôi, các trường thanh tra có sẵn là:
[field for field in dir(inspector) if not field.startswith("_")]
['MODEL_NAME', 'dataspec', 'evaluation', 'export_to_tensorboard', 'extract_all_trees', 'extract_tree', 'features', 'iterate_on_nodes', 'label', 'label_classes', 'model_type', 'num_trees', 'objective', 'specialized_header', 'task', 'training_logs', 'variable_importances', 'winner_take_all_inference']
Ghi để xem API-tài liệu tham khảo hoặc sử dụng ? cho tài liệu nội trang.
?inspector.model_type
Một số siêu dữ liệu mô hình:
print("Model type:", inspector.model_type())
print("Number of trees:", inspector.num_trees())
print("Objective:", inspector.objective())
print("Input features:", inspector.features())
Model type: RANDOM_FOREST Number of trees: 300 Objective: Classification(label=__LABEL, class=None, num_classes=3) Input features: ["bill_depth_mm" (1; #0), "bill_length_mm" (1; #1), "body_mass_g" (1; #2), "flipper_length_mm" (1; #3), "island" (4; #4), "sex" (4; #5), "year" (1; #6)]
evaluate() là việc đánh giá các mô hình tính toán trong đào tạo. Tập dữ liệu được sử dụng để đánh giá này phụ thuộc vào thuật toán. Ví dụ: nó có thể là tập dữ liệu xác thực hoặc tập dữ liệu ngoài túi.
inspector.evaluation()
Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06894904488784283, rmse=None, ndcg=None, aucs=None)
Nhập khẩu có thể thay đổi là:
print(f"Available variable importances:")
for importance in inspector.variable_importances().keys():
print("\t", importance)
Available variable importances:
MEAN_DECREASE_IN_AUC_3_VS_OTHERS
NUM_AS_ROOT
MEAN_DECREASE_IN_AUC_2_VS_OTHERS
MEAN_DECREASE_IN_AP_2_VS_OTHERS
MEAN_DECREASE_IN_ACCURACY
SUM_SCORE
MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS
MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS
MEAN_DECREASE_IN_AP_3_VS_OTHERS
MEAN_DECREASE_IN_AUC_1_VS_OTHERS
MEAN_MIN_DEPTH
MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS
NUM_NODES
MEAN_DECREASE_IN_AP_1_VS_OTHERS
Các biến nhập khẩu khác nhau có ngữ nghĩa khác nhau. Ví dụ, một tính năng với sự sụt giảm trung bình trong AUC của 0.05 có nghĩa rằng loại bỏ tính năng này từ các tập dữ liệu đào tạo sẽ làm giảm / làm tổn thương AUC 5%.
# Mean decrease in AUC of the class 1 vs the others.
inspector.variable_importances()["MEAN_DECREASE_IN_AUC_1_VS_OTHERS"]
[("bill_length_mm" (1; #1), 0.0713061951754389),
("island" (4; #4), 0.007298519736842035),
("flipper_length_mm" (1; #3), 0.004505893640351366),
("bill_depth_mm" (1; #0), 0.0021244517543865804),
("body_mass_g" (1; #2), 0.0005482456140351033),
("sex" (4; #5), 0.00047971491228060437),
("year" (1; #6), 0.0)]
Cuối cùng, truy cập cấu trúc cây thực tế:
inspector.extract_tree(tree_idx=0)
Tree(NonLeafNode(condition=(bill_length_mm >= 43.25; miss=True), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True), pos_child=NonLeafNode(condition=(bill_depth_mm >= 17.225584030151367; miss=False), pos_child=LeafNode(value=ProbabilityValue([0.16666666666666666, 0.0, 0.8333333333333334],n=6.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=104.0)), value=ProbabilityValue([0.00909090909090909, 0.0, 0.990909090909091],n=110.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0)), value=ProbabilityValue([0.005847953216374269, 0.3567251461988304, 0.6374269005847953],n=171.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.100000381469727; miss=True), pos_child=NonLeafNode(condition=(flipper_length_mm >= 187.5; miss=True), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=104.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 42.30000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 40.55000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0)), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=53.0)), value=ProbabilityValue([0.9827586206896551, 0.017241379310344827, 0.0],n=58.0)), value=ProbabilityValue([0.9047619047619048, 0.09523809523809523, 0.0],n=63.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=6.0)), value=ProbabilityValue([0.930635838150289, 0.03468208092485549, 0.03468208092485549],n=173.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)),label_classes={self.label_classes})
Việc chiết một cây không hiệu quả. Nếu tốc độ là quan trọng, công tác kiểm tra mô hình có thể được thực hiện với sự iterate_on_nodes() phương pháp thay thế. Phương pháp này là một trình vòng lặp đặt hàng trước theo chiều sâu trên tất cả các nút của mô hình.
Ví dụ sau tính toán số lần mỗi tính năng được sử dụng (đây là một loại tầm quan trọng của biến cấu trúc):
# number_of_use[F] will be the number of node using feature F in its condition.
number_of_use = collections.defaultdict(lambda: 0)
# Iterate over all the nodes in a Depth First Pre-order traversals.
for node_iter in inspector.iterate_on_nodes():
if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):
# Skip the leaf nodes
continue
# Iterate over all the features used in the condition.
# By default, models are "oblique" i.e. each node tests a single feature.
for feature in node_iter.node.condition.features():
number_of_use[feature] += 1
print("Number of condition nodes per features:")
for feature, count in number_of_use.items():
print("\t", feature.name, ":", count)
Number of condition nodes per features:
bill_length_mm : 778
bill_depth_mm : 463
flipper_length_mm : 414
island : 342
body_mass_g : 338
year : 19
sex : 36
Tạo mô hình bằng tay
Trong phần này, bạn sẽ tạo một mô hình Rừng Ngẫu nhiên nhỏ bằng tay. Để dễ dàng hơn, mô hình sẽ chỉ chứa một cây đơn giản:
3 label classes: Red, blue and green.
2 features: f1 (numerical) and f2 (string categorical)
f1>=1.5
├─(pos)─ f2 in ["cat","dog"]
│ ├─(pos)─ value: [0.8, 0.1, 0.1]
│ └─(neg)─ value: [0.1, 0.8, 0.1]
└─(neg)─ value: [0.1, 0.1, 0.8]
# Create the model builder
builder = tfdf.builder.RandomForestBuilder(
path="/tmp/manual_model",
objective=tfdf.py_tree.objective.ClassificationObjective(
label="color", classes=["red", "blue", "green"]))
Mỗi cây được thêm một.
# So alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue
builder.add_tree(
Tree(
NonLeafNode(
condition=NumericalHigherThanCondition(
feature=SimpleColumnSpec(name="f1", type=ColumnType.NUMERICAL),
threshold=1.5,
missing_evaluation=False),
pos_child=NonLeafNode(
condition=CategoricalIsInCondition(
feature=SimpleColumnSpec(name="f2",type=ColumnType.CATEGORICAL),
mask=["cat", "dog"],
missing_evaluation=False),
pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),
neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),
neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))
Kết luận cây viết
builder.close()
[INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s). [INFO kernel.cc:848] Use fast generic engine 2021-11-08 12:19:14.555155: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/manual_model/assets INFO:tensorflow:Assets written to: /tmp/manual_model/assets
Bây giờ bạn có thể mở mô hình như một mô hình keras thông thường và đưa ra dự đoán:
manual_model = tf.keras.models.load_model("/tmp/manual_model")
[INFO kernel.cc:988] Loading model from path [INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s). [INFO kernel.cc:848] Use fast generic engine
examples = tf.data.Dataset.from_tensor_slices({
"f1": [1.0, 2.0, 3.0],
"f2": ["cat", "cat", "bird"]
}).batch(2)
predictions = manual_model.predict(examples)
print("predictions:\n",predictions)
predictions: [[0.1 0.1 0.8] [0.8 0.1 0.1] [0.1 0.8 0.1]]
Truy cập cấu trúc:
yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode("utf-8")
print("yggdrasil_model_path:",yggdrasil_model_path)
inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)
print("Input features:", inspector.features())
yggdrasil_model_path: /tmp/manual_model/assets/ Input features: ["f1" (1; #1), "f2" (4; #2)]
Và tất nhiên, bạn có thể vẽ mô hình được xây dựng thủ công này:
tfdf.model_plotter.plot_model_in_colab(manual_model)
Xem trên TensorFlow.org
Chạy trong Google Colab
Xem trên GitHub
Tải xuống sổ ghi chép