⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 kmltest.cpp

📁 高效的k-means算法实现
💻 CPP
📖 第 1 页 / 共 4 页
字号:
void validateAssignments(    KMdataPtr			data,		// the data points    const KMfilterCenters&	ctrs,		// the center points    KMctrIdxArray		closeCtr,	// closest centers    double*			sqDist)		// squared distances{    int errCt = 0;				// number of errors    KMdataArray dataPts = data->getPts();	// get points and centers    KMcenterArray ctrPts = ctrs.getCtrPts();    int nPts = data->getNPts();    int kCtrs = ctrs.getK();    *kmOut << "  (Validating assignments. ";    for (int i = 0; i < nPts; i++) {	KMdist minDist = KM_DIST_INF;		// distance to nearest center	int minK = 0;				// index of this center	KMpoint thisPt = dataPts[i];		// current data point	for (int j = 0; j < kCtrs; j++) {	// compute closest candidate	    KMdist dist = kmDist(dim, ctrPts[j], thisPt);	    if (dist < minDist) {		// best so far?		minDist = dist;			// yes, save it		minK = j;			// ...and its index	    }	}	if (sqDist[i] > minDist + KM_ERR ||	// distance mismatch	    sqDist[i] < minDist - KM_ERR) {	    *kmOut << "\n    Mismatch: data[" << i << "] assigned to ctr["	    	<< closeCtr[i] << "] (" << sqDist[i]		<< ") rather than ctr[" << minK << "] ("		<< minDist << ")";	    errCt++;	}    }    *kmOut << " Found " << errCt << " mismatches.)" << endl;}//------------------------------------------------------------------------//  Print summary at end of run//------------------------------------------------------------------------static void printSummary(    KMlocalPtr		    theAlg,		// the algorithm    KMdataPtr		    dataPts,		// data points    KMfilterCenters&  	    ctrs)		// the centers{    int nStages = theAlg->getTotalStages();    double totalTime = exec_time + kc_build_time;    if (kmStatLev > SILENT) {	*kmOut << "\n[k-means completed:\n"	     << "  n_stages      = " << nStages << "\n";	if (kmStatLev >= EXEC_TIME) {	// print exec time summary	    *kmOut << "  total_time    = "		 << totalTime << " sec\n"		 << "  init_time     = "		 << kc_build_time << " sec\n"		 << "  stage_time    = "		 << double(exec_time)/double(nStages)		 << " sec/stage_(excl_init) "		 << double(totalTime)/double(nStages)		 << " sec/stage_(incl_init)\n"		 << "  average_distort = "		 << ctrs.getDist(false)/double(ctrs.getNPts())		 << "\n";	}	if (kmStatLev >= SUMMARY) {	// print final results	    int max_ctrs = kcenters;	// number of centers to print	    if (kmStatLev < STAGE && max_ctrs > 10)	      max_ctrs = 10;		// low interest? just print 10	    *kmOut << "  (Final Center Points:\n";	    ctrs.print();	    *kmOut << "  )\n";	}	*kmOut << "]" << endl;    }    if (show_assign) {		// want to see point assignments?	KMctrIdxArray closeCtr = new KMctrIdx[dataPts->getNPts()];	double* sqDist = new double[dataPts->getNPts()];	ctrs.getAssignments(closeCtr, sqDist);	*kmOut	<< "  (Cluster assignments:\n"		<< "    Point  Center  Squared Dist\n"		<< "    -----  ------  ------------\n";	for (int i = 0; i < dataPts->getNPts(); i++) {	    *kmOut	<< "   " << setw(5) << i			<< "   " << setw(5) << closeCtr[i]			<< "   " << setw(10) << sqDist[i]			<< "\n";	}	*kmOut << "  )\n";	if (validate) {	    validateAssignments(dataPts, ctrs, closeCtr, sqDist);	}	delete [] closeCtr;	delete [] sqDist;    }}//------------------------------------------------------------------------//  runKmeans - run the k-means algorithm//  This procedure is given an algorithm type, a pointer to a set of data//  points and termination object.////  This runs the k-means algorithm.  We assume that the data point set//  has been generated (dataPts != NULL).  From these we create the//  center points (using KMfilterCenters), print a header (if//  desired), resets the performance statistics and resets the clock.////  Based on the algorithm type we create an algorithm object of the//  appropriate type and execute it.  Afterwards we print a summary of//  the results.//------------------------------------------------------------------------static void runKmeans(KMalg alg, KMdataPtr dataPts, KMterm &term){    if (dataPts == NULL) {			// failed to create data      kmError("Data points have not been generated", KMabort);    }    						// center points    KMfilterCenters ctrs(kcenters, *dataPts, damp_factor);    printHeader(alg, dataPts, term);		// print header    KMlocalPtr theAlg = NULL;			// the search algorithm    switch (alg) {				// select the search algorithm    case LLOYD:					// Lloyd's algorithm	theAlg = new KMlocalLloyds(ctrs, term);	break;    case SWAP:					// the swap heuristic	theAlg = new KMlocalSwap(ctrs, term, max_swaps);	break;    case HYBRID:				// the hybrid algorithm	theAlg = new KMlocalHybrid(ctrs, term);	break;    case EZ_HYBRID:				// the EZ-hybrid algorithm	theAlg = new KMlocalEZ_Hybrid(ctrs, term);	break;    default:	kmError("Internal error: Invalid algorithm", KMabort);	break;    }    clock_t start = clock();			// start the clock    ctrs = theAlg->execute();			// execute the algorithm    exec_time = elapsedTime(start);		// get elapsed time    						// print summary    printSummary(theAlg, dataPts, ctrs);}//------------------------------------------------------------------------//  Build kc-tree for the points//	This should be called whenever the point set is modified//	This time to construct the tree is saved in the global variable,//	kc_build_time.//------------------------------------------------------------------------static void buildKcTree(		// build kc-tree for points    KMdataPtr		dataPts)	// point array{    clock_t start = clock();			// start the clock    dataPts->buildKcTree();			// build the tree    kc_build_time = elapsedTime(start);		// get elapsed time    if (kmStatLev >= TREE) {			// print the tree (if req'd)	*kmOut << "Contents of the kc-tree: [\n";	dataPts->getKcTree()->print(false);	*kmOut << "]" << endl;    }}//------------------------------------------------------------------------// Print summary of distribution.//------------------------------------------------------------------------void printDistribSummary(    KMpointArray	pa,		// point array    int			nClus)		// number of clusters (for MULTI_CLUS){    if (kmStatLev > SILENT) {	*kmOut << "[Generating Data Points:\n"	     << "  data_size     = " << data_size << "\n"	     << "  dim           = " << dim << "\n";						// output distribution info	*kmOut << "  distribution  = " << distr_table[distr] << "\n";	if (kmIdum < 0)	    *kmOut << "  seed          = " << kmIdum << "\n";	if (distr == GAUSS	 || distr == CLUS_GAUSS	 || distr == MULTI_CLUS	 || distr == CLUS_ORTH_FLATS)	    *kmOut << "  std_dev       = " << std_dev << "\n";	if (distr == CLUS_ELLIPSOIDS) {	    *kmOut << "  std_dev       = " << std_dev << " (small) \n"		   << "  std_dev_lo    = " << std_dev_lo << "\n"		   << "  std_dev_hi    = " << std_dev_hi << "\n";	}	if (distr == CO_GAUSS        || distr == CO_LAPLACE)	    *kmOut << "  corr_coef     = " << corr_coef << "\n";	if (distr == CLUS_GAUSS      || distr == CLUS_ORTH_FLATS	 || distr == CLUS_ELLIPSOIDS) {	    *kmOut << "  colors        = " << n_color << "\n";	    if (new_clust)		*kmOut << "  (cluster centers regenerated)\n";	}	if (distr == CLUS_ORTH_FLATS || distr == CLUS_ELLIPSOIDS) {	    *kmOut << "  max_dim       = " << max_dim << "\n";	}	if (distr == CLUS_GAUSS) {	    *kmOut << "  cluster_sep   = " << clus_sep << "\n";	}	if (distr == MULTI_CLUS) {	    *kmOut << "  n_clusters    = " << nClus << "\n";	}    }    if (kmStatLev >= TREE) {			// want to see points?						// clustered gaussian data?	if (distr == CLUS_GAUSS) {	    KMpointArray clusts = kmGetCGclusters();	    kmPrintPts("Cluster_Centers", clusts, n_color, dim);	}    }    if (print_points) {				// print the points?	kmPrintPts("Data_Points", pa, data_size, dim);    }    *kmOut << "]" << endl;}//------------------------------------------------------------------------//  Generate data points from a distribution//	genDataPts calls the appropriate generation function and//	prints the summary.//------------------------------------------------------------------------static void genDataPts(    KMdataPtr		&dataPts,	// pointer to data points (returned)    bool		new_clust)	// new cluster centers desired?{    int nClus;					// number of clusters (ignored)    if (dataPts == NULL) {			// allocate storage for points	dataPts = new KMdata(dim, data_size);	// allocate new structure    }    else {	dataPts->resize(dim, data_size);		// else resize    }    KMpointArray pa = dataPts->getPts();	// get the point array    switch (distr) {	case UNIFORM:				// uniform over cube [-1,1]^d.	    kmUniformPts(pa, data_size, dim);	    break;	case GAUSS:				// Gaussian with mean 0	    kmGaussPts(pa, data_size, dim, std_dev);	    break;	case LAPLACE:				// Laplacian, mean 0 and var 1	    kmLaplacePts(pa, data_size, dim);	    break;	case CO_GAUSS:				// correlated Gaussian	    kmCoGaussPts(pa, data_size, dim, corr_coef);	    break;	case CO_LAPLACE:			// correlated Laplacian	    kmCoLaplacePts(pa, data_size, dim, corr_coef);	    break;	case CLUS_GAUSS:			// clustered Gaussian	    kmClusGaussPts(pa, data_size, dim, n_color, new_clust,	    		std_dev, &clus_sep);	    break;	case CLUS_ORTH_FLATS:			// clustered on orthog flats	    kmClusOrthFlats(pa, data_size, dim, n_color, new_clust,	    		std_dev, max_dim);	    break;	case CLUS_ELLIPSOIDS:			// clustered ellipsoids	    kmClusEllipsoids(pa, data_size, dim, n_color, new_clust,	    		std_dev, std_dev_lo, std_dev_hi, max_dim);	    break;	case MULTI_CLUS:			// multi-sized clusters	    kmMultiClus(pa, data_size, dim, nClus, std_dev);	    break;	default:	    kmError("INTERNAL ERROR: Unknown distribution", KMabort);	    break;    }    printDistribSummary(pa, nClus);		// print summary of distrib}//------------------------------------------------------------------------// readDataPts - read a set of data points from a file//------------------------------------------------------------------------static void readDataPts(    KMdataPtr		&dataPts,	// point array (returned)    int			array_size,	// array size    const string	&file_nm)	// file name{    int i;    //--------------------------------------------------------------------    //  Open input file and read points    //--------------------------------------------------------------------    ifstream in_file(file_nm.c_str());		// try to open data file    if (!in_file) {	*kmErr << "File name: " << file_nm << "\n";	kmError("Cannot open input data/query file", KMabort);    }    if (dataPts == NULL) {			// allocate storage for points	dataPts = new KMdata(dim, array_size);	// allocate new structure    }    else {	dataPts->resize(dim, array_size);		// else resize    }    for (i = 0; i < array_size; i++) {		// read the data	if (!(in_file >> (*dataPts)[i][0])) break;	for (int d = 1; d < dim; d++) {	    in_file >> (*dataPts)[i][d];	}    }    char ignore_me;				// character for EOF test    in_file >> ignore_me;			// try to get one more character    if (!in_file.eof()) {			// exhausted space before eof	kmError("data_size too small; input file truncated", KMwarn);    }    int n = i;    dataPts->setNPts(n);			// set number of points    //--------------------------------------------------------------------    //  Print summary    //--------------------------------------------------------------------    if (kmStatLev > SILENT) {	*kmOut << "[Read Data Points:\n"	     << "  data_size  = " << n << "\n"	     << "  file_name  = " << file_nm << "\n"	     << "  dim        = " << dim << "\n";	if (print_points) {			// print the points?	    kmPrintPts("Data_Points", dataPts->getPts(), n, dim);	}	*kmOut << "]" << endl;    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -