⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bisecting_kmeans.java

📁 用java实现的k means算法
💻 JAVA
字号:
package Bisecting;

import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * KMeans聚类, 不仅适用于bisecting, 且适用于n-secting
 * 
 * @author zg
 * @version 1.0
 *
 */
public class Bisecting_Kmeans {
	
	/**
	 * 从文件中读取的数据, Map的key是每一行的标号, List是属性列
	 * 
	 * @see #LoadFile()
	 */
	private Map<Integer, List<Double> > data;
	
	/**
	 * 每一层的类集合, Set中记录每一行的标号,与Map中的key对应
	 * 每次分裂后都会重新排序,保持集合中的第1项(size)最大
	 */
	private List<Set<Integer> > set_of_eachlevel;	
	
	/**
	 * 聚类停止时类的数量
	 */
	private static int K = 11;
	
	/**
	 * 每次把最大的类分成B个小类
	 */
	private static int B = 2;
	
	/**
	 * 每次分裂时重复选择随机数的次数
	 */
	private static int ITER = 3;
	private static Double MAX_VALUE = 10000000d; 

	/**
	 * 构造函数
	 *
	 * @see #LoadFile()
	 */
	public Bisecting_Kmeans()
	{
		if (data == null)
			LoadFile();
		if (set_of_eachlevel == null)
		{
			Set<Integer> set = data.keySet();
			set_of_eachlevel = new ArrayList<Set<Integer> >(1);
			set_of_eachlevel.add(set);
		}		
	}
	
