Loading [MathJax]/jax/output/CommonHTML/jax.js

Decision tree: structure, training and Python code

What a decision tree is

A decision tree is a non-parametric supervised machine learning model that predicts outputs by extracting various conditions from the data that an input may or may not meet.

Decision tree structure

  • The root node is the whole input dataset.
  • Splitting is dividing an input dataset into subgroups (called nodes).
  • Branches are the possible choices or values that the attribute can take. They start from a node; usually, the left branch “is true” and the right one “is false”.
  • Leaves (or terminal nodes) represent the final predictions associated with a certain path of decisions.

The CART algorithm: how decision trees work

Problem statement

We have a labeled dataset and want to predict the labels by building a tree-like structure.

1. Extract thresholds and split nodes

For each feature in the dataset, the algorithm iterates through each value and uses it as a threshold to split the dataset into two nodes.

If the current feature is numerical, the algorithm assigns data points to the nodes whether their feature value is smaller or equal to the threshold.

If the current feature is categorical, the algorithm assigns data points to the nodes whether their feature value is equal to the threshold.

2. Choose the optimal thresholds

For each threshold, the algorithm calculates the impurity of its sub-nodes.

The impurity is a parameter that expresses the similarity of the outputs of a node.

Why is impurity important?

Impurity is used to measure the quality of a split.
If the impurity is low => the samples in the node are similar => the threshold is good because it highlights a characteristic that these data points share.

So when we are trying to predict an unseen value, and we see that it has this characteristic, it makes sense to predict an output similar to the other samples in the node.

2.1 Impurity in classification

In classification, impurity is calculated using the Gini impurity parameter, which expresses the homogeneity of the samples in a node.

GinyImpurity=1(N1N)2(N2N)2

Where:

  • N = total number of samples in the node.
  • Nn = total number of samples in the node of class n.

Curiosity: why do we square the probabilities?

We square the probabilities because otherwise, the total sum would always equal 1.

To understand, the lower the Giny Impurity of a node, the more examples in that node are of the same class.

2.2 Impurity in regression

The impurity of a node in a regression problem is equal to the mean squared residual.

For each data point in the node, the square residual is:

Squaredresidual=(ymeany)2

Where:

  • y = output of the data point,
  • mean y = average value of all the node outputs.

Choose the threshold

This is the formula called weighted impurity:

weightedimpurity=impurity1N1N+impurity2N2N

Where:

  • N is the number of data points in a node.

The threshold with the lowest weighted impurity is chosen to split the dataset.

3. Stop the tree from growing

The splitting process of a node ends when its impurity is minimal, or when its number of data points in the value of the hyperparameter min_samples_leaf. At this point, the node becomes a leaf.

The learning process stops entirely (the tree stops growing) when the value of the max_depth or max_leaves_num is reached.

4. Predict outputs

For each input, the model follows a decision path. The prediction value is the average output of the leave in which the input falls.

Model main hyperparameters

These are the model’s main hyperparameters taken from the scikit-learn documentation.

  • max_depth: The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.
  • max_leaves_num: The maximum number of leaves in the model.
  • min_samples_split: The minimum number of samples required to split an internal node.
  • min_samples_leaf: The minimum number of samples required to be at a leaf node.

Decision tree advantages and disadvantages

Advantages

  • Simple to understand and to interpret. Trees can be visualized.
  • Requires little feature engineering. Other techniques often require data normalization, dummy variables need to be created and blank values need to be removed. Some tree and algorithm combinations support missing values.
  • Able to handle both numerical and categorical data.
  • Able to handle multi-output problems.
  • Uses a white box model. If a given situation is observable in a model, the explanation for the condition is easily explained by boolean logic (true or false). By contrast, results may be more difficult to interpret in a black box model (e.g., in an artificial neural network).
  • Possible to validate a model using statistical tests. That makes it possible to account for the reliability of the model.
  • Decision trees make it easy to evaluate feature importance, or contribution, to the model.

    There are a few ways to evaluate feature importance, like the gini importance measure of a feature, which equals the average information gain of nodes related to that feature.

