K-means clustering from scratch in Python implementation

Tonya Chernyshova
6 min readJan 13, 2023

--

Image by the author

Nowadays there are many libraries and frameworks available that make it easier for data scientists and machine learning developers to solve complex problems without having to write everything from scratch. These libraries and frameworks provide pre-built algorithms and functions working like magic, saving time and effort. However, writing a machine learning algorithm from scratch can still be an extremely rewarding learning experience, as it allows for a deeper understanding of how the algorithm works and how it can be implemented. Additionally, writing an algorithm from scratch can give developers more control over the details of the implementation and allow for customization to fit specific needs. While it may require more effort and time, the learning experience and increased control can be well worth it.

In this post, I’ll walk you through the k-means algorithm step-by-step. To understand how the libraries we use work under the hood we’ll develop the code for the algorithm from scratch using Python.

K-means algorithm is one of the simplest and most popular unsupervised machine learning algorithms that is used for clustering. It is called unsupervised because it does not require a target variable, as is the case in supervised learning. Instead, the algorithm works by grouping similar data points together into clusters based on their feature values. The k-means algorithm has a variety of applications, including customer segmentation, image segmentation, and finding patterns in data. It is a simple and effective algorithm that is widely used in many fields.

K-means clustering begins with k randomly placed centroids (points in space that represent the center of the cluster) and assigns every item to the nearest one. After the assignment, the centroids are moved to the average location of all the nodes assigned to them, and the assignments are redone. This process repeats until the assignments stop changing.

Image by the author

K-means cycle:

  1. Pick a value for k: The first step in running the k-means algorithm is to choose the number of clusters you want to create (k). This value should be chosen based on the structure of the data and the goals of your analysis.
  2. Initialize centroids: The next step is to initialize the centroids. This is typically done by selecting k random points from the dataset as the initial centroids.
  3. Create clusters: The algorithm then assigns each data point to the nearest centroid, creating k clusters.
  4. Update centroids: The centroids are then updated to the average location of all the data points assigned to them.
  5. Repeat until convergence: Steps 3 and 4 are repeated until the assignments of the data points to the centroids stop changing. At this point, the algorithm has converged and the final clusters have been determined.

Let’s implement this logic in the code. The first step is to generate a dataset.

# Import packages
from copy import deepcopy
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import kaleido
from sklearn.datasets import make_blobs
import imageio

# Generate Dataset
X, y = make_blobs(
n_samples=2000,
n_features=2,
centers=3,
cluster_std=1.35,
random_state=11
)
# Put Dataset to DataFrame
data = pd.DataFrame(dict(x=X[:,0], y=X[:,1], label=y))
# Pring data size and the first 3 rows
print('Data size')
print(data.shape)
data.head(3)
Output generated by the code above

For visualization I’m using Plotly library, I’m defining a function that I’m going to use for all future plots.

# Define Plot function

def plot_data(df, centroids_list=None, labels=False, show=True, save=True):
if labels:
color_plot = df['label']
else:
color_plot = '#FFB266'
fig = go.Figure(data=go.Scatter(
x=df['x'],
y=df['y'],
mode='markers',
marker=dict(
color=color_plot,
colorscale = 'Rainbow',
size=10,
line=dict(
color='DarkSlateGrey',
width=1
)
), showlegend=False
))
fig.update_layout(
autosize=False,
width=700,
height=600)

# plot Centroids
if not centroids_list is None:
fig.add_trace(go.Scatter(x=[i[0] for i in centroids_list],
y=[i[1] for i in centroids_list], mode='markers',marker=dict(color='#100101',size=14,
symbol = 'x'),
showlegend=False))

## Save files for GIF generation
if save:
try:
fig.write_image(f'./images_kmeans/img_{t}.png',engine='orca')
except NameError:
pass

# Parameter for displaying files
if show:
fig.show()

First, let’s plot the generated dataset:

plot_data(data,centroids_list=None, labels=False, show=True, save=False)
The plot generated by Plotly Python library

Selecting the number of clusters and initializing the centroids are the first two steps in the k-means algorithm. I picked k=3 and used a random function to place 3 initial centroids.

