Confusion Matrix with Scikit-Learn and Python

Introduction

A confusion matrix is a useful tool for evaluating the performance of a classification model. The matrix provides an insight into how well the model has classified the data by comparing its predictions to the actual values. Understanding and interpreting confusion matrices can be challenging, especially for beginners in machine learning. However, it is crucial to comprehend what each cell represents since it helps you assess your model’s strengths and weaknesses.

In this blog post, we will discuss confusion matrices in-depth using Scikit-learn library in Python. We will take you through what exactly a confusion matrix is, why it’s important, and how to interpret its results effectively. Additionally, we’ll walk you through building a simple machine learning model that classifies flower species based on their measurements as an example.

What is a Confusion Matrix

A confusion matrix is a table used to evaluate the performance of a machine learning algorithm. It shows how many samples were correctly or incorrectly classified by the algorithm in each class.

The confusion matrix has two dimensions: actual and predicted. In binary classification, where there are only two classes (positive and negative), it looks like this:

Predicted Positive Predicted Negative
Actual Positive True Positive (TP) False Negative (FN)
Actual Negative False Positive (FP) True Negative (TN)

Let’s consider a binary classification problem where we have two classes, “Positive” and “Negative”.

  • True Positive (TP): This is when the model correctly predicts that an instance belongs to the positive class when it actually does. In other words, TP refers to the number of positive instances that are correctly predicted as positive by the model.
  • True Negative (TN): This is when the model correctly predicts that an instance belongs to the negative class when it actually does. In other words, TN refers to the number of negative instances that are correctly predicted as negative by the model.
  • False Positive (FP): This is when the model incorrectly predicts that an instance belongs to the positive class when it actually belongs to the negative class. In other words, FP refers to the number of negative instances that are incorrectly predicted as positive by the model.
  • False Negative (FN): This is when the model incorrectly predicts that an instance belongs to the negative class when it actually belongs to the positive class. In other words, FN refers to the number of positive instances that are incorrectly predicted as negative by the model.

Confusion Matrix Use Cases

A confusion matrix is a commonly used tool in machine learning to evaluate the performance of a classification model. Here are some real-world or business use cases where a confusion matrix can be helpful:

  1. Fraud Detection: A bank uses a machine learning model to identify fraudulent transactions. The confusion matrix helps the bank understand how well the model is performing by showing the number of true positives, true negatives, false positives, and false negatives.
  2. Medical Diagnosis: A hospital uses a machine learning model to diagnose patients with a certain disease. The confusion matrix helps doctors understand how accurate the model is by showing the number of true positives, true negatives, false positives, and false negatives.
  3. Customer Churn Prediction: A company uses a machine learning model to predict which customers are likely to churn (stop using their service). The confusion matrix helps the company understand how well the model is performing by showing the number of true positives, true negatives, false positives, and false negatives.
  4. Sentiment Analysis: A social media platform uses a machine learning model to analyze user comments and determine if they are positive or negative. The confusion matrix helps the platform understand how accurate the model is by showing the number of true positives, true negatives, false positives, and false negatives.
  5. Image Classification: An e-commerce website uses a machine learning model to automatically classify product images into different categories like apparel or electronics. The confusion matrix helps them understand how well their image classification algorithm is performing by showing the number of true positives, true negatives, false positives and false negatives for each category.

Binary Classification Example with Scikit-Learn

Let’s work through an example of binary classification on the Breast Cancer data set from Scikit-Learn.

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix

# Load the breast cancer dataset
data = load_breast_cancer()

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, random_state=0)

# Train a logistic regression model on the training data
model = LogisticRegression()
model.fit(X_train, y_train)

# Predict the test data using the trained model
y_pred = model.predict(X_test)

# Calculate the confusion matrix of predicted vs actual values
cm = confusion_matrix(y_test, y_pred)

print("Confusion Matrix:")
print(cm)

# OUTPUT WILL LOOK LIKE THIS:
#Confusion Matrix:
# [[51  2]
# [ 6 84]]

Using confusion_matrix from Scikit-Learn will just display the Numpy array of the confusion matrix, but if want to plot it, we can use the plot_confusion_matrix function from Scikit-Learn like this:

from sklearn.metrics import plot_confusion_matrix
plot_confusion_matrix(model, X_test, y_test)

Which will output a confusion matrix plot using Matplotlib:

Multi-Class Confusion Matrix with Scikit-Learn

Now let’s explore an example of a confusion matrix that uses multiple classes, not just 2 binary classes.

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt

# Load the iris dataset
iris = load_iris()
X = iris.data
y = iris.target

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

# Train a decision tree classifier on the training set
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)

# Make predictions on the test set
y_pred = clf.predict(X_test)

# Calculate the confusion matrix and plot it using scikit-learn's plot_confusion_matrix function.
cm = confusion_matrix(y_test, y_pred)
plot_confusion_matrix(clf, X_test, y_test)  

This will create the following plot:

The rows represent the true classes and columns represent predicted classes. The diagonal elements (top left to bottom right) show the number of correct predictions for each class. Off-diagonal elements show misclassifications. For example, in our case we can see that 15 samples from class 0 were predicted correctly with no mistakes, however we can see that class 1 had two misclassifications as class 2.

We can use this information to evaluate how well our model is performing on each class and identify any patterns or biases in its predictions.

Understanding Classification Error Metrics

Related to the confusion matrix metrics such as False Negatives and False Positives, are classification metrics that combine these values to inform the user of classification performance. The most commonly used metrics are accuracy, precision, recall, and F-1 Score. Let’s explore each of these!

Accuracy

Accuracy is defined as the ratio of correctly classified samples to the total number of samples in the dataset.

