Cluster Analysis
Glossary
Centroid: is the center of a cluster.
Cluster analysis (clustering): is the task of combining similar observations into groups, or clusters.
Practice
1# Clustering2from sklearn.cluster import KMeans34# n_clusters - Number of clusters5# init - initial centroids6model = KMeans(n_clusters=n_clusters, init=centers, random_state=12345)7model.fit(data)89# Obtaining cluster centroids10print(model.cluster_centers_)11# objective function value12print(model.inertia_)
1# Plot the pairplot graph with cluster fill and centroids2import pandas as pd3from sklearn.cluster import KMeans4import seaborn as sns56centroids = pd.DataFrame(model.cluster_centers_, columns=data.columns)7# Add a column with the cluster number8data['label'] = model.labels_.astype(str)9centroids['label'] = ['0 centroid', '1 centroid', '2 centroid']10# An index reset will be needed later11data_all = pd.concat([data, centroids], ignore_index=True)1213# Plot the graph14sns.pairplot(data_all, hue='label', diag_kind='hist')
1# Plot the Pairgrid graph with cluster fill, initial and end centroids23import pandas as pd4from sklearn.cluster import KMeans5import seaborn as sns67centroids = pd.DataFrame(model.cluster_centers_, columns=data.columns)8# Add a column with the cluster number9data['label'] = model.labels_.astype(str)10centroids['label'] = ['0 centroid', '1 centroid', '2 centroid']11# An index reset will be needed later12data_all = pd.concat([data, centroids], ignore_index=True)1314# Plot the graph15pairgrid = sns.pairplot(data_all, hue='label', diag_kind='hist')16pairgrid.data = pd.DataFrame([[20, 80, 8], [50, 20, 5], [20, 30, 10]], \17 columns=data.drop(columns=['label']).columns)18pairgrid.map_offdiag(func=sns.scatterplot, s=200, marker='*', color='red')
1# finding the optimal number of clusters with the elbow method23import matplotlib.pyplot as plt4from sklearn.cluster import KMeans56distortion = []7K = range(1, 8) # number of clusters from 1 to 78for k in K:9 model = KMeans(n_clusters=k, random_state=12345)10 model.fit(data)11 distortion.append(model.inertia_)1213plt.figure(figsize=(12, 8))14plt.plot(K, distortion, 'bx-')15plt.xlabel('Number of clusters')16plt.ylabel('Objective function value')17plt.show()