Machine Learning in Python: Decision Tree Classification


Welcome to this blog post on Decision Tree Classification! Let’s share with you the ins and outs of one of the most popular and widely used algorithms in machine learning. Decision tree classification is a powerful tool that can be used to solve a wide range of problems, from predicting customer churn to detecting fraud. In this blog post, we will cover the basics of decision tree classification, including how it works, how to build a decision tree model using Python programming language, and tips for optimizing your model’s performance. Whether you are new to machine learning or an experienced practitioner looking to improve your skills, this blog post has something for everyone. So let’s get started!

Use Cases for Decision Tree Classification

• Decision Tree Classification is commonly used in the field of finance for credit scoring. It can be used to predict whether or not a customer will default on their loan based on various factors such as income, employment history, and credit score.

• Another unique use case for Decision Tree Classification is in the field of healthcare. It can be used to predict a patient’s risk of developing certain diseases based on factors such as age, gender, family history, and lifestyle choices. This information can then be used to develop personalized treatment plans and preventative measures.

• In the field of marketing, Decision Tree Classification can be used to identify potential customers who are most likely to make a purchase or respond positively to a marketing campaign. By analyzing data such as demographics, past purchase behavior, and online activity, businesses can tailor their marketing efforts to target these high-value customers more effectively.

• Decision Tree Classification is also useful in the field of fraud detection. By analyzing patterns in financial transactions and account activity, it can help identify potential instances of fraud before they occur. This is particularly important in industries such as banking and insurance where fraudulent activity can have significant financial consequences.

• Finally, Decision Tree Classification can be used in the field of environmental science to predict the likelihood of natural disasters such as hurricanes or earthquakes based on factors such as weather patterns and geological data. This information can then be used to inform disaster preparedness plans and evacuation strategies.

How Decision Tree Classification Works

Decision Tree Classification is a popular machine learning algorithm that works by constructing a tree-like model to classify data. This algorithm is widely used in various fields such as finance, healthcare, and marketing.

The decision tree classification algorithm follows the following steps:

  1. Data Preparation: Before building a decision tree model, it is essential to prepare the data. The data should be cleaned and formatted correctly so that it can be used for training and testing the model.
  2. Splitting the Data: The next step is to split the dataset into two parts- training data and test data. The training set is used for building the decision tree model, while the test set is used for evaluating its performance.
  3. Building Decision Tree Model: In this step, the decision tree model is constructed using an iterative process called recursive partitioning or top-down induction of decision trees. At each node of the tree, we choose a feature that best splits the data based on some criteria such as Gini Index or Information Gain.
  4. Evaluating Model Performance: Once we have built our decision tree model, we need to evaluate its performance using test data. We do this by measuring metrics such as Accuracy, Precision, Recall or F1 Score.
  5. Tuning Parameters: Finally, we can tune parameters such as maximum depth of the tree or minimum samples required at each leaf node to improve our model’s performance further.

Decision Tree Classification works by constructing a hierarchical structure of nodes that represent decisions based on features in input data until arriving at a prediction or classification label at each leaf node. This algorithm is simple yet powerful and can handle both categorical and numerical features with high accuracy if tuned properly.

Decision Tree Classification Python Example

Here is a long piece of example Python code that implements Decision Tree Classification for the famous iris dataset using scikit-learn library. It also includes visualization code using Matplotlib.

# Importing necessary libraries
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

# Loading iris dataset from scikit-learn library and creating a pandas dataframe
iris = load_iris()
df = pd.DataFrame(, columns=iris.feature_names)

# Adding target variable 'species' to the dataframe
df['species'] = pd.Categorical.from_codes(, iris.target_names)

# Splitting the dataset into train and test sets
train = df.sample(frac=0.8, random_state=1)
test = df.drop(train.index)

# Creating decision tree classifier object and fitting it on the training set
clf = DecisionTreeClassifier(random_state=1)[iris.feature_names], train['species'])

# Visualizing the decision tree using plot_tree function from scikit-learn library and matplotlib for styling
plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)

# Predicting species on test set and calculating accuracy score
predictions = clf.predict(test[iris.feature_names])
accuracy_score = clf.score(test[iris.feature_names], test['species'])
print('Accuracy Score:', accuracy_score)

In this example code, we start by importing necessary libraries including pandas for data manipulation, scikit-learn for machine learning algorithms like Decision Tree Classifier, and Matplotlib for visualization.

Next step is loading iris dataset from scikit-learn library which contains information about three different types of iris flowers – Setosa, Versicolour and Virginica. We create a pandas dataframe from this dataset and split it into train and test sets.

Then we create a Decision Tree Classifier object with random state as 1 which ensures that the results are reproducible. We fit this classifier on the training set and visualize the decision tree using plot_tree function provided by scikit-learn library.

Finally, we predict species on test set using our trained model and calculate accuracy score which is printed to the console.