# K value, number of clusters
k = 3
# Coordinates for random initial centroids
centroidX = np.random.randint(data.x.min(), data.x.max(), size=k)
centroidY = np.random.randint(data.y.min(), data.y.max(), size=k)
centroids_list = np.array(list(zip(centroidX, centroidY)), dtype=np.float32)
print('Initial Centroids')
print(centroids_list)
Output generated by the code above

Now let’s plot initials centroids:

plot_data(data, centroids_list, labels=False, show=True, save=False)
The plot generated by Plotly Python library
The plot generated by Plotly Python library

In the k-means algorithm, the Euclidean distance is often used to measure the distance between data points and centroids. The Euclidean distance between two points in n-dimensional space is defined as the square root of the sum of the squares of the differences between the coordinates of the points:

distance = sqrt((x1 — x2)² + (y1 — y2)² + … + (zn — zn)²)

This distance measure is used to determine which centroid is closest to each data point, and the data point is then assigned to that centroid.

Let’s create a function to use the distance between centroids and each data point:

# Euclidean Distance calculation function
def distance(a, b, ax=1):
return np.linalg.norm(a - b, axis=ax)

It’s time to start implementing the algorithm. At first, small data preparation is needed:

# Data prep
x = data['x'].values
y = data['y'].values
dt = np.array(list(zip(x, y)))

Create lists to store centroid values and clusters labels:

# List for storing old values of centroids when they change location, fill with 0s
centroids_list_old = np.zeros(centroids_list.shape)
# Create an empty list for each element of clusters labels, exp. (1,0,2,1,2,0...2,1)
clusters = np.zeros(len(dt))

Calculate the distance between new centroids and old centroids, we assume that centroids are found when error = 0

# Distance between new centroids and old centroids, we'll find our centroid when error = 0 
error = distance(centroids_list,centroids_list_old, None)

So, while the error is not equal to 0, we assign a centroid to every point in the input dataset. Then centroids are moved to the center of their clusters.
I use a function defined above to save a plot at each iteration of the loop, then I combine all of the plots into an attractively animated GIF.

#for GIF generation
t=0
frames = []

while error != 0:
#for each point find centroid
for i in range(len(dt)):
distances = distance(dt[i], centroids_list)
cluster = np.argmin(distances) #returns index of min value
clusters[i] = cluster # ger cluster for i element
#copy current centroid list to 'old' list, for the first loop = centroids_list -> random, old -> empty
centroids_list_old = deepcopy(centroids_list)

#recalculate centroids
for i in range(k):
points = [dt[x] for x in range(len(dt)) if clusters[x] == i]
centroids_list[i] = np.mean(points, axis=0)
error = distance(centroids_list, centroids_list_old, None)
#plot
df_plot = pd.DataFrame()
for i in range(k):
points = [np.append(dt[j],i) for j in range(len(dt)) if clusters[j] == i]
df_temp = pd.DataFrame(points, columns=['x','y','label'])
df_plot = pd.concat([df_plot,df_temp])

plot_data(df_plot, centroids_list, labels=True, show=False, save=True)
t += 1
GIF generated by imageio library

Below is the code to generate GIF using imageio library:

# Gif generation
frames = []
for t in range(t):
image = imageio.v2.imread(f'./images_kmeans/img_{t}.png')
frames.append(image)

imageio.mimsave('./example.gif', frames, fps = 2)

Here is an example of how I might compare my implementation of the k-means algorithm with the implementation provided by scikit-learn to ensure that my implementation is correct:

from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=k)
kmeans = kmeans.fit(dt)
labels = kmeans.predict(dt)
centroids = kmeans.cluster_centers_
print("My code")
print(centroids_list)
print("sklearn")
print(centroids)
Output generated by code above

My implementation and scikit-learn produce the same results so the code above is working correctly.

I hope this post has helped you understand the mechanics of the algorithm and how it works. Thanks for reading!

Connect with me on LinkedIn

--

--

Tonya Chernyshova
Tonya Chernyshova

Written by Tonya Chernyshova

Data Engineer 1Password (ex-Amazon)

No responses yet