Lots of buzzwords floating around here: figures, axes, subplots, and probably a couple hundred more.

Get pumped!! Get excited!! We’re going to crush the mystery around how pandas uses matplotlib!

Our data

We’re going to be working with OECD data, specifically unemployment from 1980 to the present for Japan, Australia, USA, and Germany. I’ve edited the data so it looks a little cleaner and is a little easier to work with.

%matplotlib inline
import pandas as pd
df = pd.read_csv("../unemployment.csv")
df.head(3)
country year unemployment
0 Australia 1980 6.106246
1 Australia 1981 5.783571
2 Australia 1982 7.156132

Appearance of the mysterious variable ax

Let’s say we try to plot a line for each country over time. We’ll groupby country and give it some axes, and cry about the result.

df.groupby('country').plot(x='year', y='unemployment')
country
Australia    Axes(0.125,0.125;0.775x0.775)
Germany      Axes(0.125,0.125;0.775x0.775)
Japan        Axes(0.125,0.125;0.775x0.775)
USA          Axes(0.125,0.125;0.775x0.775)
dtype: object

png

png

png

png

So many charts! I’ve already told you how to fix it, although I didn’t explain it. It goes like this:

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
df.groupby('country').plot(x='year', y='unemployment', ax=ax, legend=False)
country
Australia    Axes(0.125,0.125;0.775x0.775)
Germany      Axes(0.125,0.125;0.775x0.775)
Japan        Axes(0.125,0.125;0.775x0.775)
USA          Axes(0.125,0.125;0.775x0.775)
dtype: object

png

Explaining fig, ax=plt.subplots

plt is PYPLOT

fig, ax = plt.subplots()

pyplot is matplotlib’s graphing framework. It’s kind of based on the programming language/environment MATLAB (which you’ve hopefully never heard of) but also kind of not based on MATLAB. This quasi-dichotomy is where all our troubles will come from.

But what’s it really do? All you need to know is it’s the graph-y and math-y parts of matplotlib and you might as well think of it as matplotlib itself. It does everything important.

When you import pyplot, you traditionally rename it as plt because programming are too lazy to type more letters than that. If you’re making a list of everything you should always import when you start a new project, it’s probably a good one.

import matplotlib.pyplot as plt

We’re going to skip to fig - we’ll talk subplots later!

fig is the FIGURE

fig, ax = plt.subplots()

First, we’ll start with fig. fig means figure, and it’s your entire graphic. You’ve actually seen it before, when doing things like figsize=(x,y)

df.groupby('country')['unemployment'].mean().sort_values().plot(kind='barh', figsize=(10,2))
<matplotlib.axes._subplots.AxesSubplot at 0x10e1ef358>

png

It’s the whole image, the whole shebang, all 10x2 of it.

The weird part is, though, your actual graph is not the figure. The figure is the part around your graph. Your chart sits on top of the figure. So what’s your visualization?

Your graph is what’s called a subplot or axis. Or, technically, an AxesSubplot.

Let’s talk about that.

ax is the AXIS or SUBPLOT

fig, ax = plt.subplots()

When we think about axes, we think about them as the lines on the left and right of our graphs. STOP DOING THAT RIGHT NOW or you’re going to get confused! In this case, an axis - or an ax, and we’ll call it - is an entire graph.

df.groupby('country').plot(x='year', y='unemployment', ax=ax, legend=False)

When we pass ax=ax to our plot, we’re saying “hey, we already have a graph made up! Please just use it instead” and then pandas/matplotlib does, instead of using a brand-new image for each.

So what’s the difference between a figure and an axis/subplot? That’s easy - figures are made up of subplots. Figures are the table that the dinner plates of subplots go on. So far we’ve only seen them made up of just one, but we can do more!

We’ll start by doing the long-hand version of what we just did.

fig = plt.figure()
ax = fig.add_subplot(111)

df.groupby('country').plot(x='year', y='unemployment', ax=ax, legend=False)
country
Australia    Axes(0.125,0.125;0.775x0.775)
Germany      Axes(0.125,0.125;0.775x0.775)
Japan        Axes(0.125,0.125;0.775x0.775)
USA          Axes(0.125,0.125;0.775x0.775)
dtype: object

png

Let’s break it down piece by piece. First, we create a new figure. That’s our canvas we’ll put our charts on.

fig = plt.figure()

Then we say hey, add a subplot in there - an actual graph! We’ll talk about the numbers in a second.

ax = fig.add_subplot(111)

And finally we give it to pandas to draw inside of so it doesn’t make a new axis/subplot each time.

df.groupby('country').plot(x='year', y='unemployment', ax=ax, legend=False)

But what’s the .add_subplot(111) thing? It means, “figure, please build a 1x1 grid and put me in the first section.”

