What are Decision Trees in Machine Learning (Classification And Regression)

Read it in 10 Mins

Last updated on
06th Jun, 2022
26th Aug, 2019
What are Decision Trees in Machine Learning (Classification And Regression)

Introduction to Machine Learning and its types

Machine Learning is an interdisciplinary field of study and is a sub-domain of Artificial Intelligence. It gives computers the ability to learn and infer from a huge amount of homogeneous data, without having to be programmed explicitly. Before dwelling on this article, let's know more about r squared meaning here.

Introduction to Machine Learning and types of Machine Learning

Types of Machine Learning: Machine Learning can broadly be classified into three types:

  • Supervised Learning: If the available dataset has predefined features and labels, on which the machine learning models are trained, then the type of learning is known as Supervised Machine Learning. Supervised Machine Learning Models can broadly be classified into two sub-parts: Classification and Regression. These have been discussed further in detail.

Supervised machine learning

  • Unsupervised Learning: If the available dataset has predefined features but lacks labels, then the Machine Learning algorithms perform operations on this data to assign labels to it or to reduce the dimensionality of the data. There are several types of Unsupervised Learning Models, the most common of them being: Principal Component Analysis (PCA) and Clustering.

Unsupervised machine learning

  • Reinforcement Learning: Reinforcement Learning is a more advanced type of learning, where, the model learns from “Experience”. Here, features and labels are not clearly defined. The model is just given a “Situation” and is rewarded or penalized based on the “Outcome”. The model thus learns to optimize the “Situation” to maximize the Rewards and hence improves the “Outcome” with “Experience”.

Reinforcement machine learning


Classification is the process of determination/prediction of the category to which a data-point may belong to. It is the process by which a Supervised Learning Algorithm learns to draw inference from the features of a given dataset and predict which class or group or category does the particular data point belongs to.

Example of Classification: Let’s assume that we are given a few images of handwritten digits (0-9). The problem statement is to “teach” the machine to classify correctly which image corresponds to which digit. A small sample of the dataset is given below:

Example of Classification

The machine has to be thus trained, such that, when given an input of any such hand-written digit, it has to correctly classify the digits and mention which digit the image represents. This is classed classification of hand-written digits.

Looking at another example which is not image-based, we have 2D data (x1 and x2) which is plotted in the form of a graph shown below.

2D data example of Classification in Machine Learning

The red and green dots represent two different classes or categories of data. The main goal of the classifier is that given one such “dot” of unknown class, based on its “features”, the algorithm should be able to correctly classify if that dot belongs to the red or green class. This is also shown by the line going through the middle, which correctly classifies the majority of the dots.

Applications of Classification: Listed below are some of the real-world applications of classification Algorithms.

  • Face Recognition: Face recognition finds its applications in our smartphones and any other place with Biometric security. Face Recognition is nothing but face detection followed by classification. The classification algorithm determines if the face in the image matches with the registered user or not.
  • Medical Image Classification: Given the data of patients, a model that is well trained is often used to classify if the patient has a malignant tumor (cancer), heart ailments, fractures, etc.


Regression is also a type of supervised learning. Unlike classification, it does not predict the class of the given data. Instead, it predicts the corresponding values of a given dataset based on the “features” it encounters.

Example of Regression: For this, we will look at a dataset consisting of California Housing Prices. The contents of this dataset are shown below.


Here, there are several columns. Each of the columns shows the “features” based on which the machine learning algorithm predicts the housing price (shown by yellow highlight). The primary goal of the regression algorithm is that, given the features of a given house, it should be able to correctly estimate the price of the house. This is called a regression problem. It is similar to curve fitting and is often confused with the same.

Applications of Regression: Listed below are some of the real-world applications of regression Algorithms.

  • Stock Market Prediction: Regression algorithms are used to predict the future price of stocks based on certain past features like time of the day or festival time, etc. Stock Market Prediction also falls under a subdomain of study called Time Series Analysis.
  • Object Detection Algorithms: Object Detection is the process of detection of the location of a given object in an image or video. This process returns the coordinates of the pixel values stating the location of the object in the image. These coordinates are determined by using regression algorithms alongside classification.
Assign specific classes to the data based on its features.Predict values based on the features of the dataset.
Prediction is discrete or categorical in nature.Prediction is continuous in nature.

Introduction to the building blocks of Decision Trees

