Implementing a classification tree with Gini Impurity from scratch in Python

Pandula Weerasooriya
5 min readDec 27, 2019

--

Even though , classified (pun intended) as a weak classifier, Decision trees play a huge a role in Machine learning. They provide the basis for a subset of ML algorithm family known as Ensemble learning, which includes algorithms such as Random forest and Boosting. While most of these algorithms has been abstracted away in Python, R and some BI/Stat tools, by implementing them from scratch, an inquisitive person can get a good understanding of their underlying mechanisms.

Why I chose to implement decision trees first is that, whenever I try do a hyper-parameter optimization on an Ensemble method it requires having knowledge of decision tree parameters such as max depth, split criterion, max_leaf_node etc. Therefore, I thought of implementing it first before diving in to the aforementioned Ensemble methods.

Decision trees using the CART algorithm

I was mainly inspired to do this after watching a small video on decision trees at StatQuest. If you have not subscribed to this channel, I urged you to do so, as it contains the most intuitive explanations of topics in Statistics I’ve ever known. Which can be pretty rare in Statistics.

So, decision tree is just like a binary search tree algorithm that splits nodes based on some criteria. However, the splitting criteria can vary depending on the data and the splitting method that you are using.

A tree consists of 3 types of nodes, a root node, intermediary nodes and leaf nodes. The below image shows how a decision tree gets applied for a simple dataset on heart diseases.

Screenshot taken from https://www.youtube.com/watch?v=7VeUPuFGJHk

The first question is choosing the right variable to split the target for the root node. This is where the splitting condition plays the role. There are various methods used to quantify the splitting criteria. Like,

  • Entropy method
  • Information gain
  • Gini impurity

I have always found Gini impurity method to be the least threatening and intuitive one. But before calculating that, we need to separate our data as below.

Screenshot taken from https://www.youtube.com/watch?v=7VeUPuFGJHk

It is evident that, we are unable to straight away decide the splitting variable. So the gini impurity is calculated for each variable. Note that, if a node contains only one class of a target variable, then the gini equation will become zero. If this is the case, we called them as pure nodes and higher the gini value is, higher the impurity of the node.

Below image shows how you can calculate the gini impurity for the left node for the chest pain, which is by, using the distribution of the target variable conditioned to having or not having a chest pain.

Screenshot taken from https://www.youtube.com/watch?v=7VeUPuFGJHk

We can do the same calculation for the right node as well, where gini value is calculated for target variable for patients having the opposite condition for chest pain compared to the left node.

Screenshot taken from https://www.youtube.com/watch?v=7VeUPuFGJHk

After calculating the gini impurity for both left and right nodes, we can get a weighted average of the two nodes. Note that, we have to take the variable with the lowest gini value as the best splitting variable.

Screenshot taken from https://www.youtube.com/watch?v=7VeUPuFGJHk

Good news is that we can follow the exact same steps at each iteration of building the tree.

  • Calculate all of gini impurity scores for the remaining variables.
  • If the node itself has the lowest score, then there is no point in separating the patients any more and it becomes a leaf node.
  • If separating the data results in an improvement, then pick the separation with the lowest impurity values.
Screenshot taken from https://www.youtube.com/watch?v=7VeUPuFGJHk

As per the above image, not having blocked circulation has separated the target better than separating the node using chest pain and therefore that node has become a leaf node.

A point to stress here is that, we have only looked at binary predicate variables so far and there are other types of variables which can be available in a dataset. Namely, numerical variables, multi-class variables, ordinal variables etc. However, to keep things simple, I have of only considered binary and multi-class variables. And if you are wondering how to come with a decision criteria for multi-class variables, we’d have to consider all possible combination of available classes as shown below.

Screenshot taken from https://www.youtube.com/watch?v=7VeUPuFGJHk

Python Implementation

As I’ve mentioned above I have only used categorical predicate variables and I have only tried to implement a decision tree for a binary classification task. Regression trees are another branch of decision trees and this video provides a good explanation for them https://www.youtube.com/watch?v=g9c66TUylZ4.

I have used the Titanic dataset for the classification, which is known as the Hello World of Kaggle datasets. Furthermore, only the training part (building the tree part) is provided here, since predictions would require traversing the binary tree and that could be presented in another chapter with other implications of decision trees.

First I have imported the relevant packages and the Titanic data-set. Which can be downloaded through https://www.kaggle.com/c/titanic/data here. Few pre-processing steps were done to extract only 3 categorical variables

I have build a class for the Nodes and has initialized it’s properties. Moreover, methods to identify the data type, calculate the gini impurity and finding the best combination for each categorical variable are also defined within the class.

Finally, methods for node evaluation and node insertion are also implemented. Note that node insertion is done in a recursive way and further clarification about that can be found here https://medium.com/@stephenagrice/how-to-implement-a-binary-search-tree-in-python-e1cdba29c533.

--

--

Pandula Weerasooriya
Pandula Weerasooriya

Written by Pandula Weerasooriya

A fullstack engineer who's passionate about building data intensive products and distributed systems. My stack includes Golang, Rust, React, NodeJS and Python.

No responses yet