Issue
In R I would do the following to make a grid of facets with a raster-plot in each facet:
# R Code
DF <- data.frame(expand.grid(seq(0, 7), seq(0, 7), seq(0, 5)))
names(DF) <- c("x", "y", "z")
DF$I <- runif(nrow(DF), 0, 1)
# x y z I
# 1: 0 0 0 0.70252977
# 2: 1 0 0 0.74346071
# ---
# 383: 6 7 5 0.93409337
# 384: 7 7 5 0.14143277
library(ggplot2)
ggplot(DF, aes(x = x, y = y, fill = I)) +
facet_wrap(~z, ncol = 3) +
geom_raster() +
scale_fill_viridis_c() +
theme(legend.position = "bottom") # desired legend position should be bottom
How can I do that in python (using matplotlib and probably seaborn)? I tried it with the following code, but had trouble with the plotting of images which I tried with plt.imshow
. As the data has to be reshaped for plt.imshow
I guess I need a custom plot function for g.map
. I tried several things, but had problem with the Axes or the color and with using the data in the custom plot function.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
df = pd.DataFrame(list(itertools.product(range(8), range(8), range(6))),
columns=['x', 'y', 'z'])
# order of values different than in R, but that shouldn't matter for plotting
df['I'] = np.random.rand(df.shape[0])
# x y z I
# 0 0 0 0 0.076338
# 1 0 0 1 0.148386
# 2 0 0 2 0.481053
# .. .. .. .. ...
# 382 7 7 4 0.144188
# 383 7 7 5 0.700624
g = sns.FacetGrid(df, col='z', col_wrap=2, height=4, aspect=1)
g.map(plt.imshow, color = 'I') # <- plt.imshow does not work here.
# How can this be corrected (probably with a custom plot function)?
plt.show()
Solution
- The key is to reshape each group of
'z'
data withpandas.DataFrame.pivot
into the correct format forseaborn.heatmap
.- Define
vmin
andvmax
with themin
andmax
of the entire dataset:vmin=df.I.min()
andvmax=df.I.max()
- Define
- The following code predefines the
fig
and allaxes
withplt.subplots
.- How to plot in multiple subplots shows other ways to create subplots, including this answer creating a figure with
fig = plt.figure()
and adding subplots withfig.add_subplot(2, 3, idx)
.
- How to plot in multiple subplots shows other ways to create subplots, including this answer creating a figure with
- References:
- Tested in
python v3.12.0
,pandas v2.1.2
,matplotlib v3.8.1
,seaborn v0.13.0
.
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# sample data
df = pd.DataFrame(list(itertools.product(range(8), range(8), range(6))),
columns=['x', 'y', 'z'])
np.random.seed(20231116) # for reproducible data
df['I'] = np.random.rand(df.shape[0])
# create the figure and axes
fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
# flatten the axes into a 1d array for easy access
axes = axes.flat
# add a separate axes for the colorbar
cbar_ax = fig.add_axes([0.3, .03, .4, .03])
# enumerate is specifically for adding the colorbar
# zip each group of 'z' data to the appropriate axes
for i, (ax, (z, data)) in enumerate(zip(axes, df.groupby('z'))):
# pivot data into the correct shape for heatmap
data = data.pivot(index='y', columns='x', values='I')
# plot the heatmap
sns.heatmap(data=data, cmap='viridis', ax=ax, cbar=i == 0, vmin=df.I.min(), vmax=df.I.max(),
cbar_ax=None if i else cbar_ax, cbar_kws=dict(location="bottom"))
# add a title
ax.set(title=f'Z: {z}')
# invert the yaxis to match the OP
ax.invert_yaxis()
data
for z: 5
x 0 1 2 3 4 5 6 7
y
0 0.488408 0.855913 0.339374 0.452842 0.510380 0.690491 0.448773 0.500916
1 0.273653 0.561840 0.860269 0.387470 0.170281 0.718488 0.256749 0.463527
2 0.546085 0.093934 0.273339 0.503968 0.063212 0.537974 0.867814 0.135719
3 0.071505 0.792265 0.919784 0.559663 0.733996 0.032003 0.475792 0.690789
4 0.474310 0.265576 0.841875 0.496676 0.603356 0.328808 0.039460 0.461778
5 0.439142 0.119253 0.842653 0.155213 0.798092 0.093709 0.899745 0.927067
6 0.548373 0.259983 0.295939 0.700694 0.040197 0.679880 0.153048 0.328768
7 0.216977 0.176777 0.238436 0.610802 0.705161 0.614877 0.813430 0.527120
- Implementation with
plt.figure
andfig.add_subplot
, instead ofplt.subplots
# create the figure and axes
fig = plt.figure(figsize=(15, 10))
# add a separate axes for the colorbar
cbar_ax = fig.add_axes([0.3, .03, .4, .03])
# enumerate is specifically for adding the colorbar and adding an axes
for i, (z, data) in enumerate(df.groupby('z')):
# pivot data into the correct shape for heatmap
data = data.pivot(index='y', columns='x', values='I')
# create the axes
ax = fig.add_subplot(2, 3, i+1)
# plot the heatmap
sns.heatmap(data=data, cmap='viridis', ax=ax, cbar=i == 0, vmin=df.I.min(), vmax=df.I.max(),
cbar_ax=None if i else cbar_ax, cbar_kws=dict(location="bottom"))
# add a title
ax.set(title=f'Z: {z}')
# invert the yaxis to match the OP
ax.invert_yaxis()
Answered By - Trenton McKinney
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.