Disadvantages

  • Decision-tree learners can create over-complex trees that do not generalize the data well. This is called overfitting. Mechanisms such as pruning, setting the minimum number of samples required at a leaf node or setting the maximum depth of the tree are necessary to avoid this problem.
  • Decision trees can be oversensitive and unstable because small variations in the data might result in a completely different tree being generated. This problem is mitigated by using decision trees within an ensemble (such as the random forest model or XGBoost).
  • Decision tree learners create biased trees if some classes dominate. It is therefore recommended to balance the dataset before fitting with the decision tree.

These aspects are taken from the general scikit-learn documentation about decision trees.

Programming a decision tree

From now you are going to learn how to build a decision tree model in Python using the scikit-learn library.

1. Import necessary libraries

The libraries used in this project are:

  • Pandas for handling input and output data.
  • Math for the square root function.
  • Sklearn for importing the decision tree algorithm and the validation metric.
  • Matplotlib for visualizing the model structure.
01
02
03
04
05
06
07
08
09
10
11
import pandas
 
from sklearn.model_selection import train_test_split
 
from sklearn import tree
 
from sklearn.tree import DecisionTreeRegressor
 
from sklearn.metrics import mean_squared_error
 
from matplotlib import pyplot as plt

2. Upload the dataset

1
2
3
#upload the dataset
 
dataset = pandas.read_csv( "C:\\...\\realestate_dataset.csv")

The data used to train this model look something like this:

RoomsBuilding areaYear BuiltSale price
181501987650 000
25952015300 000
361051967130 000
4475200175 000

The dataset I used is a real estate dataset that reports the sales values of properties with their respective building characteristics.

3. Select input and output features and split the data

1
2
3
4
5
6
7
8
9
#define the features and the label
 
input_variables = ['LotArea', 'OverallQual', 'YearBuilt', '1stFlrSF', '2ndFlrSF', 'FullBath', 'BedroomAbvGr', 'KitchenAbvGr', 'TotRmsAbvGrd']
 
X = dataset[input_variables]
 
y = dataset[["SalePrice"]]
 
train_X, val_X, train_y, val_y = train_test_split(X, y) #split the data into training and testing data

4. Train and validate the model

01
02
03
04
05
06
07
08
09
10
11
12
13
14
#load and train the model
 
model = DecisionTreeRegressor(max_depth=6)
 
model.fit(train_X, train_y)
 
 
#evaluate model's performance using the root mean squared error performance
 
RMSE = math.sqrt(
    mean_squared_error(val_y, model.predict(val_X))
)
 
print(RMSE)

Wow! The root mean squared error of our model is 41 034. It means that on average, the real value differs by $ 41 034 from the predicted price for each prediction. For guessing house prices this isn’t a bad result.

5. Visualize the model

01
02
03
04
05
06
07
08
09
10
#show model structure textual and visual representation
 
text_representation = tree.export_text(model)
 
 
pyplot.figure(figsize=(25,20))
 
tree.plot_tree(model, feature_names=input_variables, filled=True)
 
pyplot.show()
Decision tree splits visual representation
This represents a smaller tree than the one we’ve just created, otherwise, the text would be too small to read.

Decision tree full code

01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import pandas
 
import math
 
from sklearn import tree
 
from sklearn.tree import DecisionTreeRegressor
 
from sklearn.metrics import mean_squared_error
 
from sklearn.model_selection import train_test_split
  
from matplotlib import pyplot
 
 
#upload the dataset
 
dataset = pandas.read_csv( "C:\\...\\realestate_dataset.csv")
 
 
#define the features and the label
 
input_variables = ['LotArea', 'OverallQual', 'YearBuilt', '1stFlrSF', '2ndFlrSF', 'FullBath', 'BedroomAbvGr', 'KitchenAbvGr', 'TotRmsAbvGrd']
 
X = dataset[input_variables]
 
y = dataset[["SalePrice"]]
 
train_X, val_X, train_y, val_y = train_test_split(X, y) #split the data into training and testing data
  
 
#load and train the model
 
model = DecisionTreeRegressor(max_depth=6)
 
model.fit(train_X, train_y)
 
 
#evaluate model's performance using the root mean squared error performance
 
RMSE = math.sqrt(
    mean_squared_error(val_y, model.predict(val_X))
)
 
print(RMSE)
 
 
#show model structure textual and visual representation
 
text_representation = tree.export_text(model)
print(text_representation)
 
 
plt.figure(figsize=(25,20))
 
tree.plot_tree(model, feature_names=input_variables, filled=True)
 
plt.show()

Share the knowledge