import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# Read in our data
df = pd.read_csv("../country-data.csv")
df.head(3)
Country Continent Year GDP_per_capita life_expectancy Population
0 Afghanistan Asia 1950 757.318795 26.674 8151455
1 Afghanistan Asia 1951 766.752197 26.932 8276820
2 Afghanistan Asia 1952 779.445314 27.448 8407148

Small multiples with plt.subplots

There are a few ways to make small multiples using pandas/matplotlib.

We’ve been using plt.subplots so far to yell at matplotlib, “hey, prepare a graph!”. Then when we use df.plot we pass ax to put all of our data into that one particular graph.

# Have one subplot
fig, ax = plt.subplots()
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, legend=False)
ax.set_title("Bhutan")
<matplotlib.text.Text at 0x10b0d04a8>

png

Passing ax around

If we use .plot twice but give them both the same ax, the elements will be plotted on the same graph.

# One subplot again
fig, ax = plt.subplots()

# Use ax for both
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, label='Bhutan')
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', ax=ax, label='Iran')
ax.set_title("Iran and Bhutan")
<matplotlib.text.Text at 0x10d3d8e48>

png

Having multiple ax

We can receive multiple ax elements from .subplots. Below we’re using nrows= and ncols to ask for two rows of graphics, each row having one column.

Note: The next one is nicer than this one because it shares x and y axes.

# Asking for TWO subplots, ax1 and ax2.
# Be sure to put them in parenthesis
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

png

See how it looks like they’re both making a lot of money in the end? Unfortunately that’s not true. If you look at the y-axis labels, you’ll see Iran peaks at around a GDP of $13k Bhutan only gets up to about $6k. In order to make the x and y axes match up, you need to pass sharex and sharey to your plt.subplots.

# Receive ax1 and ax2 - note that they go in parens
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True)

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

png

Expanding with nrows and ncols

You could do this with a million different graphics!

# Beacuse I'm asking for two rows of three columns each,
# I need to separate them out with even MORE parentheses
# Using figsize to make the figure a little bigger, 10"x5"
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(10,5))

# Doing each of these manually (ugh)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")
df[df['Country'] == 'France'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax3)
ax3.set_title("France")
df[df['Country'] == 'Ireland'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax4)
ax4.set_title("Ireland")
df[df['Country'] == 'Kazakhstan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax5)
ax5.set_title("Kazakhstan")
df[df['Country'] == 'United Arab Emirates'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax6)
ax6.set_title("United Arab Emirates")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

png

Simplifying

That’s a little too complicated for my tastes, though. How are you going to get all of those into a loop? Short answer: you aren’t. Let’s try it a different way.

Instead of getting all of the subplots at once, we’ll get them one at a time by using plt.subplot, the singular version of plt.subplots.

# 1 row, 1 column, and we'd like the first element.
ax = plt.subplot(1, 1, 1)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, legend=False)
ax.set_title("Bhutan")
<matplotlib.text.Text at 0x10d567f60>

png

# 1 row, 2 columns, and we'd like the first element.
ax = plt.subplot(1, 2, 1)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, legend=False)
ax.set_title("Bhutan")
<matplotlib.text.Text at 0x10d2edf60>

png

# 1 row, 2 columns, and we'd like the second element.
ax1 = plt.subplot(1, 2, 1)
df[df['Country'] == 'Belarus'].plot(x='Year', y='GDP_per_capita', ax=ax1, legend=False)
ax1.set_title("Belarus")

# 1 row, 2 columns, and we'd like the first element.
ax2 = plt.subplot(1, 2, 2)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax2, legend=False)
ax2.set_title("Bhutan")
<matplotlib.text.Text at 0x10dde6588>

png

Make it a loop

len(df.groupby("Country"))
188

So we need 188 different graphs. If we put 15 columns on each row, that’s 12.53 rows - round that up to 13.

# Make the graph 20 inches by 40 inches
plt.figure(figsize=(20,40), facecolor='white')

# plot numbering starts at 1, not 0
plot_number = 1
for countryname, selection in df.groupby("Country"):
    # Inside of an image that's a 15x13 grid, put this
    # graph in the in the plot_number slot.
    ax = plt.subplot(15, 13, plot_number)
    selection.plot(x='Year', y='GDP_per_capita', ax=ax, label=countryname, legend=False)
    ax.set_title(countryname)
    # Go to the next plot for the next loop
    plot_number = plot_number + 1
plt.tight_layout()

png

I take it back

Maybe the best way to do this is actually to use subplots! With the sharex and sharey it certainly seems more effective. We’ll just need a way to pull off one subplot at a time, instead of doing a huge big ((ax1, ax...)) disaster.

# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True)
axes
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x10fc3dda0>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x10ff06b70>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x10ff8e2e8>],
       [<matplotlib.axes._subplots.AxesSubplot object at 0x10ffc6ba8>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x110115b00>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x1101545f8>],
       [<matplotlib.axes._subplots.AxesSubplot object at 0x11019c390>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x1101ac160>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x110228828>]], dtype=object)