In order to get started with Decision Trees, it is important to understand the basic building blocks of decision trees. Hence, we start building the concepts slowly with some basic theory.

1. Entropy

Definition: It is a commonly used concept in Information Theory and is a measure of “purity” of an arbitrary collection of information.

Mathematical Equation:

Entropy Formula

Here, given a collection S, containing positive and negative examples, the Entropy of S is given by the above equation, where, p represents the probability of occurrence of that example in the given data.

In a more generalized form, Entropy is given by the following equation:

Generalised form of Entropy equation

Example: As an example, a sample S is taken, which contains 14 data samples and includes 9 positive and 5 negative samples. The same is denoted by the mathematical notion: [9+, 5­­–].

Thus, Entropy of the given sample can be calculated as follows:

Entropy Calculation

2. Information Gain

Definition: With the knowledge of Entropy, the amount of relevant information that is gained form a given random sample size can be calculated and is known as Information Gain.

Mathematical Equation:

Gain formula with Entropy

Here, the Gain (S, A) is the Information Gain of an attribute A relative to a sample S. The Values(A) is a set of all possible values for attribute A.

Example: As an example, let’s assume S is a collection of 14 training-examples. Here, in this example, we will consider the Attribute to be Wind and the values of that corresponding attribute will be Weak and Strong. In addition to the previous example information, we will assume that out of the previously mentioned 9 positives and 5 negative samples, 6 positive and 2 negative samples have the value of the attribute Wind=Weak, and the remaining have Wind=Strong. Thus, under such a circumstance, the information gained by the attribute Wind is shown below.

Finding Gain value from Entropy equations

Decision Tree

Introduction: Since we have the basic building blocks out of the way, let’s try to understand what exactly is a Decision Tree. As the name suggests, it is a Tree which is developed based on certain decisions taken by the algorithm in accordance with the given data that it has been trained on.

In simple words, a Decision Tree uses the features in the given data to perform Supervised Learning and develop a tree-like structure (data structure) whose branches are developed in such a way that given the feature-set, the decision tree can predict the expected output relatively accurately.

Example:  Let us look at the structure of a decision tree. For this, we will take up an example dataset called the “PlayTennis” dataset. A sample of the dataset is shown below.

Play Tennis Data set used to look at the structure of a Decision Tree.

In summary, the target of the model is to predict if the weather conditions are suitable to play tennis or not, as guided by the dataset shown above.

As it can be seen in the dataset, it contains certain information (features) for each day. In this, we have the feature-attributes: Outlook, Temperature, Humidity and Wind and the target-attribute PlayTennis. Each of these attributes can take up certain values, for example, the attribute Outlook has the values Sunny, Rain and Overcast.

With a clear idea of the dataset, jumping a bit forward, let us look at the structure of the learned Decision Tree as developed from the above dataset.

The structure of the learned Decision Tree as developed from the "PlayTennis" Data Set

As shown above, it can clearly be seen that, given certain values for each of the attributes, the learned decision tree is capable of giving a clear answer as to whether the weather is suitable for Tennis or not.

Algorithm: With the overall intuition of decision trees, let us look at the formal Algorithm:

  1. ID3(Samples, Target_attribute, Attributes):
  2. Create a root node for the Tree
  3. If all the Samples are positive, Return a single-node tree Root, with label = +
  4. If all the Samples are negative, Return a single-node tree Root, with label = –
  5. If Attribute is empty, Return the single-node tree Root with the label = Most common value of the Target_attribute among the Samples.
  6. Otherwise:
    • A ← the attribute from Attributes that best classifies the Samples
    • The decision attribute for Root ← A
    • For each possible value of A:
      • Add a new tree branch below Root, corresponding to the test A = vi
      • Let the Samplesvi be the subset of Samples that have value vi for A
      • If Samplesvi is empty:
        • Then below the new branch add a leaf node with the label = most common value of Target_attribute in the samples
        • Else below the new branch add the subtree:
        • ID3(Samplesvi, Target_attribute, Attributes – {A})
  7. End
  8. Return Root

Connecting the dots: Since the overall idea of decision trees have been explained, let’s try to figure out how Entropy and Information Gain fits into this entire process.

Entropy (E) is used to calculate Information Gain, which is used to identify which attribute of a given dataset provides the highest amount of information. The attribute which provides the highest amount of information for the given dataset is considered to have more contribution towards the outcome of the classifier and hence is given the higher priority in the tree.

