3D Scatter Plots in Python

Introduction

Python is a powerful programming language that has become increasingly popular for data analysis and visualization. One of the most useful tools for visualizing data is Matplotlib, a Python library that allows you to create a wide range of plots and charts. In particular, if you’re looking to create visualizations of three-dimensional data, a 3D scatter plot in Python is an excellent way to do so. In this blog post, we’ll walk you through the process of creating a 3D scatter plot using Python and Matplotlib. Whether you’re new to Python or an experienced user, this tutorial will provide you with everything you need to know to create stunning visualizations of your data using a 3D scatter plot in Python. We’ll even learn how to make animated 3D Scatter Plots that are interactive like the one shown below:

So, let’s dive in!

Understanding 3D Scatter Plots in Python

Let’s work through a basic example of a 3D Scatter Plot in Python by utilizing NumPy to create a simple example.

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Generate random data
x = np.random.normal(0, 1, 100)
y = np.random.normal(0, 1, 100)
z = np.random.normal(0, 1, 100)

# Create a 3D scatter plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z, c=z)

# Set labels and title
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
plt.title('3D Scatter Plot')

plt.show()

This will produce a plot that looks like this:

Here are the steps needed to create the plot:

  1. First, we import the necessary libraries: matplotlib.pyplot for plotting, mpl_toolkits.mplot3d for 3D projections, and numpy for generating random data.
  2. We generate random data for the x, y, and z coordinates using the numpy.random.normal() function. This function generates random numbers from a normal (Gaussian) distribution with a given mean and standard deviation. In this case, we generate 100 random numbers for each coordinate with a mean of 0 and a standard deviation of 1.
  3. We create a fig object and an ax object using the plt.subplots() function. The projection parameter is set to '3d' to indicate that we want a 3D plot.
  4. We use the scatter() function to plot the data points in 3D space. The x, y, and z parameters are arrays containing the coordinates for each point, and the c parameter is set to 'z' to indicate that we want the color of each point to be determined by its z-coordinate.
  5. We set the labels for the x, y, and z axes using the set_xlabel(), set_ylabel(), and set_zlabel() functions.
  6. Finally, we set the title of the plot using the title() function and display the plot using the show() function.

2D Data Combined with 3D Python Plot

We can also combine 2D data with a 3D Python plot to show points on a plane along with a curve. For example, consider the following code:

import numpy as np
import matplotlib.pyplot as plt

# Create a 3D plot
ax = plt.figure(dpi=150).add_subplot(projection='3d')

# Plot a sin curve using the x and y axes.
x = np.linspace(0, 1, 100)
y = np.sin(x * 3 * np.pi) / 2 + 0.6
ax.plot(x, y, zs=0, zdir='z', label='curve in (x, y)')

# Plot scatterplot data (20 2D points per colour) on the x and z axes.
colors = ('#0000FF', '#FF00FF', '#800080', '#8A2BE2')

# Fixing random state for reproducibility
np.random.seed(19680801)

# Generate random data points
x = np.random.sample(20 * len(colors))
y = np.random.sample(20 * len(colors))

# Create a list of colors corresponding to the data points
c_list = []
for c in colors:
    c_list.extend([c] * 20)

# By using zdir='y', the y value of these points is fixed to the zs value 0
# and the (x, y) points are plotted on the x and z axes.
ax.scatter(x, y, zs=0, zdir='y', c=c_list, label='points in (x, z)')

# Make legend, set axes limits and labels
ax.legend(loc=(1,1))
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_zlim(0, 1)

# Customizing Viewing Angle
ax.view_init(elev=20., azim=-35, roll=0)

# Show the plot
plt.show()

This would result in a plot that looks like this:

Saving and Sharing Python 3D Scatter Plots

The examples above we saved the plots as static images, but wouldn’t it be cool to let these plots be interactive? With the Plotly Express library we can easily pass in the data with a Pandas DataFrame and then let it be interactive for users.

First make sure to install the Plotly Express library with:

pip install plotly

Then we can simply adjust our code to pass in the data to a Pandas DataFrame and then pass that into Plotly’s scatter_3d call. For example:

import plotly.express as px
import numpy as np

# Generate random data
x = np.random.normal(0, 1, 100)
y = np.random.normal(0, 1, 100)
z = np.random.normal(0, 1, 100)

# Create a DataFrame with the data
data = {'x': x, 'y': y, 'z': z}
df = pd.DataFrame(data)

# Create the 3D scatter plot using plotly.express
fig = px.scatter_3d(df, x='x', y='y', z='z', color='z')

# Set labels and title
fig.update_layout(scene=dict(
    xaxis_title='X Label',
    yaxis_title='Y Label',
    zaxis_title='Z Label'
))
fig.update_layout(title='3D Scatter Plot')

# Display the plot
fig.show()

This will allow you to have an interactive graph as shown in the video below:

Conclusion

In conclusion, 3D scatter plots are a powerful tool in data visualization and Python’s Matplotlib library makes it very easy to create them. By adding the third dimension, we can visualize complex relationships in our data that may not be immediately evident in 2D plots. With the help of Matplotlib’s customization options, we can also make our plots informative and visually appealing. Whether you’re exploring a new dataset or presenting insights to colleagues, 3D scatter plots are a great way to showcase your findings. So why not give it a try and see what insights you can uncover? If you’re interested in learning more, check out our Python for Data Science course!


How to Become a Data Scientist PDF

Your FREE Guide to Become a Data Scientist

Discover the path to becoming a data scientist with our comprehensive FREE guide! Unlock your potential in this in-demand field and access valuable resources to kickstart your journey.

Don’t wait, download now and transform your career!


Pierian Training
Pierian Training
Pierian Training is a leading provider of high-quality technology training, with a focus on data science and cloud computing. Pierian Training offers live instructor-led training, self-paced online video courses, and private group and cohort training programs to support enterprises looking to upskill their employees.

You May Also Like

Data Science, Tutorials

Guide to NLTK – Natural Language Toolkit for Python

Introduction Natural Language Processing (NLP) lies at the heart of countless applications we use every day, from voice assistants to spam filters and machine translation. It allows machines to understand, interpret, and generate human language, bridging the gap between humans and computers. Within the vast landscape of NLP tools and techniques, the Natural Language Toolkit […]

Machine Learning, Tutorials

GridSearchCV with Scikit-Learn and Python

Introduction In the world of machine learning, finding the optimal set of hyperparameters for a model can significantly impact its performance and accuracy. However, searching through all possible combinations manually can be an incredibly time-consuming and error-prone process. This is where GridSearchCV, a powerful tool provided by Scikit-Learn library in Python, comes to the rescue. […]

Python Basics, Tutorials

Plotting Time Series in Python: A Complete Guide

Introduction Time series data is a type of data that is collected over time at regular intervals. It can be used to analyze trends, patterns, and behaviors over time. In order to effectively analyze time series data, it is important to visualize it in a way that is easy to understand. This is where plotting […]