Knowledge Base

Cluster Analysis

Glossary

Centroid: is the center of a cluster.

Cluster analysis (clustering): is the task of combining similar observations into groups, or clusters.

Practice

1# Clustering
2from sklearn.cluster import KMeans
3
4# n_clusters - Number of clusters
5# init - initial centroids
6model = KMeans(n_clusters=n_clusters, init=centers, random_state=12345)
7model.fit(data)
8
9# Obtaining cluster centroids
10print(model.cluster_centers_)
11# objective function value
12print(model.inertia_)

1# Plot the pairplot graph with cluster fill and centroids
2import pandas as pd
3from sklearn.cluster import KMeans
4import seaborn as sns
5
6centroids = pd.DataFrame(model.cluster_centers_, columns=data.columns)
7# Add a column with the cluster number
8data['label'] = model.labels_.astype(str)
9centroids['label'] = ['0 centroid', '1 centroid', '2 centroid']
10# An index reset will be needed later
11data_all = pd.concat([data, centroids], ignore_index=True)
12
13# Plot the graph
14sns.pairplot(data_all, hue='label', diag_kind='hist')

1# Plot the Pairgrid graph with cluster fill, initial and end centroids
2
3import pandas as pd
4from sklearn.cluster import KMeans
5import seaborn as sns
6
7centroids = pd.DataFrame(model.cluster_centers_, columns=data.columns)
8# Add a column with the cluster number
9data['label'] = model.labels_.astype(str)
10centroids['label'] = ['0 centroid', '1 centroid', '2 centroid']
11# An index reset will be needed later
12data_all = pd.concat([data, centroids], ignore_index=True)
13
14# Plot the graph
15pairgrid = sns.pairplot(data_all, hue='label', diag_kind='hist')
16pairgrid.data = pd.DataFrame([[20, 80, 8], [50, 20, 5], [20, 30, 10]], \
17 columns=data.drop(columns=['label']).columns)
18pairgrid.map_offdiag(func=sns.scatterplot, s=200, marker='*', color='red')

1# finding the optimal number of clusters with the elbow method
2
3import matplotlib.pyplot as plt
4from sklearn.cluster import KMeans
5
6distortion = []
7K = range(1, 8) # number of clusters from 1 to 7
8for k in K:
9 model = KMeans(n_clusters=k, random_state=12345)
10 model.fit(data)
11 distortion.append(model.inertia_)
12
13plt.figure(figsize=(12, 8))
14plt.plot(K, distortion, 'bx-')
15plt.xlabel('Number of clusters')
16plt.ylabel('Objective function value')
17plt.show()
Send Feedback
close
  • Bug
  • Improvement
  • Feature
Send Feedback
,