	/**
	 * 从文件中读取数据
	 *
	 * @exception IOException, FileNotFoundException
	 */
	public void LoadFile()
	{
		if (data == null)
			data = new HashMap<Integer, List<Double> >();
		
		String filename = "football.txt";
		try {
			FileReader fr = new FileReader(filename);
			BufferedReader br = new BufferedReader(fr);
			
			String line = "";
			String[] keys = null;
			List<Double> list = null;
			br.readLine();
			while((line = br.readLine()) != null)
			{
				keys = line.split("\t");
				list = new  ArrayList<Double>(keys.length - 1);
				for (int i = 1; i < keys.length; i++)
					list.add(Double.parseDouble(keys[i]));
				data.put(Integer.parseInt(keys[0]), list);
			}
			
			br.close();
			fr.close();
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}

	/**
	 * 递归调用, 输出每一层的各个类, 同时分裂该层中最大的类
	 * 
	 * @param level 层数,从第1层开始
	 * @see #Split()
	 */
	public void PrintEachLevel(int level)
	{
		int num = (B - 1) *(level - 1) + 1;
		if (num >= K + 1)
			return;		
		
		System.out.println("第" + level + "层的" + num + "个类分别是:");
		Set<Integer> set = null;
		StringBuffer sb = null;
		for (int i = 0; i < num; i++)
		{
			sb = new StringBuffer();
			set = set_of_eachlevel.get(i);			
			sb.append("(");
			for (Iterator it = set.iterator(); it.hasNext(); )
				sb.append(it.next().toString() + ",");
			sb.replace(sb.length() - 1, sb.length(), ")");	
			System.out.println(sb.toString());
		}		
		
		Split();
		PrintEachLevel(level + 1);
	}
	
	/**
	 * 分裂当前层中最大的类为B个小类
	 *
	 * @see #Iterative(Set, List, StringBuffer)
	 */
	@SuppressWarnings("unchecked")
	public void Split()
	{
		/* set_of_eachlevel中第1项最大,即包含的点最多 */
		Set<Integer> set = set_of_eachlevel.get(0);
		/* 该数组只用来获取随机下标对应的类标号 */
		Object[] set_array = set.toArray();		
		
		/* 第1个List对应B个小类, 第2个List对应各小类的中心点坐标 */
		List<List<Double> > list_of_center = null;
		/* List对应B个小类, Set中记录各类中包含的节点标号 */
		List<Set<Integer> > list_of_points = null;
		/* 记录已生成的随机下标, 通过查询该Hashset保证生成的下标不重复 */
		HashSet<Integer> hs = null;
		
		/* important, 记录每次选择随机数并完成分裂后所有小类的距离和 */
		StringBuffer sb = null;
		List<Set<Integer> > templist_of_points = null;	
		double min_sum = MAX_VALUE;
		/* 重复ITER次选择随机下标, 并选择ITER次中距离和最小的作为分裂标准 */
		for (int k = 0; k < ITER; k++)
		{
			list_of_center = new ArrayList<List<Double> >(B);
			hs = new HashSet<Integer>();
			sb = new StringBuffer(" ");
			int random_num = 0;
			int key = 0;
			for (int i = 0; i < B; i++)
			{
				/* 生成随机下标 */
				random_num = (int)(Math.random() * set.size());
				if (hs.contains(random_num))
					i--;
				else
				{
					hs.add(random_num);
					/* 获取随机下标对应的节点标号(行号) */
					key = Integer.parseInt(set_array[random_num].toString());
					/* 获取该行的所有属性 */
					List<Double> temp_list = data.get(key);
					/* 初始化每个小类的中心为选择的随机点 */
					list_of_center.add(i, temp_list);
				}
			}
			hs.clear();
			
			/* 获取距离和最小的结果并记录在list_of_points中 */
			templist_of_points = Iterative(set, list_of_center, sb);
			if (Double.parseDouble(sb.toString()) < min_sum)
			{
				min_sum = Double.parseDouble(sb.toString());	
				list_of_points = templist_of_points;
			}			
		}
			
		/* 将生成的B个小类加入类集合中, 并移除原先的大类 */
		for (int i = 0; i < B; i++)
			set_of_eachlevel.add(list_of_points.get(i));
		set_of_eachlevel.remove(0);
		set_array = null;
		
		Set<Integer> temp_set1 = null;
		Set<Integer> temp_set2 = null;
		/* 对类集合进行1次排序, 将元素最多的类放在第1项 */
		for (int i = 1; i < set_of_eachlevel.size(); i++)
		{
			if (set_of_eachlevel.get(i).size() > set_of_eachlevel.get(0).size())
			{
				temp_set1 = ((Set<Integer>)((HashSet<Integer>)set_of_eachlevel.get(0)).clone());
				temp_set2 = ((Set<Integer>)((HashSet<Integer>)set_of_eachlevel.get(i)).clone());
				set_of_eachlevel.set(0, temp_set2);
				set_of_eachlevel.set(i, temp_set1);
			}
		}
	}
	
	/**
	 * 递归调用自身, 直至类不发生变化
	 * 还可以优化即每次迭代时只重新计算各点到发生改变的类的中心的距离, 并与到当前类中心的距离比较, 若变小则更新, 否则不变 
	 * 
	 * @param set 要分裂的大类
	 * @param list_of_center 每个类的中心点
	 * @param sb string对象,记录迭代停止时所有小类的距离和
	 * @return List<Set<Integer> > list_of_points: 稳定(迭代停止)时各小类中的节点集合
	 * 
	 * @see #IsChanged(List, List)
	 */
	public List<Set<Integer> > Iterative(Set<Integer> set, List<List<Double> > list_of_center, StringBuffer sb)
	{
		sb.delete(0, sb.length() - 1);
		/* 所有小类的距离和 */
		double sum_of_all = 0d;
				
		List<Set<Integer> > list_of_points = new ArrayList<Set<Integer> >(B);
		Set<Integer> empty_set = null;
		/* 初始化为空集合 */
		for (int i = 0; i < B; i++)
		{
			empty_set = new HashSet<Integer>();
			list_of_points.add(empty_set);
		}
		
		/* 类是否改变的标志 */
		boolean flag = true;
		List<Double> compare = null;
		List<Double> compared = null;
		Set<Integer> set_in_list = null;
		int min_index;
		double min_sum;
		double sum = 0;
		int compare_index = 0;
		int compared_index = 0;
		/* 对于大类中的每一个节点, 计算其到B个小类中心的距离并划分到距离最近的类中 */
		for (Iterator<Integer> compare_it = set.iterator(); compare_it.hasNext(); )
		{
			min_index = 0;
			min_sum = MAX_VALUE;	
			/* 节点标号 */
			compare_index = compare_it.next();
			/* 节点属性列 */
			compare = data.get(compare_index);
			compared_index = 0;
			/* 选择距离最近的类 */
			for (Iterator<List<Double> > compared_it = list_of_center.iterator(); compared_it.hasNext(); compared_index++)
			{
				sum = 0d;
				/* 中心点属性列 */
				compared = compared_it.next();
				int size = compare.size() < compared.size() ? compare.size() : compared.size();
				for (int index = 0; index < size; index++)
				{
					sum += Math.pow((compare.get(index) - compared.get(index)), 2);
				}
				if (sum < min_sum)
				{
					min_index = compared_index;
					min_sum = sum;
				}				
			}
			/* 更新类集合 */
			sum_of_all += min_sum;
			set_in_list = list_of_points.get(min_index);
			set_in_list.add(compare_index);
			list_of_points.set(min_index, set_in_list);
		}
		sb.append(sum_of_all);
		
		List<Double> store = null;
		List<Double> avg = null;
		/* 更新类中心 */
		for (int i = 0; i < list_of_points.size(); i++)
		{
			set_in_list = list_of_points.get(i);
			avg = null;
			for (Iterator<Integer> it = set_in_list.iterator(); it.hasNext(); )
			{
				store = data.get(it.next());
				if (avg == null)
				{
					avg = new ArrayList<Double>(store.size());
					for (int j = 0; j < store.size(); j++)
						avg.add(0, 0d);
				}
				for (int j = 0; j < avg.size(); j++)
					avg.set(j, avg.get(j) + store.get(j));
			}
			for (int j = 0; j < avg.size(); j++)
			{
				avg.set(j, avg.get(j) / set_in_list.size());
			}
			if (IsChanged(list_of_center.get(i), avg))
			{
				list_of_center.set(i, avg);
				flag = false;
			}
		}
		/* 稳定则返回, 否则迭代调用 */
		if (flag)
			return list_of_points;
		else
			return Iterative(set, list_of_center, sb);
	}
	
	/**
	 * 根据类中心点坐标判断其是否改变
	 * 
	 * @param Old 原先坐标点
	 * @param New 新生成的坐标点
	 * @return 只要1个属性值不同则为true, 否则false
	 */
	public boolean IsChanged(List<Double> Old, List<Double> New)
	{
		if (Old.size() != New.size())
			return true;
		for (int i = 0; i < Old.size(); i++)
		{
			if (Old.get(i).doubleValue() != New.get(i).doubleValue())
				return true;
		}
		return false;
	}
	
	
	public static void main(String[] args)
	{
		Bisecting_Kmeans bk = new Bisecting_Kmeans();
		bk.PrintEachLevel(1);
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -