Issue
I want to plot a dataset with different clusters.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.cluster
rng = np.random.default_rng(seed=5)
df_1_3 = pd.DataFrame(rng.normal(loc=(1, 3), size=(30, 2), scale=0.50), columns=["x", "y"])
df_5_1 = pd.DataFrame(rng.normal(loc=(5, 1), size=(30, 2), scale=0.25), columns=["x", "y"])
df_5_5 = pd.DataFrame(rng.normal(loc=(5, 5), size=(30, 2), scale=0.25), columns=["x", "y"])
df = pd.concat([df_1_3, df_5_1, df_5_5], keys=["df_1_3", "df_5_1", "df_5_5"])
A cluster algorithm will calculate the cluster labels:
model = sklearn.cluster.AgglomerativeClustering(...)
df["cluster"] = model.fit_predict(df[["x", "y"]]) # [0, 0, 0, ... 1, 1, 1 ... 2, 2, 2]
df["cluster"] = df["cluster"].astype("category")
I want to visualize the data in one plot. Each original data should be distinguishable by an individual marker, and the label should be visualized by the color.
To clarify, if you set the origin of all three data close to each other. The Algorithm would create just one cluster (aka one category / color), but the markers shall be depend on the original keys
, 'df_1_3'
, 'df_5_1'
, and 'df_5_5'
.
Actually I nearly got the result with:
fig, ax = plt.subplots()
for marker, (name, sdf) in zip(["o", "s", "^", "d"], df.groupby(level=0)):
sdf.plot.scatter(x="x", y="y", c="cluster", marker=marker, cmap="viridis", ax=ax)
but with the caveat that the color bar is displayed three times
How do I get rid of the redundant colorbars?
Solution
Using seaborn you can do this without using a for loop and get a cleaner looking plot:
import seaborn as sns
sns.scatterplot(data=df, x='x', y='y', hue='cluster', style='cluster', markers=["o", "^", "d"], palette="viridis")
To keep the color and the marker separate, it is best to reset the dataframe index, and use the keys
, in level=0
of the index, for the markers.
# reset the index
df = df.reset_index(level=0, names=['key'])
# plot
ax = sns.scatterplot(data=df, x='x', y='y', hue='cluster', style='key', markers=["o", "^", "d"], palette="viridis")
sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)
df.head()
after df.reset_index(level=0, names=['key'])
key x y cluster
0 df_1_3 0.599034 2.337821 0
1 df_1_3 0.875819 3.210223 0
2 df_1_3 1.568023 3.054853 0
3 df_1_3 0.723676 2.607610 0
4 df_1_3 1.374373 3.817392 0
Answered By - Suraj Shourie
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.