How to Plot a 3D PCA Scatterplot (with Python Example)

As part of the series of tutorials on PCA with Python, we will learn how to plot a 3D PCA graph (scatter plot) on the Iris Dataset with Python, Scikit-learn and Matplotlib.

What is 3D PCA Scatter plot?

A 3D PCA (Principal Component Analysis) scatter plot is a PCA visualization that shows the distribution of data points in a 3D space after reducing a dataset to 3 PCA features.

How to Plot a 3D PCA Graph in Python?

To plot a 3D PCA scatter plot in Python, reduce the number of features to 3 principal components. After, use matplotlib to generate a three dimensional scatterplot from the data.


Subscribe to my Newsletter


Here are the detailed steps to plot a 3D PCA scatter plot in Python:

  1. Load the required Python Libraries
  2. Load your Dataset
  3. Set up a 3D plotting environment
  4. Assign PCA Features to their own Axes of the Scatter Plot
  5. Plot the 3D PCA Graph using scatter3D
  6. Interpret the 3D PCA Scatterplot

1. Loading the Required Python Libraries

import matplotlib.pyplot as plt 
import numpy as np
import pandas as pd 
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
plt.style.use('default')

2. Loading the Iris Dataset in Python

To start, let’s load the Iris dataset in Python.

# load features and targets separately
iris = datasets.load_iris()
X = iris.data
y = iris.target

From this data, we will learn various ways to plot the 3D PCA graph with Python.

3. Scale and Reduce the Number of Features Using PCA

Next, do some preprocessing and use PCA to reduce the dataset to 3 features. Scale the data before applying PCA, and select the n_component to be equal to 3. To learn what this means, follow our tutorial on PCA with Python.

# Data Scaling
x_scaled = StandardScaler().fit_transform(X)

# Reduce from 4 to 3 features with PCA
pca = PCA(n_components=3)

# Fit and transform data
pca_features = pca.fit_transform(x_scaled)

4. Set up a 3D Plotting Environment in Matplotlib

Sett up a 3D plotting environment in matplotlib using plt.axes(projection='3d').

ax = plt.axes(projection='3d')

Let’s see an example by plotting our selected features into a 3D graph.

# Prepare 3D graph
fig = plt.figure()
ax = plt.axes(projection='3d')

5. Assign PCA Features to their own Axes of the Scatter Plot

Before we can plot the data, we need to set-up the data for the x, y and z axes of the 3D scatter plot. Each feature will be on its own axis.

# Plot scaled features
xdata = pca_features[:,0]
ydata = pca_features[:,1]
zdata = pca_features[:,2]

6. Plot the 3D PCA Graph using scatter3D

To plot the 3D PCA graph in Python, use ax.scatter3D with the x, y and z data as its argument, mapping each PCA feature to its own axes in the scatter plot.

# Plot 3D plot
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='viridis')

# Plot title of graph
plt.title(f'3D Scatter of Iris')

# Plot x, y, z even ticks
ticks = np.linspace(-3, 3, num=5)
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_zticks(ticks)

# Plot x, y, z labels
ax.set_xlabel('sepal_length', rotation=150)
ax.set_ylabel('sepal_width')
ax.set_zlabel('petal_length', rotation=60)
plt.show()
3D scatterplot of Iris dataset

6. Interpret the 3D PCA Scatterplot

When plotting a 3D graph, it is clearer that there is less variance in Petal length of Iris flowers than in Sepal length or Sepal width, almost making a flat 2D pane inside the 3D graph. That shows that the intrinsic dimension of the data is essentially 2 dimensions instead of 4.

Reducing these 3 features to 2 would not only make the model faster but the visualizations more informative without losing too much information.

Next Steps

After plotting a 3D PCA Scatterplot, it is interesting to learn how to plot a 3D PCA Biplot.

Full Code

import matplotlib.pyplot as plt 
import numpy as np
import pandas as pd 
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
plt.style.use('default')

# load features and targets separately
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Data Scaling
x_scaled = StandardScaler().fit_transform(X)

# Dimention Reduction
pca = PCA(n_components=3)
pca_features = pca.fit_transform(x_scaled)
 

# Prepare 3D graph
fig = plt.figure()
ax = plt.axes(projection='3d')

# Plot scaled features
xdata = pca_features[:,0]
ydata = pca_features[:,1]
zdata = pca_features[:,2]

# Plot 3D plot
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='viridis')

# Plot title of graph
plt.title(f'3D Scatter of Iris')

# Plot x, y, z even ticks
ticks = np.linspace(-3, 3, num=5)
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_zticks(ticks)

# Plot x, y, z labels
ax.set_xlabel('sepal_length', rotation=150)
ax.set_ylabel('sepal_width')
ax.set_zlabel('petal_length', rotation=60)
plt.show()
Enjoyed This Post?