Issue
I am learning the breast cancer classification dataset in python. I am trying to plot histograms for each features, how am I able to arrange those histograms into three groups? Like the following screenshot:
What I am trying to achieve
Here is the code I used:
from sklearn.datasets import load_breast_cancer # sample data
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
data = load_breast_cancer()
# Turn the feature data into a dataframe
df = pd.DataFrame(data.data, columns = data.feature_names)
# Add the target columns, and fill it with the target data
df["target"] = data.target
# display(df.head())
mean radius mean texture mean perimeter mean area mean smoothness mean compactness mean concavity mean concave points mean symmetry mean fractal dimension radius error texture error perimeter error area error smoothness error compactness error concavity error concave points error symmetry error fractal dimension error worst radius worst texture worst perimeter worst area worst smoothness worst compactness worst concavity worst concave points worst symmetry worst fractal dimension target
0 17.99 10.38 122.80 1001.0 0.11840 0.27760 0.3001 0.14710 0.2419 0.07871 1.0950 0.9053 8.589 153.40 0.006399 0.04904 0.05373 0.01587 0.03003 0.006193 25.38 17.33 184.60 2019.0 0.1622 0.6656 0.7119 0.2654 0.4601 0.11890 0
1 20.57 17.77 132.90 1326.0 0.08474 0.07864 0.0869 0.07017 0.1812 0.05667 0.5435 0.7339 3.398 74.08 0.005225 0.01308 0.01860 0.01340 0.01389 0.003532 24.99 23.41 158.80 1956.0 0.1238 0.1866 0.2416 0.1860 0.2750 0.08902 0
2 19.69 21.25 130.00 1203.0 0.10960 0.15990 0.1974 0.12790 0.2069 0.05999 0.7456 0.7869 4.585 94.03 0.006150 0.04006 0.03832 0.02058 0.02250 0.004571 23.57 25.53 152.50 1709.0 0.1444 0.4245 0.4504 0.2430 0.3613 0.08758 0
3 11.42 20.38 77.58 386.1 0.14250 0.28390 0.2414 0.10520 0.2597 0.09744 0.4956 1.1560 3.445 27.23 0.009110 0.07458 0.05661 0.01867 0.05963 0.009208 14.91 26.50 98.87 567.7 0.2098 0.8663 0.6869 0.2575 0.6638 0.17300 0
4 20.29 14.34 135.10 1297.0 0.10030 0.13280 0.1980 0.10430 0.1809 0.05883 0.7572 0.7813 5.438 94.44 0.011490 0.02461 0.05688 0.01885 0.01756 0.005115 22.54 16.67 152.20 1575.0 0.1374 0.2050 0.4000 0.1625 0.2364 0.07678 0
# plotting
plotnumber = 1
fig = plt.figure(figsize=(20, 20))
for column in df.drop('target', axis=1):
if plotnumber <= 30:
plt.subplot(5, 6, plotnumber)
sns.distplot(df[df['target'] == 0][column], label = 'malignant')
sns.distplot(df[df['target'] == 1][column], label = 'benign')
plt.legend()
plt.title(column)
plotnumber += 1
fig.tight_layout()
What I have so far
I want to divide them into three groups: 'mean' group, 'error' group and 'worst' group. And each group includes 10 plots (5 row, 2 column)
These don't answer the question
- How to plot in multiple subplots
- Doesn't show how to group the subplots
- Matplotlib different size subplots
- Also doesn't show how to group the subplots
Solution
- This requires using
matplotlib.figure.subfigures
, and then adding.subplots
to each subfigure. - In
seaborn v0.11.2
,sns.distplot
is deprecated and replaced bysns.histplot
. itertools.chain
is used to flattengroups
, but other methods can be found in How do I make a flat list out of a list of lists?.- Tested in
python 3.10
,pandas 1.4.2
,matplotlib 3.5.1
,seaborn 0.11.2
- If you are using Anaconda, then use
conda update --all
, otherwise update packages withpip
. Alternatively, create a new environment with the appropriate package versions.
- If you are using Anaconda, then use
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import load_breast_cancer # sample data
from itertools import chain # to lazily flatten the nested list
# starting with the sample dataframe in the op
data = load_breast_cancer()
df = pd.DataFrame(data.data, columns=data.feature_names)
df["target"] = data.target
# change the target name to what should be in the legend
df.target = df.target.map({0: 'Malignant', 1: 'Benign'})
# create the groups of column names for each set of subplots
col_groups = [df.columns[df.columns.str.contains(v)] for v in ['mean', 'error', 'worst']]
# create the subfigures and subplots
fig = plt.figure(figsize=(20, 20), constrained_layout=True)
subfigs = fig.subfigures(1, 3, width_ratios=[1, 1, 1], wspace=.15)
axs0 = subfigs[0].subplots(5, 2)
axs0 = axs0.flatten()
subfigs[0].suptitle('Mean Values', fontsize=20)
axs1 = subfigs[1].subplots(5, 2)
axs1 = axs1.flatten()
subfigs[1].suptitle('Standard Error Values', fontsize=20)
axs2 = subfigs[2].subplots(5, 2)
axs2 = axs2.flatten()
subfigs[2].suptitle('Worst Values', fontsize=20)
# create a flattened list of tuples containing an axes and column name
groups = chain(*[list(zip(axes, group)) for axes, group in zip([axs0, axs1, axs2], col_groups)])
# iterate through each axe and column
for ax, col in groups:
sns.histplot(data=df, x=col, hue='target', kde=True, stat='density', ax=ax)
l = ax.get_legend() # remove this line to keep default legend
l.remove() # remove this line to keep default legend
# get the existing label text, otherwise use a custom list (e.g labels = ['Malignant', 'Benign'])
# remove this line to keep default legend
labels = [v.get_text() for v in l.get_texts()]
# add a single legend at the top of the figure; change loc and bbox_to_anchor to move the legend
# remove this line to keep default legend
fig.legend(title='Tumor Classification', handles=l.legendHandles, labels=labels, loc='lower center', ncol=2, bbox_to_anchor=(0.5, -0.03))
fig.suptitle('Breast Cancer Data', fontsize=30, y=1.05)
fig.savefig('test.png', bbox_inches="tight")
plt.show()
Plotted with default legends
Answered By - Trenton McKinney
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.