Mathematically, accuracy can be expressed as:

Accuracy = (Number of correctly classified samples) / (Total number of samples)

For example, if we have a dataset with 1000 samples and our model correctly classifies 900 out of those 1000 samples, then the accuracy of our model would be:

Accuracy = 900/1000 = 0.9 or 90%

In other words, our model has an accuracy rate of 90%, meaning that it is able to correctly classify 90% of the samples in the dataset.

While accuracy is a commonly used metric for evaluating classification models, it may not always be the best metric to use. For instance, if we have an imbalanced dataset where there are significantly more examples from one class than another, then even a simple model that always predicts the majority class will have high accuracy but may not be useful in practice. Therefore, it’s important to consider other metrics such as precision and recall alongside accuracy when evaluating classification models.

Precision

Precision is defined as the ratio of true positive samples to all the predicted positive samples.

Mathematically, precision can be expressed as:

Precision = (True positives) / (True positives + False positives)

where True positives are the number of correctly classified positive samples and False positives are the number of negative samples that were incorrectly classified as positive.

For example, if our model has predicted 100 samples as positive out of which 80 are actually positive and 20 are negative, then the precision of our model would be:

Precision = 80/(80+20) = 0.8 or 80%

This means that out of all the samples our model predicted as positive, it was able to correctly classify 80% of them as truly positive.

Precision is an important metric when we want to avoid false positives. For instance, in a medical diagnosis scenario where we want to identify patients who have a particular disease, we may want to ensure that we don’t falsely classify healthy patients as having the disease. In such cases, we would aim for a high precision value.

However, optimizing for high precision may lead to low recall (the ability to detect all relevant cases), and vice versa. Therefore, it’s important to consider both metrics together when evaluating classification models.

Recall

Recall is a metric that measures how well a model is able to correctly identify all positive samples from the total number of positive samples in the dataset. It is defined as the ratio of true positive samples to all positive samples.

Mathematically, recall can be expressed as:

Recall = (True positives) / (True positives + False negatives)

where True positives are the number of correctly classified positive samples and False negatives are the number of positive samples that were incorrectly classified as negative.

For example, if our model has correctly identified 80 out of 100 positive samples in our dataset and missed 20 positive cases, then the recall value would be:

Recall = 80/(80+20) = 0.8 or 80%

This means that our model was able to correctly identify 80% of all the actual positive cases in our dataset.

Recall is an important metric when we want to avoid false negatives. For instance, in a medical diagnosis scenario where we want to identify patients who have a particular disease, we may want to ensure that we don’t miss any patients who actually have the disease. In such cases, we would aim for a high recall value.

F1-Score

F1-score is a metric that combines both precision and recall into a single score. It is the harmonic mean of precision and recall, and provides a balance between the two metrics.

Mathematically, F1-score can be expressed as:

F1-score = 2 * ((Precision * Recall) / (Precision + Recall))

where Precision is the ratio of true positive samples to all predicted positive samples, and Recall is the ratio of true positive samples to all actual positive samples.

For example, if our model has precision of 0.8 (80%) and recall of 0.85 (85%), then the F1-score would be:

F1-score = 2((0.80.85)/(0.8+0.85)) = 0.82 or 82%

This means that our model has an overall performance score of 82%, which takes into account both precision and recall.

F1-score is useful when we want to have a balance between precision and recall, especially in cases where both false positives and false negatives are equally important to avoid. For instance, in spam email classification where we want to avoid classifying legitimate emails as spam (false positive) as well as missing out on actual spam emails (false negative), we would aim for a high F1-score value.

Conclusion

In conclusion, a confusion matrix is an essential tool in evaluating the performance of a machine learning model. It provides insights into the accuracy and precision of predictions made by the model, allowing developers to identify areas that need improvement. By understanding the elements that make up a confusion matrix, developers can effectively evaluate their models’ performance and make informed decisions on how to optimize them. While confusion matrices may seem complex at first, taking the time to understand them will go a long way in creating accurate and effective machine learning models.

You can learn more about all of this in our Python for Machine Learning course! Also check out our free downloadable material below:


How to Become a Data Scientist PDF

Your FREE Guide to Become a Data Scientist

Discover the path to becoming a data scientist with our comprehensive FREE guide! Unlock your potential in this in-demand field and access valuable resources to kickstart your journey.

Don’t wait, download now and transform your career!


Pierian Training
Pierian Training
Pierian Training is a leading provider of high-quality technology training, with a focus on data science and cloud computing. Pierian Training offers live instructor-led training, self-paced online video courses, and private group and cohort training programs to support enterprises looking to upskill their employees.

You May Also Like

Data Science, Tutorials

Guide to NLTK – Natural Language Toolkit for Python

Introduction Natural Language Processing (NLP) lies at the heart of countless applications we use every day, from voice assistants to spam filters and machine translation. It allows machines to understand, interpret, and generate human language, bridging the gap between humans and computers. Within the vast landscape of NLP tools and techniques, the Natural Language Toolkit […]

Machine Learning, Tutorials

GridSearchCV with Scikit-Learn and Python

Introduction In the world of machine learning, finding the optimal set of hyperparameters for a model can significantly impact its performance and accuracy. However, searching through all possible combinations manually can be an incredibly time-consuming and error-prone process. This is where GridSearchCV, a powerful tool provided by Scikit-Learn library in Python, comes to the rescue. […]

Python Basics, Tutorials

Plotting Time Series in Python: A Complete Guide

Introduction Time series data is a type of data that is collected over time at regular intervals. It can be used to analyze trends, patterns, and behaviors over time. In order to effectively analyze time series data, it is important to visualize it in a way that is easy to understand. This is where plotting […]