Issue
I have three point plot i'm trying to chart and show a legend. The colors do not match the colors called out in the plots. I tried using the solution from this post, but that did not work.
Here is the code I'm using:
fig, ax = plt.subplots()
a = sns.pointplot(x=l[1:], y = np.exp(model_m.params[1:]), label = 'factor',
ax = ax, color = 'green')
b = sns.pointplot(x=l[1:], y = np.exp(model_m.conf_int()[1:][:,1]),
ax = ax, label = 'conf_int+', color = 'red')
c = sns.pointplot(x=l[1:], y = np.exp(model_m.conf_int()[1:][:,0]),
ax = ax, label = 'conf_int-', color = 'blue')
plt.title('Model M Discrete')
ax.legend(labels = ['factor', 'conf_inf+', 'conf_inf-'],
title = 'legend')
Here is what it produces:
Solution
The easiest solution would be to use sns.lineplot
instead of sns.pointplot
:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
fig, ax = plt.subplots()
x = np.arange(10)
sns.lineplot(x=x, y=1 + np.random.rand(10).cumsum(),
ax=ax, label='factor', color='green', marker='o')
sns.lineplot(x=x, y=2 + np.random.rand(10).cumsum(),
ax=ax, label='conf_int+', color='red', marker='o')
sns.lineplot(x=x, y=3 + np.random.rand(10).cumsum(),
ax=ax, label='conf_int-', color='blue', marker='o')
ax.set_title('Model M Discrete')
ax.legend(title='legend')
plt.tight_layout()
plt.show()
Another option would be to iterate through the generated "pathCollections" and assign a label (for some reason label=
doesn't work in sns.pointplot
).
fig, ax = plt.subplots()
sns.pointplot(x=x, y=1 + np.random.rand(10).cumsum(),
ax=ax, color='green')
sns.pointplot(x=x, y=2 + np.random.rand(10).cumsum(),
ax=ax, color='red')
sns.pointplot(x=x, y=3 + np.random.rand(10).cumsum(),
ax=ax, label='conf_int-', color='blue')
for curve, label in zip(ax.collections, ['factor', 'conf_int+', 'conf_int-']):
curve.set_label(label)
ax.set_title('Model M Discrete')
ax.legend(title='legend')
Still another way is to mimic a long form dataframe with hue which automatically creates a legend:
fig, ax = plt.subplots()
x = np.arange(10)
y1 = 1 + np.random.rand(10).cumsum()
y2 = 2 + np.random.rand(10).cumsum()
y3 = 3 + np.random.rand(10).cumsum()
sns.pointplot(x=np.tile(x, 3),
y=np.concatenate([y1, y2, y3]),
hue=np.repeat(['factor', 'conf_int+', 'conf_int-'], len(x)),
ax=ax, palette=['green', 'red', 'blue'])
Note that in both cases only a dot is shown in the legend, not a line.
Answered By - JohanC
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.