Issue
I'm building Kmeans in pytorch using gradient descent on centroid locations, instead of expectation-maximisation. Loss is the sum of square distances of each point to its nearest centroid. To identify which centroid is nearest to each point, I use argmin, which is not differentiable everywhere. However, pytorch is still able to backprop and update weights (centroid locations), giving similar performance to sklearn kmeans on the data.
Any ideas how this is working, or how I can figure this out within pytorch? Discussion on pytorch github suggests argmax is not differentiable: https://github.com/pytorch/pytorch/issues/1339.
Example code below (on random pts):
import numpy as np
import torch
num_pts, batch_size, n_dims, num_clusters, lr = 1000, 100, 200, 20, 1e-5
# generate random points
vector = torch.from_numpy(np.random.rand(num_pts, n_dims)).float()
# randomly pick starting centroids
idx = np.random.choice(num_pts, size=num_clusters)
kmean_centroids = vector[idx][:,None,:] # [num_clusters,1,n_dims]
kmean_centroids = torch.tensor(kmean_centroids, requires_grad=True)
for t in range(4001):
# get batch
idx = np.random.choice(num_pts, size=batch_size)
vector_batch = vector[idx]
distances = vector_batch - kmean_centroids # [num_clusters, #pts, #dims]
distances = torch.sum(distances**2, dim=2) # [num_clusters, #pts]
# argmin
membership = torch.min(distances, 0)[1] # [#pts]
# cluster distances
cluster_loss = 0
for i in range(num_clusters):
subset = torch.transpose(distances,0,1)[membership==i]
if len(subset)!=0: # to prevent NaN
cluster_loss += torch.sum(subset[:,i])
cluster_loss.backward()
print(cluster_loss.item())
with torch.no_grad():
kmean_centroids -= lr * kmean_centroids.grad
kmean_centroids.grad.zero_()
Solution
As alvas noted in the comments, argmax
is not differentiable. However, once you compute it and assign each datapoint to a cluster, the derivative of loss with respect to the location of these clusters is well-defined. This is what your algorithm does.
Why does it work? If you had only one cluster (so that the argmax
operation didn't matter), your loss function would be quadratic, with minimum at the mean of the data points. Now with multiple clusters, you can see that your loss function is piecewise (in higher dimensions think volumewise) quadratic - for any set of centroids [C1, C2, C3, ...]
each data point is assigned to some centroid CN
and the loss is locally quadratic. The extent of this locality is given by all alternative centroids [C1', C2', C3', ...]
for which the assignment coming from argmax
remains the same; within this region the argmax
can be treated as a constant, rather than a function and thus the derivative of loss
is well-defined.
Now, in reality, it's unlikely you can treat argmax
as constant, but you can still treat the naive "argmax-is-a-constant" derivative as pointing approximately towards a minimum, because the majority of data points are likely to indeed belong to the same cluster between iterations. And once you get close enough to a local minimum such that the points no longer change their assignments, the process can converge to a minimum.
Another, more theoretical way to look at it is that you're doing an approximation of expectation maximization. Normally, you would have the "compute assignments" step, which is mirrored by argmax
, and the "minimize" step which boils down to finding the minimizing cluster centers given the current assignments. The minimum is given by d(loss)/d([C1, C2, ...]) == 0
, which for a quadratic loss is given analytically by the means of data points within each cluster. In your implementation, you're solving the same equation but with a gradient descent step. In fact, if you used a 2nd order (Newton) update scheme instead of 1st order gradient descent, you would be implicitly reproducing exactly the baseline EM scheme.
Answered By - Jatentaki
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.