📄 kmltest.cpp
字号:
void validateAssignments( KMdataPtr data, // the data points const KMfilterCenters& ctrs, // the center points KMctrIdxArray closeCtr, // closest centers double* sqDist) // squared distances{ int errCt = 0; // number of errors KMdataArray dataPts = data->getPts(); // get points and centers KMcenterArray ctrPts = ctrs.getCtrPts(); int nPts = data->getNPts(); int kCtrs = ctrs.getK(); *kmOut << " (Validating assignments. "; for (int i = 0; i < nPts; i++) { KMdist minDist = KM_DIST_INF; // distance to nearest center int minK = 0; // index of this center KMpoint thisPt = dataPts[i]; // current data point for (int j = 0; j < kCtrs; j++) { // compute closest candidate KMdist dist = kmDist(dim, ctrPts[j], thisPt); if (dist < minDist) { // best so far? minDist = dist; // yes, save it minK = j; // ...and its index } } if (sqDist[i] > minDist + KM_ERR || // distance mismatch sqDist[i] < minDist - KM_ERR) { *kmOut << "\n Mismatch: data[" << i << "] assigned to ctr[" << closeCtr[i] << "] (" << sqDist[i] << ") rather than ctr[" << minK << "] (" << minDist << ")"; errCt++; } } *kmOut << " Found " << errCt << " mismatches.)" << endl;}//------------------------------------------------------------------------// Print summary at end of run//------------------------------------------------------------------------static void printSummary( KMlocalPtr theAlg, // the algorithm KMdataPtr dataPts, // data points KMfilterCenters& ctrs) // the centers{ int nStages = theAlg->getTotalStages(); double totalTime = exec_time + kc_build_time; if (kmStatLev > SILENT) { *kmOut << "\n[k-means completed:\n" << " n_stages = " << nStages << "\n"; if (kmStatLev >= EXEC_TIME) { // print exec time summary *kmOut << " total_time = " << totalTime << " sec\n" << " init_time = " << kc_build_time << " sec\n" << " stage_time = " << double(exec_time)/double(nStages) << " sec/stage_(excl_init) " << double(totalTime)/double(nStages) << " sec/stage_(incl_init)\n" << " average_distort = " << ctrs.getDist(false)/double(ctrs.getNPts()) << "\n"; } if (kmStatLev >= SUMMARY) { // print final results int max_ctrs = kcenters; // number of centers to print if (kmStatLev < STAGE && max_ctrs > 10) max_ctrs = 10; // low interest? just print 10 *kmOut << " (Final Center Points:\n"; ctrs.print(); *kmOut << " )\n"; } *kmOut << "]" << endl; } if (show_assign) { // want to see point assignments? KMctrIdxArray closeCtr = new KMctrIdx[dataPts->getNPts()]; double* sqDist = new double[dataPts->getNPts()]; ctrs.getAssignments(closeCtr, sqDist); *kmOut << " (Cluster assignments:\n" << " Point Center Squared Dist\n" << " ----- ------ ------------\n"; for (int i = 0; i < dataPts->getNPts(); i++) { *kmOut << " " << setw(5) << i << " " << setw(5) << closeCtr[i] << " " << setw(10) << sqDist[i] << "\n"; } *kmOut << " )\n"; if (validate) { validateAssignments(dataPts, ctrs, closeCtr, sqDist); } delete [] closeCtr; delete [] sqDist; }}//------------------------------------------------------------------------// runKmeans - run the k-means algorithm// This procedure is given an algorithm type, a pointer to a set of data// points and termination object.//// This runs the k-means algorithm. We assume that the data point set// has been generated (dataPts != NULL). From these we create the// center points (using KMfilterCenters), print a header (if// desired), resets the performance statistics and resets the clock.//// Based on the algorithm type we create an algorithm object of the// appropriate type and execute it. Afterwards we print a summary of// the results.//------------------------------------------------------------------------static void runKmeans(KMalg alg, KMdataPtr dataPts, KMterm &term){ if (dataPts == NULL) { // failed to create data kmError("Data points have not been generated", KMabort); } // center points KMfilterCenters ctrs(kcenters, *dataPts, damp_factor); printHeader(alg, dataPts, term); // print header KMlocalPtr theAlg = NULL; // the search algorithm switch (alg) { // select the search algorithm case LLOYD: // Lloyd's algorithm theAlg = new KMlocalLloyds(ctrs, term); break; case SWAP: // the swap heuristic theAlg = new KMlocalSwap(ctrs, term, max_swaps); break; case HYBRID: // the hybrid algorithm theAlg = new KMlocalHybrid(ctrs, term); break; case EZ_HYBRID: // the EZ-hybrid algorithm theAlg = new KMlocalEZ_Hybrid(ctrs, term); break; default: kmError("Internal error: Invalid algorithm", KMabort); break; } clock_t start = clock(); // start the clock ctrs = theAlg->execute(); // execute the algorithm exec_time = elapsedTime(start); // get elapsed time // print summary printSummary(theAlg, dataPts, ctrs);}//------------------------------------------------------------------------// Build kc-tree for the points// This should be called whenever the point set is modified// This time to construct the tree is saved in the global variable,// kc_build_time.//------------------------------------------------------------------------static void buildKcTree( // build kc-tree for points KMdataPtr dataPts) // point array{ clock_t start = clock(); // start the clock dataPts->buildKcTree(); // build the tree kc_build_time = elapsedTime(start); // get elapsed time if (kmStatLev >= TREE) { // print the tree (if req'd) *kmOut << "Contents of the kc-tree: [\n"; dataPts->getKcTree()->print(false); *kmOut << "]" << endl; }}//------------------------------------------------------------------------// Print summary of distribution.//------------------------------------------------------------------------void printDistribSummary( KMpointArray pa, // point array int nClus) // number of clusters (for MULTI_CLUS){ if (kmStatLev > SILENT) { *kmOut << "[Generating Data Points:\n" << " data_size = " << data_size << "\n" << " dim = " << dim << "\n"; // output distribution info *kmOut << " distribution = " << distr_table[distr] << "\n"; if (kmIdum < 0) *kmOut << " seed = " << kmIdum << "\n"; if (distr == GAUSS || distr == CLUS_GAUSS || distr == MULTI_CLUS || distr == CLUS_ORTH_FLATS) *kmOut << " std_dev = " << std_dev << "\n"; if (distr == CLUS_ELLIPSOIDS) { *kmOut << " std_dev = " << std_dev << " (small) \n" << " std_dev_lo = " << std_dev_lo << "\n" << " std_dev_hi = " << std_dev_hi << "\n"; } if (distr == CO_GAUSS || distr == CO_LAPLACE) *kmOut << " corr_coef = " << corr_coef << "\n"; if (distr == CLUS_GAUSS || distr == CLUS_ORTH_FLATS || distr == CLUS_ELLIPSOIDS) { *kmOut << " colors = " << n_color << "\n"; if (new_clust) *kmOut << " (cluster centers regenerated)\n"; } if (distr == CLUS_ORTH_FLATS || distr == CLUS_ELLIPSOIDS) { *kmOut << " max_dim = " << max_dim << "\n"; } if (distr == CLUS_GAUSS) { *kmOut << " cluster_sep = " << clus_sep << "\n"; } if (distr == MULTI_CLUS) { *kmOut << " n_clusters = " << nClus << "\n"; } } if (kmStatLev >= TREE) { // want to see points? // clustered gaussian data? if (distr == CLUS_GAUSS) { KMpointArray clusts = kmGetCGclusters(); kmPrintPts("Cluster_Centers", clusts, n_color, dim); } } if (print_points) { // print the points? kmPrintPts("Data_Points", pa, data_size, dim); } *kmOut << "]" << endl;}//------------------------------------------------------------------------// Generate data points from a distribution// genDataPts calls the appropriate generation function and// prints the summary.//------------------------------------------------------------------------static void genDataPts( KMdataPtr &dataPts, // pointer to data points (returned) bool new_clust) // new cluster centers desired?{ int nClus; // number of clusters (ignored) if (dataPts == NULL) { // allocate storage for points dataPts = new KMdata(dim, data_size); // allocate new structure } else { dataPts->resize(dim, data_size); // else resize } KMpointArray pa = dataPts->getPts(); // get the point array switch (distr) { case UNIFORM: // uniform over cube [-1,1]^d. kmUniformPts(pa, data_size, dim); break; case GAUSS: // Gaussian with mean 0 kmGaussPts(pa, data_size, dim, std_dev); break; case LAPLACE: // Laplacian, mean 0 and var 1 kmLaplacePts(pa, data_size, dim); break; case CO_GAUSS: // correlated Gaussian kmCoGaussPts(pa, data_size, dim, corr_coef); break; case CO_LAPLACE: // correlated Laplacian kmCoLaplacePts(pa, data_size, dim, corr_coef); break; case CLUS_GAUSS: // clustered Gaussian kmClusGaussPts(pa, data_size, dim, n_color, new_clust, std_dev, &clus_sep); break; case CLUS_ORTH_FLATS: // clustered on orthog flats kmClusOrthFlats(pa, data_size, dim, n_color, new_clust, std_dev, max_dim); break; case CLUS_ELLIPSOIDS: // clustered ellipsoids kmClusEllipsoids(pa, data_size, dim, n_color, new_clust, std_dev, std_dev_lo, std_dev_hi, max_dim); break; case MULTI_CLUS: // multi-sized clusters kmMultiClus(pa, data_size, dim, nClus, std_dev); break; default: kmError("INTERNAL ERROR: Unknown distribution", KMabort); break; } printDistribSummary(pa, nClus); // print summary of distrib}//------------------------------------------------------------------------// readDataPts - read a set of data points from a file//------------------------------------------------------------------------static void readDataPts( KMdataPtr &dataPts, // point array (returned) int array_size, // array size const string &file_nm) // file name{ int i; //-------------------------------------------------------------------- // Open input file and read points //-------------------------------------------------------------------- ifstream in_file(file_nm.c_str()); // try to open data file if (!in_file) { *kmErr << "File name: " << file_nm << "\n"; kmError("Cannot open input data/query file", KMabort); } if (dataPts == NULL) { // allocate storage for points dataPts = new KMdata(dim, array_size); // allocate new structure } else { dataPts->resize(dim, array_size); // else resize } for (i = 0; i < array_size; i++) { // read the data if (!(in_file >> (*dataPts)[i][0])) break; for (int d = 1; d < dim; d++) { in_file >> (*dataPts)[i][d]; } } char ignore_me; // character for EOF test in_file >> ignore_me; // try to get one more character if (!in_file.eof()) { // exhausted space before eof kmError("data_size too small; input file truncated", KMwarn); } int n = i; dataPts->setNPts(n); // set number of points //-------------------------------------------------------------------- // Print summary //-------------------------------------------------------------------- if (kmStatLev > SILENT) { *kmOut << "[Read Data Points:\n" << " data_size = " << n << "\n" << " file_name = " << file_nm << "\n" << " dim = " << dim << "\n"; if (print_points) { // print the points? kmPrintPts("Data_Points", dataPts->getPts(), n, dim); } *kmOut << "]" << endl; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -