Issue
I'm generating a figure that has 4 curves (for example), divided into 2 types - Type1 and Type2 (2 curves of each type). I'm drawing Type1 as a solid line while Type2 is dashed. To not overload the figure, I want to add a text somewhere in the figure that explains that the solid lines are Type1 and the dashed lines are Type2, and not to enter this on every legend entry like in the following example:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
x = np.arange(1,10)
a_1 = 2 * x
a_2 = 3 * x
b_1 = 5 * x
b_2 = 6 * x
p1, = plt.plot(x, a_1,)
p2, = plt.plot(x, a_2, linestyle='--')
p3, = plt.plot(x, b_1)
p4, = plt.plot(x, b_2, linestyle='--')
plt.legend([(p1, p2), (p3, p4)], ['A Type1/Type2', 'B Type1/Type2'], numpoints=1,handler_map={tuple: HandlerTuple(ndivide=None)}, handlelength=3)
plt.show()
The result is:
What I would like is something like this:
Where I removed the Type1/Type2 from each legend and added it with black color somewhere appropriate in the figure (marked by a red circle).
Can anybody help?
Solution
- I think it's easiest to allow the plotting API to handle the legend, compared to manually constructing it, which means properly labeling the data to feed into the API.
- In the following example, the data is loaded into a
dict
, where the values have been provided a category and type.['A']*len(a_1)
creates alist
of labels based on the length of a givenarray
['A']*len(a_1) + ['A']*len(a_2)
combines multiplelists
into a singlelist
'x': np.concatenate((x, x, x, x))
ensures that each value invals
is plotted with the correct x value.
seaborn.lineplot
, which is a high-level API formatplotlib
, can handle loading data directly from thedict
, where thehue
andstyle
parameters can be used.
import numpy as np
import seaborn as sns
# load the data from the OP in to the dict
data = {'x': np.concatenate((x, x, x, x)),
'vals': np.concatenate((a_1, a_2, b_1, b_2)),
'cat': ['A']*len(a_1) + ['A']*len(a_2) + ['B']*len(b_1) + ['B']*len(b_2),
'type': ['T1']*len(a_1) + ['T2']*len(a_2) + ['T1']*len(b_1) + ['T2']*len(b_2)}
# plot the data
p = sns.lineplot(data=data, x='x', y='vals', hue='cat', style='type')
# move the legend
p.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
- Plotting can also be accomplished with
seaborn.relplot
sns.relplot(data=data, kind='line', x='x', y='vals', hue='cat', style='type')
Manual Legend Creation
- matplotlib: Legend Guide
- matplotlib: Composing Custom Legends
- SO: Manually add legend Items Python matplotlib
Answered By - Trenton McKinney
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.