If we want to get picky, it’s actually a shorthand version of .add_subplot(1,1,1). You don’t normally have that many subplots, though, so most people use the shorthand.

By adjusting the numbers, we can actually add multiple subplots. Let’s create a 2x1 grid and put something in the first subplot and something in the second subplot.

fig = plt.figure()

# Divide the figure into a 2x1 grid, and give me the first section
ax1 = fig.add_subplot(211)

# Divide the figure into a 2x1 grid, and give me the second section
ax2 = fig.add_subplot(212)

df.groupby('country').plot(x='year', y='unemployment', ax=ax1, legend=False)
df.groupby('country')['unemployment'].mean().sort_values().plot(kind='barh', ax=ax2)
<matplotlib.axes._subplots.AxesSubplot at 0x10ed2d1d0>

png

Why does 2x1 give you two TALL and one WIDE? I don’t know, it doesn’t make sense at all, but it’s how life is. We could change the numbers to be 121 and 122 - “build a 1x2 grid, and give me the 1st and 2nd spaces”

fig = plt.figure()

# Divide the figure into a 1x2 grid, and give me the first section
ax1 = fig.add_subplot(121)

# Divide the figure into a 1x2 grid, and give me the second section
ax2 = fig.add_subplot(122)

df.groupby('country').plot(x='year', y='unemployment', ax=ax1, legend=False)
df.groupby('country')['unemployment'].mean().sort_values().plot(kind='barh', ax=ax2)
<matplotlib.axes._subplots.AxesSubplot at 0x10f03a668>

png

Yes, yes, it doesn’t always look nice. What about a 2x2 grid, and we only take up the 1st and 4th boxes?

fig = plt.figure()

# Divide the figure into a 1x2 grid, and give me the first section
ax1 = fig.add_subplot(221)

# Divide the figure into a 1x2 grid, and give me the second section
ax2 = fig.add_subplot(224)

df.groupby('country').plot(x='year', y='unemployment', ax=ax1, legend=False)
df.groupby('country')['unemployment'].mean().sort_values().plot(kind='barh', ax=ax2)
<matplotlib.axes._subplots.AxesSubplot at 0x10f43bda0>

png

If we get tired of doing the 221 style shorthand, we could break it out into separate elements (height, width, box number), and then ask for multiple boxes. Down below we’ll build a 2x2 square and put the first graph into the first AND second boxes, and then stick the poor second graph into the fourth box.

fig = plt.figure()

# Divide the figure into a 1x2 grid, and give me the first section
ax1 = fig.add_subplot(2, 2, (1,2))

# Divide the figure into a 1x2 grid, and give me the second section
ax2 = fig.add_subplot(224)

df.groupby('country').plot(x='year', y='unemployment', ax=ax1, legend=False)
df.groupby('country')['unemployment'].mean().sort_values().plot(kind='barh', ax=ax2)
<matplotlib.axes._subplots.AxesSubplot at 0x10f8f7e80>

png

Doing it all at once

Normally you don’t use add_subplot. Why? I don’t know, people are lazy. They like to do it all at once. That’s why plt.subplots() exists.

When you do plt.subplots(), you get back the figure and any axes you created. Below we’ve made a 2x1 grid, and the axes split it evenly.

fig, (ax1, ax2) = plt.subplots(2, 1)

df.groupby('country').plot(x='year', y='unemployment', ax=ax1, legend=False)
df.groupby('country')['unemployment'].mean().sort_values().plot(kind='barh', ax=ax2)
<matplotlib.axes._subplots.AxesSubplot at 0x10fbaf400>

png

We could do the same thing side-by-side instead of top-to-bottom

fig, (ax1, ax2) = plt.subplots(1, 2)

df.groupby('country').plot(x='year', y='unemployment', ax=ax1, legend=False)
df.groupby('country')['unemployment'].mean().sort_values().plot(kind='barh', ax=ax2)
<matplotlib.axes._subplots.AxesSubplot at 0x1109770f0>

png

You can also give them a little more space to breathe and clean them up a bit.

fig, (ax1, ax2) = plt.subplots(2, 1)

df.groupby('country').plot(x='year', y='unemployment', ax=ax1, legend=False)
ax1.set_ylim((0,12))
ax1.yaxis.set_ticks([5, 10])
ax1.xaxis.set_ticks([1980,1990, 2000, 2010])
ax1.xaxis.set_label_text("")
ax1.set_title("Unemployment over time")

df.groupby('country')['unemployment'].mean().sort_values().plot(kind='barh', ax=ax2)
ax2.yaxis.set_label_text("")
ax2.set_title("Unemployment on average")
fig.subplots_adjust(hspace=0.5)

png

Not that these charts make sense, but at least you now know a little more about figures, axes and subplots.