Issue
I need to create a bar chart using python, where I can group the X axis by 2 variables. The data I need to graph is the following
Country | City | Number of Universities |
---|---|---|
Germany | Berlin | 30 |
Germany | Munich | 20 |
Germany | Hamburg | 10 |
France | Paris | 40 |
France | Marseille | 5 |
France | Lyon | 10 |
France | Nice | 5 |
Spain | Madrid | 25 |
Spain | Barcelona | 15 |
Spain | Valencia | 10 |
Spain | Seville | 7 |
Denmark | Copenhagen | 10 |
Denmark | Aarhus | 5 |
Italy | Rome | 20 |
Italy | Milan | 15 |
Italy | Naples | 8 |
Italy | Florence | 7 |
Austria | Vienna | 12 |
Austria | Salzburg | 4 |
I share the code to create it below:
import pandas as pd
# Create a dictionary with the data
data = {
'Country': ['Germany', 'Germany', 'Germany', 'France', 'France', 'France', 'France',
'Spain', 'Spain', 'Spain', 'Spain', 'Denmark', 'Denmark', 'Italy',
'Italy', 'Italy', 'Italy', 'Austria', 'Austria'],
'City': ['Berlin', 'Munich', 'Hamburg', 'Paris', 'Marseille', 'Lyon', 'Nice',
'Madrid', 'Barcelona', 'Valencia', 'Seville', 'Copenhagen', 'Aarhus',
'Rome', 'Milan', 'Naples', 'Florence', 'Vienna', 'Salzburg']
}
# Convert the dictionary into a pandas DataFrame
df = pd.DataFrame(data)
# Count the number of universities per city and add it as a new column
universities_count = [30, 20, 10, 40, 5, 10, 5, 25, 15, 10, 7, 10, 5, 20, 15, 8, 7, 12, 4]
df['Number of Universities'] = universities_count
# Show the DataFrame
print(df)
The objective is to generate the following graph:
in advance, thank you very much for your help
Solution
Depending on how closely you want to replicate this, you'll need an approach that can calculate the necessary height/placement of your group labels and tick lines.
I wrote a helper function to do this that calculates the current height (in points) of the xaxis of an inputted Axes
.
After that, we effectively need 3 sets of Axes.xaxis
to work with:
- For the individual
'City'
level ticks - For the grouped
'Country'
level ticks - For the lines that span from the
Axes
down to the bottom of the'Country'
level labels.
import pandas as pd
import numpy as np
df = pd.DataFrame({
'Country': [
'Germany', 'Germany', 'Germany', 'France', 'France', 'France', 'France',
'Spain', 'Spain', 'Spain', 'Spain', 'Denmark', 'Denmark', 'Italy',
'Italy', 'Italy', 'Italy', 'Austria', 'Austria'
],
'City': [
'Berlin', 'Munich', 'Hamburg', 'Paris', 'Marseille', 'Lyon', 'Nice',
'Madrid', 'Barcelona', 'Valencia', 'Seville', 'Copenhagen', 'Aarhus',
'Rome', 'Milan', 'Naples', 'Florence', 'Vienna', 'Salzburg'
],
'Number of Universities': [
30, 20, 10, 40, 5, 10, 5, 25, 15, 10, 7, 10, 5, 20, 15, 8, 7, 12, 4
]
})
plot_df = df.sort_values(['Country', 'City'])
###
def get_xaxis_height(ax):
height = 0
axes = [ax] + ax.child_axes
for ax in axes:
height += ax.xaxis.get_tightbbox().height
height += ax.xaxis.get_tick_params()['pad']
return height * 72 / fig.dpi
from matplotlib import pyplot as plt
plt.rc('font', size=12)
## Create base chart
fig, ax = plt.subplots(figsize=(12, 8))
ax.spines[['left', 'top', 'right']].set_visible(False) # turn off all spines
bc = ax.bar('City', 'Number of Universities', data=plot_df, width=.6)
ax.bar_label(bc)
ax.xaxis.set_tick_params(
rotation=90,
bottom=False,
length=0,
pad=1, # adjust pad to move individual labels further/closer to bottom spine
)
ax.yaxis.set_tick_params(left=False)
## Add group labels underneath existing rotated labels
label_locs = (
plot_df.assign(tick_loc=np.arange(len(plot_df)))
.groupby('Country')['tick_loc']
.mean()
)
ax_bottom = get_xaxis_height(ax)
group_label_ax = ax.secondary_xaxis(location='bottom')
group_label_ax.set_xticks(label_locs, labels=label_locs.index, ha='center')
group_label_ax.tick_params(
bottom=False,
pad=10, # adjust pad to move your group labels further/closer to the individual labels
length=ax_bottom
)
## add long tick lines where needed
line_locs = (
plot_df.assign(tick_loc=np.arange(len(plot_df)))
.loc[lambda d:
d['Country'] != d['Country'].shift(), 'tick_loc'
]
- 0.5
).tolist()
line_locs += [len(df) - .5]
ax_bottom = get_xaxis_height(ax)
tickline_ax = ax.secondary_xaxis(location='bottom')
tickline_ax.set_xticks(line_locs)
tickline_ax.tick_params(labelbottom=False, length=ax_bottom, pad=0)
ax.set_xlim(-.5, len(ax.containers[0]) - .5)
## adjust spine & tick colors
ax.spines['bottom'].set_color('gainsboro')
tickline_ax.xaxis.set_tick_params(
color='gainsboro',
labelcolor='black',
width=ax.spines['bottom'].get_linewidth() * 2
)
## adjust y-ticks to be multiples of 5
from matplotlib.ticker import MultipleLocator
ax.yaxis.set_major_locator(MultipleLocator(5))
ax.margins(y=.2)
fig.tight_layout()
plt.show()
Answered By - Cameron Riddell
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.