Issue
I have a 4 column dataframe which I extracted from the iris dataset. I use kmeans to plot 3 clusters from all possible combinations of 2 columns.
However, there seems to be something wrong with the output, especially since the cluster centers are not placed at the center of the clusters. I have provided examples of the output. Only cluster_1 seems OK but the other 3 look completely wrong .
How best can I fix my clustering? This is the sample code I am using
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import itertools
df = pd.read_csv('iris.csv')
df_columns = ['column_a', 'column_b', 'column_c', 'column_d']
n_clusters=3
kmeans = KMeans(n_clusters=n_clusters, init = 'k-means++', max_iter=200)
kmeans = kmeans.fit(df)
centroids = kmeans.cluster_centers_
cluster_labels = kmeans.labels_
for i in itertools.combinations(df_columns, 2):
fig, ax = plt.subplots(figsize=(12, 8))
fig=plt.figure()
ax.scatter(df[i[0]].values, df[i[1]].values, c=cluster_labels , cmap='viridis', edgecolor='k', s=20, alpha = 0.5)
ax.scatter(centroids[:, 0], centroids[:, 1],s = 20, c = 'black', marker='*')
plt.show()
Dataset used:
**column_a**,**column_b**,**column_c**,**column_d**
5.1,3.5,1.4,0.2
4.9,3.0,1.4,0.2
4.7,3.2,1.3,0.2
4.6,3.1,1.5,0.2
5.0,3.6,1.4,0.2
5.4,3.9,1.7,0.4
4.6,3.4,1.4,0.3
5.0,3.4,1.5,0.2
4.4,2.9,1.4,0.2
4.9,3.1,1.5,0.1
5.4,3.7,1.5,0.2
4.8,3.4,1.6,0.2
4.8,3.0,1.4,0.1
4.3,3.0,1.1,0.1
5.8,4.0,1.2,0.2
5.7,4.4,1.5,0.4
5.4,3.9,1.3,0.4
5.1,3.5,1.4,0.3
5.7,3.8,1.7,0.3
5.1,3.8,1.5,0.3
5.4,3.4,1.7,0.2
5.1,3.7,1.5,0.4
4.6,3.6,1.0,0.2
5.1,3.3,1.7,0.5
4.8,3.4,1.9,0.2
5.0,3.0,1.6,0.2
5.0,3.4,1.6,0.4
5.2,3.5,1.5,0.2
5.2,3.4,1.4,0.2
4.7,3.2,1.6,0.2
4.8,3.1,1.6,0.2
5.4,3.4,1.5,0.4
5.2,4.1,1.5,0.1
5.5,4.2,1.4,0.2
4.9,3.1,1.5,0.1
5.0,3.2,1.2,0.2
5.5,3.5,1.3,0.2
4.9,3.1,1.5,0.1
4.4,3.0,1.3,0.2
5.1,3.4,1.5,0.2
5.0,3.5,1.3,0.3
4.5,2.3,1.3,0.3
4.4,3.2,1.3,0.2
5.0,3.5,1.6,0.6
5.1,3.8,1.9,0.4
4.8,3.0,1.4,0.3
5.1,3.8,1.6,0.2
4.6,3.2,1.4,0.2
5.3,3.7,1.5,0.2
5.0,3.3,1.4,0.2
7.0,3.2,4.7,1.4
6.4,3.2,4.5,1.5
6.9,3.1,4.9,1.5
5.5,2.3,4.0,1.3
6.5,2.8,4.6,1.5
5.7,2.8,4.5,1.3
6.3,3.3,4.7,1.6
4.9,2.4,3.3,1.0
6.6,2.9,4.6,1.3
5.2,2.7,3.9,1.4
5.0,2.0,3.5,1.0
5.9,3.0,4.2,1.5
6.0,2.2,4.0,1.0
6.1,2.9,4.7,1.4
5.6,2.9,3.6,1.3
6.7,3.1,4.4,1.4
5.6,3.0,4.5,1.5
5.8,2.7,4.1,1.0
6.2,2.2,4.5,1.5
5.6,2.5,3.9,1.1
5.9,3.2,4.8,1.8
6.1,2.8,4.0,1.3
6.3,2.5,4.9,1.5
6.1,2.8,4.7,1.2
6.4,2.9,4.3,1.3
6.6,3.0,4.4,1.4
6.8,2.8,4.8,1.4
6.7,3.0,5.0,1.7
6.0,2.9,4.5,1.5
5.7,2.6,3.5,1.0
5.5,2.4,3.8,1.1
5.5,2.4,3.7,1.0
5.8,2.7,3.9,1.2
6.0,2.7,5.1,1.6
5.4,3.0,4.5,1.5
6.0,3.4,4.5,1.6
6.7,3.1,4.7,1.5
6.3,2.3,4.4,1.3
5.6,3.0,4.1,1.3
5.5,2.5,4.0,1.3
5.5,2.6,4.4,1.2
6.1,3.0,4.6,1.4
5.8,2.6,4.0,1.2
5.0,2.3,3.3,1.0
5.6,2.7,4.2,1.3
5.7,3.0,4.2,1.2
5.7,2.9,4.2,1.3
6.2,2.9,4.3,1.3
5.1,2.5,3.0,1.1
5.7,2.8,4.1,1.3
6.3,3.3,6.0,2.5
5.8,2.7,5.1,1.9
7.1,3.0,5.9,2.1
6.3,2.9,5.6,1.8
6.5,3.0,5.8,2.2
7.6,3.0,6.6,2.1
4.9,2.5,4.5,1.7
7.3,2.9,6.3,1.8
6.7,2.5,5.8,1.8
7.2,3.6,6.1,2.5
6.5,3.2,5.1,2.0
6.4,2.7,5.3,1.9
6.8,3.0,5.5,2.1
5.7,2.5,5.0,2.0
5.8,2.8,5.1,2.4
6.4,3.2,5.3,2.3
6.5,3.0,5.5,1.8
7.7,3.8,6.7,2.2
7.7,2.6,6.9,2.3
6.0,2.2,5.0,1.5
6.9,3.2,5.7,2.3
5.6,2.8,4.9,2.0
7.7,2.8,6.7,2.0
6.3,2.7,4.9,1.8
6.7,3.3,5.7,2.1
7.2,3.2,6.0,1.8
6.2,2.8,4.8,1.8
6.1,3.0,4.9,1.8
6.4,2.8,5.6,2.1
7.2,3.0,5.8,1.6
7.4,2.8,6.1,1.9
7.9,3.8,6.4,2.0
6.4,2.8,5.6,2.2
6.3,2.8,5.1,1.5
6.1,2.6,5.6,1.4
7.7,3.0,6.1,2.3
6.3,3.4,5.6,2.4
6.4,3.1,5.5,1.8
6.0,3.0,4.8,1.8
6.9,3.1,5.4,2.1
6.7,3.1,5.6,2.4
6.9,3.1,5.1,2.3
5.8,2.7,5.1,1.9
6.8,3.2,5.9,2.3
6.7,3.3,5.7,2.5
6.7,3.0,5.2,2.3
6.3,2.5,5.0,1.9
6.5,3.0,5.2,2.0
6.2,3.4,5.4,2.3
5.9,3.0,5.1,1.8
Solution
You compute the clusters in four dimensions. Note this implies the centroids are four-dimensional points too. Then you plot two-dimensional projections of the clusters. So when you plot the centroids, you have to pick out the same two dimensions that you just used for the scatterplot of the individual points.
for i, j in itertools.combinations([0, 1, 2, 3], 2):
fig, ax = plt.subplots(figsize=(12, 8))
ax.scatter(df.iloc[:, i], df.iloc[:, j], c=cluster_labels,
cmap='viridis', edgecolor='k', s=20, alpha=0.5)
ax.scatter(centroids[:, i], centroids[:, j], s=20, c='black', marker='*')
plt.show()
Answered By - Arne
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.