import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# Read in our data
df = pd.read_csv("../country-data.csv")
df.head(3)
Country Continent Year GDP_per_capita life_expectancy Population
0 Afghanistan Asia 1950 757.318795 26.674 8151455
1 Afghanistan Asia 1951 766.752197 26.932 8276820
2 Afghanistan Asia 1952 779.445314 27.448 8407148

Small multiples with plt.subplots

There are a few ways to make small multiples using pandas/matplotlib.

We’ve been using plt.subplots so far to yell at matplotlib, “hey, prepare a graph!”. Then when we use df.plot we pass ax to put all of our data into that one particular graph.

# Have one subplot
fig, ax = plt.subplots()
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, legend=False)
ax.set_title("Bhutan")
<matplotlib.text.Text at 0x10b0d04a8>

png

Passing ax around

If we use .plot twice but give them both the same ax, the elements will be plotted on the same graph.

# One subplot again
fig, ax = plt.subplots()

# Use ax for both
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, label='Bhutan')
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', ax=ax, label='Iran')
ax.set_title("Iran and Bhutan")
<matplotlib.text.Text at 0x10d3d8e48>

png

Having multiple ax

We can receive multiple ax elements from .subplots. Below we’re using nrows= and ncols to ask for two rows of graphics, each row having one column.

Note: The next one is nicer than this one because it shares x and y axes.

# Asking for TWO subplots, ax1 and ax2.
# Be sure to put them in parenthesis
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

png

See how it looks like they’re both making a lot of money in the end? Unfortunately that’s not true. If you look at the y-axis labels, you’ll see Iran peaks at around a GDP of $13k Bhutan only gets up to about $6k. In order to make the x and y axes match up, you need to pass sharex and sharey to your plt.subplots.

# Receive ax1 and ax2 - note that they go in parens
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True)

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

png

Expanding with nrows and ncols

You could do this with a million different graphics!

# Beacuse I'm asking for two rows of three columns each,
# I need to separate them out with even MORE parentheses
# Using figsize to make the figure a little bigger, 10"x5"
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(10,5))

# Doing each of these manually (ugh)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")
df[df['Country'] == 'France'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax3)
ax3.set_title("France")
df[df['Country'] == 'Ireland'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax4)
ax4.set_title("Ireland")
df[df['Country'] == 'Kazakhstan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax5)
ax5.set_title("Kazakhstan")
df[df['Country'] == 'United Arab Emirates'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax6)
ax6.set_title("United Arab Emirates")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

png

Simplifying

That’s a little too complicated for my tastes, though. How are you going to get all of those into a loop? Short answer: you aren’t. Let’s try it a different way.

Instead of getting all of the subplots at once, we’ll get them one at a time by using plt.subplot, the singular version of plt.subplots.

# 1 row, 1 column, and we'd like the first element.
ax = plt.subplot(1, 1, 1)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, legend=False)
ax.set_title("Bhutan")
<matplotlib.text.Text at 0x10d567f60>

png

# 1 row, 2 columns, and we'd like the first element.
ax = plt.subplot(1, 2, 1)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, legend=False)
ax.set_title("Bhutan")
<matplotlib.text.Text at 0x10d2edf60>

png

# 1 row, 2 columns, and we'd like the second element.
ax1 = plt.subplot(1, 2, 1)
df[df['Country'] == 'Belarus'].plot(x='Year', y='GDP_per_capita', ax=ax1, legend=False)
ax1.set_title("Belarus")

# 1 row, 2 columns, and we'd like the first element.
ax2 = plt.subplot(1, 2, 2)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax2, legend=False)
ax2.set_title("Bhutan")
<matplotlib.text.Text at 0x10dde6588>

png

Make it a loop

len(df.groupby("Country"))
188

So we need 188 different graphs. If we put 15 columns on each row, that’s 12.53 rows - round that up to 13.

# Make the graph 20 inches by 40 inches
plt.figure(figsize=(20,40), facecolor='white')

# plot numbering starts at 1, not 0
plot_number = 1
for countryname, selection in df.groupby("Country"):
    # Inside of an image that's a 15x13 grid, put this
    # graph in the in the plot_number slot.
    ax = plt.subplot(15, 13, plot_number)
    selection.plot(x='Year', y='GDP_per_capita', ax=ax, label=countryname, legend=False)
    ax.set_title(countryname)
    # Go to the next plot for the next loop
    plot_number = plot_number + 1
plt.tight_layout()