For Example, considering the PlayTennis Example, if we calculate the Information Gain for two corresponding Attributes: Humidity and Wind, we would find that Humidity plays a more important role in deciding whether to play tennis or not. Hence, in this case, Humidity is considered as a better classifier. The detailed calculation is shown in the figure below:

Calculating the Information Gain for 2 corresponding attributes Humidity and Wind.

Applications of Decision Tree

With the basic idea out of the way, let’s look at where decision trees can be used:

  1. Select a flight to travel: Decision trees are very good at classification and hence can be used to select which flight would yield the best “bang-for-the-buck”. There are a lot of parameters to consider, such as if the flight is connecting or non-stop, or how reliable is the service record of the given airliner, etc.
  2. Selecting alternative products: Often in companies, it is important to determine which product will be more profitable at launch. Given the sales attributes such as market conditions, competition, price, availability of raw materials, demand, etc. a Decision Tree classifier can be used to accurately determine which of the products would maximize the profits.
  3. Sentiment Analysis: Sentiment Analysis is the determination of the overall opinion of a given piece of text and is especially used to determine if the writer’s comment towards a given product/service is positive, neutral or negative. Decision trees are very versatile classifiers and are used for sentiment analysis in many Natural Language Processing (NLP) applications.
  4. Energy Consumption: It is very important for electricity supply boards to correctly predict the amount of energy consumption in the near future for a particular region. This is to make sure that un-used power can be diverted towards an area with a higher demand to keep a regular and uninterrupted supply of power throughout the grid. Decision Trees are often used to determine which region is expected to require more or less power in the up-coming time-frame.
  5. Fault Diagnosis: In the Engineering domain, one of the widely used applications of decision trees is the determination of faults. In the case of load-bearing rotatory machines, it is important to determine which of the component(s) have failed and which ones can directly or indirectly be affected by the failure. This is determined by a set of measurements that are taken. Unfortunately, there are numerous measurements to take and among them, there are some measurements which are not relevant to the detection of the fault. A Decision Tree classifier can be used to quickly determine which of these measurements are relevant in the determination of the fault.

Advantages of Decision Tree

Listed below are some of the advantages of Decision Trees:

  1. Comprehensive: Another significant advantage of a decision tree is that it forces the algorithm to take into consideration all the possible outcomes of a decision and traces each path to a conclusion.
  2. Specific: The output of decision trees is very specific and reduces uncertainty in the prediction. Hence, they are considered as really good classifiers.
  3. Easy to use: Decision Trees are one of the simplest, yet most versatile algorithms in Machine Learning. It is based on simple math and no complex formulas. They are easy to visualize, understand and explain.
  4. Versatile: A lot of business problems can be solved using Decision Trees. They find their applications in the field of Engineering, Management, Medicine, etc. basically, any situation where data is available and a decision needs to be taken in uncertain conditions.
  5. Resistant to data abnormalities: Data is never perfect and there are always many abnormalities in the dataset. Some of the most common abnormalities are outliers, missing data and noise. While most Machine Learning algorithms fail with even a minor set of abnormalities, Decision Trees are very resilient and is able to handle a fair percentage of such abnormalities quite well without altering the results.
  6. Visualization of the decision taken: Often in Machine Learning models, data scientists struggle to reason as to why a certain model is giving a certain set of outputs. Unfortunately, for most of the algorithms, it is not possible to clearly determine and visualize the actual process of classification that leads to the final outcome. However, decision trees are very easy to visualize. Once the tree is trained, it can be visualized and the programmer can see exactly how and why the conclusion was reached. It is also easy to explain the outcome to a non-technical team with the “tree” type visualization. This is why many organizations prefer to use decision trees over other Machine Learning Algorithms.

Limitations of Decision Tree

Listed below are some of the limitations of Decision Trees:

  1. Sensitivity to hyperparameter tuning: Decision Trees are very sensitive to hyperparameter tuning. Hyperparameters are those parameters which are in control of the programmer and can be tuned to get better performance out of a given model. Unfortunately, the output of a decision tree can vary drastically if the hyperparameters are inaccurately tuned.
  2. Overfitting: Decision trees are prone to overfitting. Overfitting is a concept where the model learns the data too well and hence performs well on training dataset but fails to perform on testing dataset. Decision trees are prone to overfitting if the breadth and depth of the tree is set to very high for a simpler dataset.
  3. Underfitting: Similar to overfitting, decision trees are also prone to underfitting. Underfitting is a concept where the model is too simple for it to learn the dataset effectively. Decision tree suffers from underfitting if the breadth and depth of the model or the number of nodes are set too low. This does not allow the model to fit the data properly and hence fails to learn.

