Issue
I am using seaborn to create a jointplot with kind='hist'. When I add the colorbar for the histogram to the plot the grid for the jointplot and the single histogram no longer match. I am pretty sure this is because I add the colorbar to the axis, but I have no idea how to do it differently and wasn't able to find anything helpful here or elsewhere.
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
#Function I use to generate some sampling weights (my real data is very imbalanced)
def generate_weights(n_bins:int, data:np.ndarray)->np.ndarray:
hist_data, bin_edges = np.histogram(data, bins=n_bins)
indices = np.digitize(data, bin_edges[:-1])-1
weights = 1.0/hist_data[indices]
return weights
#generating some data to plot
dataset_1 = np.random.normal(0, 50, size=1000)
dataset_2 = np.random.normal(9, 55, size=1000)
#generating weights for the data
weights_1 = generate_weights(50, dataset_1)
#weights_2 = generate_weights(50, dataset_2)
#calculate the min and max of both datasets to use them later for calculating the bins and for the limits of the axes
min_val = min([np.min(dataset_1), np.min(dataset_2)])
max_val = max([np.max(dataset_1), np.max(dataset_2)])
bin_width = (max_val-min_val)/100
# calculate bins
bins_2d = (np.linspace(min_val, max_val, 100), np.linspace(min_val, max_val, 100))
sns.set_style('darkgrid')
figsize = (8,6)
jointplot_hist = sns.jointplot(
x = dataset_1,
y = dataset_2,
kind = 'hist',
cmap = 'viridis',
bins = bins_2d,
**{'weights': weights_1}
)
#set axis labels
jointplot_hist.set_axis_labels('dataset 1', 'dataset_2')
#set axis limits
jointplot_hist.ax_joint.set_xlim([min_val, max_val])
jointplot_hist.ax_joint.set_ylim([min_val, max_val])
#plot a diagonal line
plt.plot([min_val, max_val], [min_val, max_val], color='red', linestyle='-', linewidth=0.5)
#here I add the colorbar
divider = make_axes_locatable(jointplot_hist.ax_joint)
cax = divider.append_axes('right', size='5%', pad=0.05)
mappable = jointplot_hist.ax_joint.collections[0]
plt.colorbar(mappable, cax=cax)
Plot where the grid no longer matches:
I tried to set the colorbar differently, like this:
cbar_ax = jointplot_hist.fig.add_axes([0.85, 0.15, 0.05, 0.7])
mappable = jointplot_hist.ax_joint.collections[0]
jointplot_hist.fig.colorbar(mappable, cax=cbar_ax)
but then I do not have it placed between the jointplot and the individual plot which is the design I would like to get. Also adjusting the size is tricky.
Solution
You can grab the position of both the joint ax, and the marginal ax. And then reposition the right of the marginal ax.
# at the end of the code in the question's post
jointplot_hist.fig.canvas.draw() # a draw is needed to let matplotlib fix the positions
_jx0, _jy0, jx1, _jy1 = jointplot_hist.ax_joint.get_position().bounds
mx0, my0, _mx1, my1 = jointplot_hist.ax_marg_x.get_position().bounds
jointplot_hist.ax_marg_x.set_position([mx0, my0, jx1, my1])
PS: Note that to set the figsize, sns.jointplot()
uses the height=
parameter (and creates a square plot). You can use jointplot_hist.fig.set_size_inches(8, 6)
to change the size, once everything has been created.
Answered By - JohanC
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.