The visualization code using Matplotlib creates a beautiful decision tree plot which helps us understand how the model is making decisions based on features like petal width, petal length, sepal width and sepal length.

Here’s an example of what the decision tree visualization might look like when run:

This code shows how Decision Tree Classification can be implemented in Python along with visualization for better understanding of how the model is making decisions.

Pros of Decision Tree Classification

• Decision tree classification is a simple and interpretable algorithm that is easy to understand, visualize and explain. It is intuitive to use for beginners, making it a popular choice for data analysts and business professionals.

• Decision trees can handle both categorical and numerical data. This means that they can be used in a variety of applications such as market segmentation, risk analysis, fraud detection, customer profiling, etc.

• Decision tree classification does not require any assumptions about the underlying distribution of the data. This makes it robust to outliers and missing values in the dataset. In contrast, other algorithms like logistic regression or Naive Bayes assume that the data follows a certain distribution.

• Decision trees are capable of identifying complex relationships between variables by recursively splitting the dataset based on the most informative features. This allows them to capture nonlinear interactions between predictors that may be missed by linear models.

• Decision tree classification can be easily adapted to handle imbalanced datasets by adjusting the class weights or using ensemble methods such as random forests or boosting. This ensures that rare events are not overlooked and improve its accuracy in real-world scenarios.

Cons of Decision Tree Classification

• Overfitting: One of the main disadvantages of Decision Tree Classification is overfitting, which occurs when the model is too complex and fits the training data too closely. This can result in poor generalization performance and prediction accuracy on new data.

• Instability: Decision Trees can be unstable because small variations in the data can result in a completely different tree being generated. This makes it difficult to interpret and analyze the results, as well as leading to unpredictable behavior.

• Bias: Another disadvantage of Decision Trees is that they can be biased towards certain classes or attributes that are more frequently represented in the training data. This can lead to inaccurate predictions for underrepresented classes or attributes, which may be important in real-world applications.

• Limited Applicability: Decision Trees are not suitable for all types of problems or datasets. For example, they tend to work best with categorical or binary data, rather than continuous variables or mixed types of features. They also do not handle missing values well, which can limit their usefulness in some contexts.

• Lack of Robustness: Finally, Decision Trees are not very robust when faced with lots of noise in the data. These anomalies can have a disproportionate impact on the structure of the tree and lead to incorrect predictions. As a result, it’s important to preprocess and clean the data carefully before using decision trees for classification tasks.


In conclusion, decision tree classification is a powerful machine learning algorithm that is widely used in various industries. It provides a simple yet effective way to classify data based on their attributes and can handle both categorical and numerical data. Decision trees are easy to interpret and can provide valuable insight into the underlying patterns of the data. However, care must be taken to avoid overfitting and to choose appropriate hyperparameters to ensure optimal performance. With its wide range of applications, from finance to healthcare, decision tree classification has become an essential tool for any data scientist or machine learning practitioner. By understanding its strengths and limitations, we can leverage this algorithm to discover hidden patterns in our data and make informed decisions for better business outcomes.

Want to learn more?

Want to learn more about Machine Learning and Python or becoming a data scientist? Check out our training offerings, including Python for Machine Learning!

We also offer a free guide below, helping you unlock the skillsets of a data scientist!

How to Become a Data Scientist PDF

Your FREE Guide to Become a Data Scientist

Discover the path to becoming a data scientist with our comprehensive FREE guide! Unlock your potential in this in-demand field and access valuable resources to kickstart your journey.

Don’t wait, download now and transform your career!

Pierian Training
Pierian Training
Pierian Training is a leading provider of high-quality technology training, with a focus on data science and cloud computing. Pierian Training offers live instructor-led training, self-paced online video courses, and private group and cohort training programs to support enterprises looking to upskill their employees.

You May Also Like

Data Science, Tutorials

Guide to NLTK – Natural Language Toolkit for Python

Introduction Natural Language Processing (NLP) lies at the heart of countless applications we use every day, from voice assistants to spam filters and machine translation. It allows machines to understand, interpret, and generate human language, bridging the gap between humans and computers. Within the vast landscape of NLP tools and techniques, the Natural Language Toolkit […]

Machine Learning, Tutorials

GridSearchCV with Scikit-Learn and Python

Introduction In the world of machine learning, finding the optimal set of hyperparameters for a model can significantly impact its performance and accuracy. However, searching through all possible combinations manually can be an incredibly time-consuming and error-prone process. This is where GridSearchCV, a powerful tool provided by Scikit-Learn library in Python, comes to the rescue. […]

Python Basics, Tutorials

Plotting Time Series in Python: A Complete Guide

Introduction Time series data is a type of data that is collected over time at regular intervals. It can be used to analyze trends, patterns, and behaviors over time. In order to effectively analyze time series data, it is important to visualize it in a way that is easy to understand. This is where plotting […]