K-Means 聚类
前言
K-Means是一个用于发现聚类中心的方法。它的实现非常简单,K-Means中的K意味着有K个聚类中心,每个聚类中心都对应这一种类。
具体流程
K-Means算法的具体步骤为:
- 首先随机选择K个中心
- 将所有样本分为K个类,每个类中均为到同一个聚类中心最近的点。
- 根据每个类中的点计算新的聚类中心,具体方法为计算一类中每个点的平均位置。
- 如果新的聚类中心和之前相比没有变化,或者到达最大迭代次数,结束计算,否则重复到步骤2。
该动画可以帮助理解K—Means算法的工作原理。
实现
def distance(pos1, pos2):
"""Returns the Euclidean distance between pos1 and pos2, which are pairs.
>>> distance([1, 2], [4, 6])
5.0
"""
return sqrt((pos1[0] - pos2[0]) ** 2 + (pos1[1] - pos2[1]) ** 2)
def find_closest(location, centroids):
"""Return the centroid in centroids that is closest to location.
If multiple centroids are equally close, return the first one.
>>> find_closest([3.0, 4.0], [[0.0, 0.0], [2.0, 3.0], [4.0, 3.0], [5.0, 5.0]])
[2.0, 3.0]
"""
return min(centroids, key=lambda x:distance(location, x))
def group_by_first(pairs):
"""Return a list of pairs that relates each unique key in the [key, value]
pairs to a list of all values that appear paired with that key.
Arguments:
pairs -- a sequence of pairs
>>> example = [ [1, 2], [3, 2], [2, 4], [1, 3], [3, 1], [1, 2] ]
>>> group_by_first(example)
[[2, 3, 2], [2, 1], [4]]
"""
keys = []
for key, _ in pairs:
if key not in keys:
keys.append(key)
return [[y for x, y in pairs if x == key] for key in keys]
def group_by_centroid(restaurants, centroids):
"""Return a list of clusters, where each cluster contains all restaurants
nearest to a corresponding centroid in centroids. Each item in
restaurants should appear once in the result, along with the other
restaurants closest to the same centroid.
"""
restaurant_groups = [[find_closest(restaurant_location(i), centroids), i]
for i in restaurants]
return group_by_first(restaurant_groups)
def find_centroid(cluster):
"""Return the centroid of the locations of the restaurants in cluster."""
restaurant_locations = [restaurant_location(i) for i in cluster]
return [mean([loc[0] for loc in restaurant_locations]),
mean([loc[1] for loc in restaurant_locations])]
def k_means(restaurants, k, max_updates=100):
"""Use k-means to group restaurants by location into k clusters."""
assert len(restaurants) >= k, 'Not enough restaurants to cluster'
old_centroids, n = [], 0
# Select initial centroids randomly by choosing k different restaurants
centroids = [restaurant_location(r) for r in sample(restaurants, k)]
while old_centroids != centroids and n < max_updates:
old_centroids = centroids
grouped_restaurants = group_by_centroid(restaurants, centroids)
centroids = [find_centroid(c) for c in grouped_restaurants]
n += 1
return centroids