一、算法原理
K-Means算法是一种聚类分析(cluster analysis)的算法,通过事先不知道类别的情况下,将相似的对象统一归类到簇中,其归类过程使用多次迭代的欧氏距离(Euclidean Distance)进行计算
欧式距离
算法过程如下,通过不断查找k值与点之间最短的欧式距离并归类进簇,通过多次迭代的方式进行归类,找出最优的簇中心点
二、算法Java实现
1.工具类KMeansUtil.java
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
| import java.util.ArrayList; import java.util.List;
public class KMeansUtil { public static final double DISTANCE = 10000.00; public int k; public int maxIter; public List<PointBean> points; public List<PointBean> centers;
public KMeansUtil(int k, int maxIter, List<PointBean> points) { this.k = k; this.maxIter = maxIter; this.points = points; initCenters(); }
public void initCenters() { centers = new ArrayList<>(k);
for (int i = 0; i < k; i++) { PointBean tmPoint = points.get(i * 2); PointBean center = new PointBean(tmPoint.getX(), tmPoint.getY()); center.setClusterID(i + 1); centers.add(center); } }
public void runKmeans() { int count = 1;
while (count++ <= maxIter) { for (PointBean point : points) { assignPointToCluster(point); }
adjustCenters(); } }
public void adjustCenters() { double[] sumX = new double[k]; double[] sumY = new double[k]; int[] count = new int[k];
for (int i = 0; i < k; i++) { sumX[i] = 0.0; sumY[i] = 0.0; count[i] = 0; }
for (PointBean point : points) { int clusterID = point.getClusterID();
sumX[clusterID - 1] += point.getX(); sumY[clusterID - 1] += point.getY(); count[clusterID - 1]++; }
for (int i = 1; i <= k; i++) { PointBean tmpPoint = centers.get(i-1);
tmpPoint.setX(sumX[i] / count[i]); tmpPoint.setY(sumY[i] / count[i]); tmpPoint.setClusterID(i); centers.set(i, tmpPoint); } }
public void assignPointToCluster(PointBean point) { double minDistance = DISTANCE;
int clusterID = -1;
for (PointBean center : centers) { double dis = EurDistance(point, center); if (dis < minDistance) { minDistance = dis; clusterID = center.getClusterID(); } } point.setClusterID(clusterID); }
public double EurDistance(PointBean point, PointBean center) { double detX = point.getX() - center.getX(); double detY = point.getY() - center.getY(); return Math.sqrt(detX * detX + detY * detY); } }
|
2.实体类PointBean.java
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| import lombok.Data; import lombok.NoArgsConstructor;
@Data @NoArgsConstructor public class PointBean { private Double x; private Double y;
private int clusterID = -1;
public PointBean(Double x, Double y) { this.x = x; this.y = y; }
public PointBean(Double x, Double y, int clusterID) { this.x = x; this.y = y; this.clusterID = clusterID; } }
|
3.使用实例
1 2 3 4 5 6 7
| KMeansUtil kMeansUtil = new KMeansUtil(mapEnterprise.getPointCount(), 5, beans); kMeansUtil.runKmeans();
List<PointBean> points = kMeansUtil.points;
List<PointBean> centers = kMeansUtil.centers;
|