In this article, I’ll show you how to program a decision tree in Python using ONLY MATH. In this way, I will help you understand how this algorithm works deep down.
Let’s get started.
Table of Contents
Disclaimer
I have already written an article that discusses decision tree thoroughly, explaining the mathematical concepts and steps of the algorithm with pictures and examples.
I suggest you read it before continuing.
Decision tree in Python
Problem statement
We want to solve a regression problem with only numerical features by fitting a decision tree to the data.
1. Import necessary libraries
import numpy
In this code I only use Numpy, a library useful for dealing with lists, to save space by computing the mean value of a list without iterating.
2. Define a dataset
train_X = {
"LotArea":[50, 70, 100],
"Quality":[8, 7.5, 9]
}
train_y = {
"SalePrice":[100, 105, 180]
}
I use a dictionary structure to store my dataset about house prices.
3. Define nodes
def create_node(datapoints):
datapoints = datapoints
mean_value = numpy.mean([train_y["SalePrice"][y] for y in datapoints])
mean_variance = numpy.mean([(train_y["SalePrice"][y] - mean_value) ** 2 for y in datapoints])
leaf = False if mean_variance != 0 else True
return {
# datapoints index in the node
"datapoints" : datapoints,
# mean output value of the data points in the node
"mean_value" : mean_value,
# impurity of the node
"mean_variance" : mean_variance,
# leaf state
"leaf" : leaf,
# best feature and treshold to split the node
"feature" : None,
"treshold" : None,
# the left sub-node of the node
"left" : None,
# the right sub-node of the node
"right" : None
}
The function create_node takes an input list with indexes of the dataset and returns a node, organized in a dictionary. Each node value meaning is listed above in a comment.
4. Define the split function
def split_node(node, feature, treshold):
node_left = create_node([x for x in node["datapoints"] if train_X[feature][x] <= train_X[feature][treshold]])
node_right = create_node([x for x in node["datapoints"] if train_X[feature][x] > train_X[feature][treshold]])
return node_left, node_right
5. Find the best threshold for a node
def find_treshold(node):
values = []
impurities = []
for feature in train_X:
for treshold in node["datapoints"]:
node_left, node_right = split_node(node, feature, treshold)
if node_left["datapoints"] == node["datapoints"] or node_right["datapoints"] == node["datapoints"]:
continue
if (node_left["mean_variance"] + node_right["mean_variance"]) / 2 >= node["mean_variance"]:
continue
weighted_impurity = node_left["mean_variance"] * len(node_left["datapoints"]) / len(node["datapoints"]) + node_right["mean_variance"] * len(node_right["datapoints"]) / len(node["datapoints"])
values.append([feature, treshold])
impurities.append(weighted_impurity)
best_split = impurities.index(min(impurities))
node["feature"] = values[best_split][0]
node["treshold"] = values[best_split][1]
The algorithm calculates all the possible splits and selects the threshold with the minimal weighted impurity to split the node.
6. Build the tree structure
def build_tree(node, max_depth, depth = 1):
# if the current depth is equal to max_depth, stop the splitting process
if depth < max_depth:
find_treshold(node)
node["left"], node["right"] = split_node(node, node["feature"], node["treshold"])
if node["left"]["leaf"] == False:
# re-execute this function with the left node as main node
build_tree(node["left"], max_depth, depth + 1)
if node["right"]["leaf"] == False:
# re-execute this function with the right node as main node
build_tree(node["right"], max_depth, depth + 1)
else:
node["leaf"] = True
This is a recursive function to build a tree based on a root node and a max_depth parameter.
7. Define the predict function
def predict(tree, X):
y = []
for index in range(len(X["LotArea"])):
current_node = tree
while not current_node["leaf"]:
if X[current_node["feature"]][index] <= train_X[current_node["feature"]][current_node["treshold"]]:
current_node = current_node["left"]
else:
current_node = current_node["right"]
y.append(current_node["mean_value"])
return y
This function locates the node where the input value falls in and returns the mean value of that node.
8. Fit a tree to the data
node1 = create_node([0, 1, 2])
build_tree(node1, max_depth)
9. Predict unseen values
val_X = {
"LotArea" : [50, 90],
"Quality" : [7.5, 90]
}
print(predict(node1, val_X))
Decision tree in Python full code
import numpy
train_X = {
"LotArea":[50, 70, 100],
"Quality":[8, 7.5, 9]
}
train_y = {
"SalePrice":[100, 105, 180]
}
def create_node(datapoints):
datapoints = datapoints
mean_value = numpy.mean([train_y["SalePrice"][y] for y in datapoints])
mean_variance = numpy.mean([(train_y["SalePrice"][y] - mean_value) ** 2 for y in datapoints])
leaf = False if mean_variance != 0 else True
return {
# datapoints index in the node
"datapoints" : datapoints,
# mean output value of the data points in the node
"mean_value" : mean_value,
# impurity of the node
"mean_variance" : mean_variance,
# leaf state
"leaf" : leaf,
# best feature and treshold to split the node
"feature" : None,
"treshold" : None,
# the left sub-node of the node
"left" : None,
# the right sub-node of the node
"right" : None
}
def split_node(node, feature, treshold):
node_left = create_node([x for x in node["datapoints"] if train_X[feature][x] <= train_X[feature][treshold]])
node_right = create_node([x for x in node["datapoints"] if train_X[feature][x] > train_X[feature][treshold]])
return node_left, node_right
def find_treshold(node):
values = []
impurities = []
for feature in train_X:
for treshold in node["datapoints"]:
node_left, node_right = split_node(node, feature, treshold)
if node_left["datapoints"] == node["datapoints"] or node_right["datapoints"] == node["datapoints"]:
continue
if (node_left["mean_variance"] + node_right["mean_variance"]) / 2 >= node["mean_variance"]:
continue
weighted_impurity = node_left["mean_variance"] * len(node_left["datapoints"]) / len(node["datapoints"]) + node_right["mean_variance"] * len(node_right["datapoints"]) / len(node["datapoints"])
values.append([feature, treshold])
impurities.append(weighted_impurity)
best_split = impurities.index(min(impurities))
node["feature"] = values[best_split][0]
node["treshold"] = values[best_split][1]
def build_tree(node, max_depth, depth = 1):
# if the current depth is equal to max_depth, stop the splitting process
if depth < max_depth:
find_treshold(node)
node["left"], node["right"] = split_node(node, node["feature"], node["treshold"])
if node["left"]["leaf"] == False:
# re-execute this function with the left node as main node
build_tree(node["left"], max_depth, depth + 1)
if node["right"]["leaf"] == False:
# re-execute this function with the right node as main node
build_tree(node["right"], max_depth, depth + 1)
else:
node["leaf"] = True
def predict(tree, X):
y = []
for index in range(len(X["LotArea"])):
current_node = tree
while not current_node["leaf"]:
if X[current_node["feature"]][index] <= train_X[current_node["feature"]][current_node["treshold"]]:
current_node = current_node["left"]
else:
current_node = current_node["right"]
y.append(current_node["mean_value"])
return y
node1 = create_node([0, 1, 2])
build_tree(node1, 2),
val_X = {
"LotArea" : [50, 90],
"Quality" : [7.5, 7.5]
}
print(predict(node1, val_X))