As part of the series of tutorials on PCA with Python, we will learn how to plot a 3D PCA graph (scatter plot) on the Iris Dataset with Python, Scikit-learn and Matplotlib.

**Navigation**Show

## What is 3D PCA Scatter plot?

A 3D PCA (Principal Component Analysis) scatter plot is a PCA visualization that shows the distribution of data points in a 3D space after reducing a dataset to 3 PCA features.

## How to Plot a 3D PCA Graph in Python?

To plot a 3D PCA scatter plot in Python, reduce the number of features to 3 principal components. After, use `matplotlib`

to generate a three dimensional scatterplot from the data.

Here are the detailed steps to plot a 3D PCA scatter plot in Python:

- Load the required Python Libraries
- Load your Dataset
- Set up a 3D plotting environment
- Assign PCA Features to their own Axes of the Scatter Plot
- Plot the 3D PCA Graph using scatter3D
- Interpret the 3D PCA Scatterplot

### 1. Loading the Required Python Libraries

```
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
plt.style.use('default')
```

### 2. Loading the Iris Dataset in Python

To start, let’s load the Iris dataset in Python.

```
# load features and targets separately
iris = datasets.load_iris()
X = iris.data
y = iris.target
```

From this data, we will learn various ways to plot the 3D PCA graph with Python.

### 3. Scale and Reduce the Number of Features Using PCA

Next, do some preprocessing and use PCA to reduce the dataset to 3 features. Scale the data before applying PCA, and select the `n_component`

to be equal to 3. To learn what this means, follow our tutorial on PCA with Python.

```
# Data Scaling
x_scaled = StandardScaler().fit_transform(X)
# Reduce from 4 to 3 features with PCA
pca = PCA(n_components=3)
# Fit and transform data
pca_features = pca.fit_transform(x_scaled)
```

### 4. Set up a 3D Plotting Environment in Matplotlib

Sett up a 3D plotting environment in `matplotlib`

using `plt.axes(projection='3d')`

.

`ax = plt.axes(projection='3d')`

Let’s see an example by plotting our selected features into a 3D graph.

```
# Prepare 3D graph
fig = plt.figure()
ax = plt.axes(projection='3d')
```

### 5. Assign PCA Features to their own Axes of the Scatter Plot

Before we can plot the data, we need to set-up the data for the x, y and z axes of the 3D scatter plot. Each feature will be on its own axis.

```
# Plot scaled features
xdata = pca_features[:,0]
ydata = pca_features[:,1]
zdata = pca_features[:,2]
```

### 6. Plot the 3D PCA Graph using scatter3D

To plot the 3D PCA graph in Python, use `ax.scatter3D`

with the x, y and z data as its argument, mapping each PCA feature to its own axes in the scatter plot.

```
# Plot 3D plot
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='viridis')
# Plot title of graph
plt.title(f'3D Scatter of Iris')
# Plot x, y, z even ticks
ticks = np.linspace(-3, 3, num=5)
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_zticks(ticks)
# Plot x, y, z labels
ax.set_xlabel('sepal_length', rotation=150)
ax.set_ylabel('sepal_width')
ax.set_zlabel('petal_length', rotation=60)
plt.show()
```

### 6. Interpret the 3D PCA Scatterplot

When plotting a 3D graph, it is clearer that there is less variance in `Petal length`

of Iris flowers than in `Sepal length`

or `Sepal width`

, almost making a flat 2D pane inside the 3D graph. That shows that the intrinsic dimension of the data is essentially 2 dimensions instead of 4.

Reducing these 3 features to 2 would not only make the model faster but the visualizations more informative without losing too much information.

## Next Steps

After plotting a 3D PCA Scatterplot, it is interesting to learn how to plot a 3D PCA Biplot.

## Full Code

```
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
plt.style.use('default')
# load features and targets separately
iris = datasets.load_iris()
X = iris.data
y = iris.target
# Data Scaling
x_scaled = StandardScaler().fit_transform(X)
# Dimention Reduction
pca = PCA(n_components=3)
pca_features = pca.fit_transform(x_scaled)
# Prepare 3D graph
fig = plt.figure()
ax = plt.axes(projection='3d')
# Plot scaled features
xdata = pca_features[:,0]
ydata = pca_features[:,1]
zdata = pca_features[:,2]
# Plot 3D plot
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='viridis')
# Plot title of graph
plt.title(f'3D Scatter of Iris')
# Plot x, y, z even ticks
ticks = np.linspace(-3, 3, num=5)
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_zticks(ticks)
# Plot x, y, z labels
ax.set_xlabel('sepal_length', rotation=150)
ax.set_ylabel('sepal_width')
ax.set_zlabel('petal_length', rotation=60)
plt.show()
```

SEO Strategist at Tripadvisor, ex- Seek (Melbourne, Australia). Specialized in technical SEO. Writer in Python, Information Retrieval, SEO and machine learning. Guest author at SearchEngineJournal, SearchEngineLand and OnCrawl.