Decision tree in Python from scratch

In this article, I’ll show you how to program a decision tree in Python using ONLY MATH. In this way, I will help you understand how this algorithm works deep down.

Let’s get started.

Disclaimer

I have already written an article that discusses decision tree thoroughly, explaining the mathematical concepts and steps of the algorithm with pictures and examples.

I suggest you read it before continuing.

Decision tree in Python

Problem statement

We want to solve a regression problem with only numerical features by fitting a decision tree to the data.

1. Import necessary libraries

import numpy

In this code I only use Numpy, a library useful for dealing with lists, to save space by computing the mean value of a list without iterating.

2. Define a dataset

train_X = {
    "LotArea":[50, 70, 100],
    "Quality":[8, 7.5, 9]
}

train_y = {
    "SalePrice":[100, 105, 180]
}

I use a dictionary structure to store my dataset about house prices.

3. Define nodes

def create_node(datapoints):
    datapoints = datapoints
    mean_value = numpy.mean([train_y["SalePrice"][y] for y in datapoints])
    mean_variance = numpy.mean([(train_y["SalePrice"][y] - mean_value) ** 2  for y in datapoints])
    leaf = False if mean_variance != 0 else True

    return {
        # datapoints index in the node
        "datapoints" : datapoints,
        # mean output value of the data points in the node
        "mean_value" : mean_value,
        # impurity of the node
        "mean_variance" : mean_variance,
        # leaf state
        "leaf" : leaf,
        # best feature and treshold to split the node
        "feature" : None,
        "treshold" : None,
        # the left sub-node of the node
        "left" : None,
        # the right sub-node of the node
        "right" : None
    }

The function create_node takes an input list with indexes of the dataset and returns a node, organized in a dictionary. Each node value meaning is listed above in a comment.

4. Define the split function

def split_node(node, feature, treshold):
    node_left = create_node([x for x in node["datapoints"] if train_X[feature][x] <= train_X[feature][treshold]])
    node_right = create_node([x for x in node["datapoints"] if train_X[feature][x] > train_X[feature][treshold]])
    return node_left, node_right

5. Find the best threshold for a node

def find_treshold(node):
    values = []
    impurities = []

    for feature in train_X:

        for treshold in node["datapoints"]:
            node_left, node_right = split_node(node, feature, treshold)

            if node_left["datapoints"] == node["datapoints"] or node_right["datapoints"] == node["datapoints"]: 
                continue

            if (node_left["mean_variance"] + node_right["mean_variance"]) / 2 >= node["mean_variance"]:
                continue

            weighted_impurity = node_left["mean_variance"] * len(node_left["datapoints"]) / len(node["datapoints"]) + node_right["mean_variance"] * len(node_right["datapoints"]) / len(node["datapoints"])
            values.append([feature, treshold])
            impurities.append(weighted_impurity) 

    best_split = impurities.index(min(impurities))
    node["feature"] = values[best_split][0]
    node["treshold"] = values[best_split][1] 

The algorithm calculates all the possible splits and selects the threshold with the minimal weighted impurity to split the node.

6. Build the tree structure

def build_tree(node, max_depth, depth = 1):
    # if the current depth is equal to max_depth, stop the splitting process
    if depth < max_depth:
        find_treshold(node)
        node["left"], node["right"] = split_node(node, node["feature"], node["treshold"])

        if node["left"]["leaf"] == False:
            # re-execute this function with the left node as main node
            build_tree(node["left"], max_depth, depth + 1)

        if node["right"]["leaf"] == False:
            # re-execute this function with the right node as main node
            build_tree(node["right"], max_depth, depth + 1)
            
    else:
        node["leaf"] = True

This is a recursive function to build a tree based on a root node and a max_depth parameter.

7. Define the predict function

def predict(tree, X):
    y = []

    for index in range(len(X["LotArea"])):
        current_node = tree

        while not current_node["leaf"]:

            if X[current_node["feature"]][index] <= train_X[current_node["feature"]][current_node["treshold"]]:
                current_node = current_node["left"]

            else:
                current_node = current_node["right"]

        y.append(current_node["mean_value"])

    return y

This function locates the node where the input value falls in and returns the mean value of that node.

8. Fit a tree to the data

node1 = create_node([0, 1, 2])

build_tree(node1, max_depth)

9. Predict unseen values

val_X = {
    "LotArea" : [50, 90],
    "Quality" : [7.5, 90]
}

print(predict(node1, val_X))

Decision tree in Python full code

import numpy


train_X = {
    "LotArea":[50, 70, 100],
    "Quality":[8, 7.5, 9]
}

train_y = {
    "SalePrice":[100, 105, 180]
}


def create_node(datapoints):
    datapoints = datapoints
    mean_value = numpy.mean([train_y["SalePrice"][y] for y in datapoints])
    mean_variance = numpy.mean([(train_y["SalePrice"][y] - mean_value) ** 2  for y in datapoints])
    leaf = False if mean_variance != 0 else True

    return {
        # datapoints index in the node
        "datapoints" : datapoints,
        # mean output value of the data points in the node
        "mean_value" : mean_value,
        # impurity of the node
        "mean_variance" : mean_variance,
        # leaf state
        "leaf" : leaf,
        # best feature and treshold to split the node
        "feature" : None,
        "treshold" : None,
        # the left sub-node of the node
        "left" : None,
        # the right sub-node of the node
        "right" : None
    }


def split_node(node, feature, treshold):
    node_left = create_node([x for x in node["datapoints"] if train_X[feature][x] <= train_X[feature][treshold]])
    node_right = create_node([x for x in node["datapoints"] if train_X[feature][x] > train_X[feature][treshold]])
    return node_left, node_right


def find_treshold(node):
    values = []
    impurities = []

    for feature in train_X:

        for treshold in node["datapoints"]:
            node_left, node_right = split_node(node, feature, treshold)

            if node_left["datapoints"] == node["datapoints"] or node_right["datapoints"] == node["datapoints"]: 
                continue

            if (node_left["mean_variance"] + node_right["mean_variance"]) / 2 >= node["mean_variance"]:
                continue

            weighted_impurity = node_left["mean_variance"] * len(node_left["datapoints"]) / len(node["datapoints"]) + node_right["mean_variance"] * len(node_right["datapoints"]) / len(node["datapoints"])
            values.append([feature, treshold])
            impurities.append(weighted_impurity) 

    best_split = impurities.index(min(impurities))
    node["feature"] = values[best_split][0]
    node["treshold"] = values[best_split][1] 


def build_tree(node, max_depth, depth = 1):
    # if the current depth is equal to max_depth, stop the splitting process
    if depth < max_depth:
        find_treshold(node)
        node["left"], node["right"] = split_node(node, node["feature"], node["treshold"])

        if node["left"]["leaf"] == False:
            # re-execute this function with the left node as main node
            build_tree(node["left"], max_depth, depth + 1)

        if node["right"]["leaf"] == False:
            # re-execute this function with the right node as main node
            build_tree(node["right"], max_depth, depth + 1)
            
    else:
        node["leaf"] = True


def predict(tree, X):
    y = []

    for index in range(len(X["LotArea"])):
        current_node = tree

        while not current_node["leaf"]:

            if X[current_node["feature"]][index] <= train_X[current_node["feature"]][current_node["treshold"]]:
                current_node = current_node["left"]

            else:
                current_node = current_node["right"]

        y.append(current_node["mean_value"])

    return y


node1 = create_node([0, 1, 2])

build_tree(node1, 2),


val_X = {
    "LotArea" : [50, 90],
    "Quality" : [7.5, 7.5]
}

print(predict(node1, val_X))
Share the knowledge