Code Examples

With the theory out of the way, let’s look at the practical implementation of decision tree classifiers and regressors.

1. Classification

In order to conduct classification, a diabetes dataset from Kaggle has been used. It can be downloaded.

  • The initial step for any data science application is data visualization. Hence, the dataset is shown below:

Example Data set for Classification in Machine Learning

The highlighted column is the target value that the model is expected to predict, given the parameters.

  • Load the Libraries. We will be using pandas to load and manipulate data. Sklearn is used for applying Machine Learning models on the data.
# Load libraries
import pandas as pd
from sklearn.tree import DecisionTreeClassifier # Import Decision Tree Classifier
from sklearn.model_selection import train_test_split # Import train_test_split function
from sklearn import metrics #Import scikit-learn metrics module for accuracy calculation 
  • Load the data. Pandas is used to read the data from the CSV.
col_names = ['pregnant', 'glucose', 'bp', 'skin', 'insulin', 'bmi', 'pedigree', 'age', 'label']
# load dataset
pima = pd.read_csv("pima-indians-diabetes.csv", header=None, names=col_names)
  • Feature Selection: The relevant features are selected for the classification.
#split dataset in features and target variable
feature_cols = ['pregnant', 'insulin', 'bmi', 'age','glucose','bp','pedigree']
X = pima[feature_cols] # Features
y = pima.label # Target variable
  • splitting the data: The dataset needs to be split into training and testing data. The training data is used to train the model, while the testing data is used to test the model’s performance on unseen data.
# Split dataset into training set and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) # 70% training and 30% test
  • Building the decision tree. These few lines initialize, train and predict on the dataset.
# Create Decision Tree classifier object
clf = DecisionTreeClassifier()

# Train Decision Tree Classifier
clf = clf.fit(X_train,y_train)

#Predict the response for test dataset
y_pred = clf.predict(X_test)
  • The model’s accuracy is evaluated by using Sklearn’s metrics library. 
# Model Accuracy, how often is the classifier correct?
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
Output: Accuracy: 0.6753246753246753
  • This will generate the decision tree that is shown in the following image

The Decision Tree generated from the Dataset by Classification method in Machine Learning.

2. Regression

In order to conduct classification, a diabetes dataset from Kaggle has been used.

  • For this example, we will generate a Numpy Array which will simulate a scatter plot resembling a sine wave with a few randomly added noise elements. 
# Import the necessary modules and libraries
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt

# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
  • This time we create two regression models to experiment and see how overfitting looks like for a decision tree. Hence, we initialize the two Decision Tree Regression objects and train them on the given data.
# Fit regression model
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)
  • After fitting the model, we predict on a custom test dataset and plot the results to see how it performed.
# Predict
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)

# Plot the results
plt.scatter(X, y, s=20, edgecolor="black",
            c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue",
        label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.title("Decision Tree Regression")

The graph that is thus generated is shown below. Here we can clearly see that for a simple dataset  when we used max_depth=5 (green), the model started to overfit and learned the patterns of the noise along with the sine wave. Such kinds of models do not perform well. Meanwhile, for max_depth=3 (blue), it has fitted the dataset in a better way when compared to the other one.

Decision Tree by Regression in Machine Learning.


In this article, we tried to build an intuition, by starting from the basics of the theory behind the working of a decision tree classifier. However, covering every aspect of detail is beyond the scope of this article. Hence, it is suggested to go through this book to dive deeper into the specifics. Further, moving on, the code snippets introduces the “Hello World” of how to use both, real-world data and artificially generated data to train a Decision Tree model and predict using the same. This will allow any novice to get an overall balanced theoretical and practical idea about the workings of Classification and Regression Trees and their implementation.


Animikh Aich

Computer Vision Engineer

Animikh Aich is a Deep Learning enthusiast, currently working as a Computer Vision Engineer. His work includes three International Conference publications and several projects based on Computer Vision and Machine Learning.