png

Right now it’s 3 lists of 3 axes, which will be hard to loop over.

axes
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x10fc3dda0>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x10ff06b70>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x10ff8e2e8>],
       [<matplotlib.axes._subplots.AxesSubplot object at 0x10ffc6ba8>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x110115b00>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x1101545f8>],
       [<matplotlib.axes._subplots.AxesSubplot object at 0x11019c390>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x1101ac160>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x110228828>]], dtype=object)

Luckily we can easily convert it to just one long list using a weird list comprehension

# http://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
[item for sublist in axes for item in sublist] 
[<matplotlib.axes._subplots.AxesSubplot at 0x10fc3dda0>,
 <matplotlib.axes._subplots.AxesSubplot at 0x10ff06b70>,
 <matplotlib.axes._subplots.AxesSubplot at 0x10ff8e2e8>,
 <matplotlib.axes._subplots.AxesSubplot at 0x10ffc6ba8>,
 <matplotlib.axes._subplots.AxesSubplot at 0x110115b00>,
 <matplotlib.axes._subplots.AxesSubplot at 0x1101545f8>,
 <matplotlib.axes._subplots.AxesSubplot at 0x11019c390>,
 <matplotlib.axes._subplots.AxesSubplot at 0x1101ac160>,
 <matplotlib.axes._subplots.AxesSubplot at 0x110228828>]

And take parts off one at a time

axes_list = [item for sublist in axes for item in sublist] 
# We have 9
len(axes_list)
9
# Remove the first one, save it as 'ax'
ax = axes_list.pop(0)
ax
<matplotlib.axes._subplots.AxesSubplot at 0x10fc3dda0>
# Only 8 left now
len(axes_list)
8
ax = axes_list.pop(0)
ax
<matplotlib.axes._subplots.AxesSubplot at 0x10ff06b70>
# Only have 7 left now
len(axes_list)
7

And we can just keep on doing down, plucking them off one at a time.

Putting it together

# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True, figsize=(10,5))
axes_list = [item for sublist in axes for item in sublist] 

for countryname, selection in df.head(1200).groupby("Country"):
    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='both',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()

png

# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True, figsize=(10,5))
axes_list = [item for sublist in axes for item in sublist] 

for countryname, selection in df.head(1200).groupby("Country"):
    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='both',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.grid(linewidth=0.25)
    ax.set_ylim((0, 15000))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()

png

# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=2, ncols=5, sharex=False, sharey='row', figsize=(10,5))
axes_list = [item for sublist in axes for item in sublist] 

for countryname, selection in df.head(1200).groupby("Country"):
    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False, clip_on=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='both',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.grid(linewidth=0.25)
    ax.set_xlim((1950, 2020))
    ax.set_xlabel("")
    ax.set_xticks(range(1950, 2015, 25))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()

plt.tight_layout()

png

# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=4, ncols=5, sharex=True, sharey='row', figsize=(10,10))
axes_list = [item for sublist in axes for item in sublist] 

for countryname, selection in df.head(1200).groupby("Country"):
    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False, clip_on=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='both',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.grid(linewidth=0.25)
    ax.set_xlim((1950, 2020))
    ax.set_xlabel("")
    ax.set_xticks(range(1950, 2015, 25))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()

plt.tight_layout()

png

# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=4, ncols=5, sharex=True, sharey=True, figsize=(10,10))
axes_list = [item for sublist in axes for item in sublist] 

for countryname, selection in df.head(1200).groupby("Country"):
    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False, clip_on=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='both',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.grid(linewidth=0.25)
    ax.set_xlim((1950, 2020))
    ax.set_xlabel("")
    ax.set_xticks(range(1950, 2015, 25))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()

plt.tight_layout()

png

Let’s put them in order

sample_df = df.head(1200)
# We could order by the max or the min but I'm
# going to get the LAST VALUE. using... .last().
# You could also use .max() or whatever.
sample_df.groupby("Country")['GDP_per_capita'].last().sort_values(ascending=False)
Country
Austria                36731.628770
Australia              36064.737280
Belgium                32585.011970
Bahrain                24472.896240
Bahamas                21418.229810
Barbados               16233.413790
Argentina              15714.103180
Antigua and Barbuda    13722.986160
Belarus                13515.161030
Azerbaijan              9291.026270
Belize                  7550.138241
Albania                 6969.306283
Algeria                 6419.127829
Bhutan                  6130.862355
Angola                  5838.155376
Armenia                 5059.087964
Bolivia                 2677.326347
Bangladesh              1792.550235
Benin                   1464.138255
Afghanistan             1349.696941
Name: GDP_per_capita, dtype: float64
# This is an ordered list of the names
sample_df.groupby("Country")['GDP_per_capita'].last().sort_values(ascending=False).index
Index(['Austria', 'Australia', 'Belgium', 'Bahrain', 'Bahamas', 'Barbados',
       'Argentina', 'Antigua and Barbuda', 'Belarus', 'Azerbaijan', 'Belize',
       'Albania', 'Algeria', 'Bhutan', 'Angola', 'Armenia', 'Bolivia',
       'Bangladesh', 'Benin', 'Afghanistan'],
      dtype='object', name='Country')
# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=4, ncols=5, sharex=True, sharey=True, figsize=(10,10))
axes_list = [item for sublist in axes for item in sublist] 

ordered_country_names = grouped['GDP_per_capita'].last().sort_values(ascending=False).index

# Now instead of looping through the groupby
# you CREATE the groupby
# you LOOP through the ordered names
# and you use .get_group to get the right group
grouped = df.head(1200).groupby("Country")

for countryname in ordered_country_names:
    selection = grouped.get_group(countryname)
    
    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='both',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.grid(linewidth=0.25)
    ax.set_xlim((1950, 2016))
    ax.set_xlabel("")
    ax.set_xticks(range(1950, 2015, 25))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()

plt.tight_layout()

png

# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=6, ncols=8, sharex=True, sharey=True, figsize=(18,10))
axes_list = [item for sublist in axes for item in sublist] 

ordered_country_names = grouped['GDP_per_capita'].last().sort_values(ascending=False).index

# Now instead of looping through the groupby
# you CREATE the groupby
# you LOOP through the ordered names
# and you use .get_group to get the right group
grouped = df.head(3000).groupby("Country")

first_year = df['Year'].min()
last_year = df['Year'].max()

for countryname in ordered_country_names:
    selection = grouped.get_group(countryname)
    
    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='both',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.grid(linewidth=0.25)
    ax.set_xlim((first_year, last_year))
    ax.set_xlabel("")
    ax.set_xticks((first_year, last_year))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()

plt.subplots_adjust(hspace=1)

plt.tight_layout()

png

df[df['Country'] == 'Cape Verde']['Year'].max()
2012
df[(df['Country'] == 'Cape Verde')][df['Year'] == 2012]
/Library/Frameworks/Python.framework/Versions/3.4/lib/python3.4/site-packages/ipykernel/__main__.py:1: UserWarning: Boolean Series key will be reindexed to match DataFrame index.
  if __name__ == '__main__':
Country Continent Year GDP_per_capita life_expectancy Population
1952 Cape Verde Africa 2012 3896.041139 74.771 505335
# So we want to plot at 
max_year = df[df['Country'] == 'Cape Verde']['Year'].max()
gdp_value = float(df[df['Country'] == 'Cape Verde'][df['Year'] == 2012]['GDP_per_capita'])
/Library/Frameworks/Python.framework/Versions/3.4/lib/python3.4/site-packages/ipykernel/__main__.py:3: UserWarning: Boolean Series key will be reindexed to match DataFrame index.
  app.launch_new_instance()
ax = df[df['Country'] == 'Cape Verde'].plot(x='Year', y='GDP_per_capita', legend=False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(
    which='both',
    bottom='off',
    left='off',
    right='off',
    top='off'
)
plt.scatter(x= [max_year], y=[gdp_value], s=70, clip_on=False, linewidth=0)
plt.annotate(int(gdp_value), xy=[max_year, gdp_value], xytext=[7, -2], textcoords='offset points')
<matplotlib.text.Annotation at 0x111e4b240>

png

# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=6, ncols=8, sharex=True, sharey=True, figsize=(18,10))
axes_list = [item for sublist in axes for item in sublist] 

ordered_country_names = grouped['GDP_per_capita'].last().sort_values(ascending=False).index

# Now instead of looping through the groupby
# you CREATE the groupby
# you LOOP through the ordered names
# and you use .get_group to get the right group
grouped = df.head(3000).groupby("Country")

first_year = df['Year'].min()
last_year = df['Year'].max()

for countryname in ordered_country_names:
    selection = grouped.get_group(countryname)

    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='both',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.grid(linewidth=0.25)
    ax.set_xlim((first_year, last_year))
    ax.set_xlabel("")
    ax.set_xticks((first_year, last_year))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    max_year = selection['Year'].max()
    gdp_value = float(selection[df['Year'] == max_year]['GDP_per_capita'])
    ax.set_ylim((0, 100000))
    ax.scatter(x=[max_year], y=[gdp_value], s=70, clip_on=False, linewidth=0)
    ax.annotate(str(int(gdp_value / 1000)) + "k", xy=[max_year, gdp_value], xytext=[7, -2], textcoords='offset points')

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()
    
plt.tight_layout()
plt.subplots_adjust(hspace=1)
/Library/Frameworks/Python.framework/Versions/3.4/lib/python3.4/site-packages/ipykernel/__main__.py:38: UserWarning: Boolean Series key will be reindexed to match DataFrame index.

png

ax.scatter