Issue
I am making boxplot using "iris.csv" data. I am trying to break the data into multiple dataframe by measurements (i.e petal-length, petal-width, sepal-length, sepal-width) and then make box-plot on a forloop, thereby adding subplot.
Finally, I want to add a common legend for all the box plot at once. But, I am not able to do it. I have tried several tutorials and methods using several stackoverflow questions, but i am not able to fix it.
Here is my code:
import seaborn as sns
from matplotlib import pyplot
iris_data = "iris.csv"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = read_csv(iris_data, names=names)
# Reindex the dataset by species so it can be pivoted for each species
reindexed_dataset = dataset.set_index(dataset.groupby('class').cumcount())
cols_to_pivot = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width']
# empty dataframe
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
pivoted_dataset = reindexed_dataset.pivot(columns='class', values=var_name).rename_axis(None,axis=1)
pivoted_dataset['measurement'] = var_name
reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)
## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement') :
grouped_dfs_02.append(group[1])
## make the box plot of several measured variables, compared between species
pyplot.figure(figsize=(20, 5), dpi=80)
pyplot.suptitle('Distribution of floral traits in the species of iris')
sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')
my_pal = {"Iris-versicolor": "g", "Iris-setosa": "r", "Iris-virginica":"b"}
plt_index = 0
# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):
axi = pyplot.subplot(1, len(grouped_dfs_02), plt_index + 1)
sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
df_melt = df.melt('measurement', var_name='species', value_name='values')
sns.boxplot(data=df_melt, x='species', y='values', ax = axi, orient="v", palette=my_pal)
pyplot.title(group_name)
plt_index += 1
# Move the legend to an empty part of the plot
pyplot.legend(title='species', labels = sp_name,
handles=[setosa, versi, virgi], bbox_to_anchor=(19, 4),
fancybox=True, shadow=True, ncol=5)
pyplot.show()
How, do I add a common legend to the main figure, outside the main frame, by the side of the "main suptitle"?
Solution
To position the legend, it is important to set the loc
parameter, being the anchor point. (The default loc
is 'best'
which means you don't know beforehand where it would end up). The positions are measured from 0,0
being the lower left of the current ax, to 1,1
: the upper left of the current ax. This doesn't include the padding for titles etc., so the values can go a bit outside the 0, 1
range. The "current ax" is the last one that was activated.
Note that instead of plt.legend
(which uses an axes), you could also use plt.gcf().legend
which uses the "figure". Then, the coordinates are 0,0
in lower left corner of the complete plot (the "figure") and 1,1
in the upper right. A drawback would be that no extra space would be created for the legend, so you'd need to manually set a top padding (e.g. plt.gcf().subplots_adjust(top=0.8)
). A drawback would be that you can't use plt.tight_layout()
anymore, and that it would be harder to align the legend with the axes.
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import patches as mpatches
import pandas as pd
dataset = sns.load_dataset("iris")
# Reindex the dataset by species so it can be pivoted for each species
reindexed_dataset = dataset.set_index(dataset.groupby('species').cumcount())
cols_to_pivot = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
# empty dataframe
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
pivoted_dataset = reindexed_dataset.pivot(columns='species', values=var_name).rename_axis(None, axis=1)
pivoted_dataset['measurement'] = var_name
reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)
## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement'):
grouped_dfs_02.append(group[1])
## make the box plot of several measured variables, compared between species
plt.figure(figsize=(20, 5), dpi=80)
plt.suptitle('Distribution of floral traits in the species of iris')
sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')
my_pal = {"versicolor": "g", "setosa": "r", "virginica": "b"}
plt_index = 0
# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):
axi = plt.subplot(1, len(grouped_dfs_02), plt_index + 1)
sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
df_melt = df.melt('measurement', var_name='species', value_name='values')
sns.boxplot(data=df_melt, x='species', y='values', ax=axi, orient="v", palette=my_pal)
plt.title(group_name)
plt_index += 1
# Move the legend to an empty part of the plot
plt.legend(title='species', labels=sp_name,
handles=[setosa, versi, virgi], bbox_to_anchor=(1, 1.23),
fancybox=True, shadow=True, ncol=5, loc='upper right')
plt.tight_layout()
plt.show()
Answered By - JohanC
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.