📄 ann_test.cpp
字号:
if (!strcmp(arg, "standard")) { method = STANDARD; } else if (!strcmp(arg, "priority")) { method = PRIORITY; } else { cerr << "Search type: " << arg << "\n"; Error("Search type must be \"standard\" or \"priority\"", ANNabort); } if (data_pts == NULL || query_pts == NULL) { Error("Either data set and query set not constructed", ANNabort); } if (the_tree == NULL) { Error("No search tree built.", ANNabort); } //------------------------------------------------------------ // Set up everything //------------------------------------------------------------ #ifdef ANN_PERF // performance only annResetStats(data_size); // reset statistics #endif clock0 = clock(); // start time // deallocate existing storage if (apx_nn_idx != NULL) delete [] apx_nn_idx; if (apx_dists != NULL) delete [] apx_dists; if (apx_pts_in_range != NULL) delete [] apx_pts_in_range; // allocate apx answer storage apx_nn_idx = new ANNidx[near_neigh*query_size]; apx_dists = new ANNdist[near_neigh*query_size]; apx_pts_in_range = new int[query_size]; annMaxPtsVisit(max_pts_visit); // set max points to visit //------------------------------------------------------------ // Run the queries //------------------------------------------------------------ // pointers for current query ANNidxArray curr_nn_idx = apx_nn_idx; ANNdistArray curr_dists = apx_dists; for (int i = 0; i < query_size; i++) { #ifdef ANN_PERF annResetCounts(); // reset counters #endif apx_pts_in_range[i] = 0; if (radius_bound == 0) { // no radius bound if (method == STANDARD) { the_tree->annkSearch( query_pts[i], // query point near_neigh, // number of near neighbors curr_nn_idx, // nearest neighbors (returned) curr_dists, // distance (returned) epsilon); // error bound } else if (method == PRIORITY) { the_tree->annkPriSearch( query_pts[i], // query point near_neigh, // number of near neighbors curr_nn_idx, // nearest neighbors (returned) curr_dists, // distance (returned) epsilon); // error bound } else { Error("Internal error - invalid method", ANNabort); } } else { // use radius bound if (method != STANDARD) { Error("A nonzero radius bound assumes standard search", ANNwarn); } apx_pts_in_range[i] = the_tree->annkFRSearch( query_pts[i], // query point ANN_POW(radius_bound), // squared radius search bound near_neigh, // number of near neighbors curr_nn_idx, // nearest neighbors (returned) curr_dists, // distance (returned) epsilon); // error bound } curr_nn_idx += near_neigh; // increment current pointers curr_dists += near_neigh; #ifdef ANN_PERF annUpdateStats(); // update stats #endif } long query_time = clock() - clock0; // end of query time if (validate) { // validation requested if (valid_dirty) getTrueNN(); // get true near neighbors doValidation(); // validate } //------------------------------------------------------------ // Print summaries //------------------------------------------------------------ if (stats > SILENT) { cout << "[Run Queries:\n"; cout << " query_size = " << query_size << "\n"; cout << " dim = " << dim << "\n"; cout << " search_method = " << arg << "\n"; cout << " epsilon = " << epsilon << "\n"; cout << " near_neigh = " << near_neigh << "\n"; if (max_pts_visit != 0) cout << " max_pts_visit = " << max_pts_visit << "\n"; if (radius_bound != 0) cout << " radius_bound = " << radius_bound << "\n"; if (validate) cout << " true_nn = " << true_nn << "\n"; if (stats >= EXEC_TIME) { // print exec time summary cout << " query_time = " << double(query_time)/(query_size*CLOCKS_PER_SEC) << " sec/query"; #ifdef ANN_PERF cout << " (biased by perf measurements)"; #endif cout << "\n"; } if (stats >= QUERY_STATS) { // output performance stats #ifdef ANN_PERF cout.flush(); annPrintStats(validate); #else cout << " (Performance statistics unavailable.)\n"; #endif } if (stats >= QUERY_RES) { // output results cout << " (Query Results:\n"; cout << " Pt\tANN\tDist\n"; curr_nn_idx = apx_nn_idx; // subarray pointers curr_dists = apx_dists; // output nearest neighbors for (int i = 0; i < query_size; i++) { cout << " " << setw(4) << i; for (int j = 0; j < near_neigh; j++) { // exit if no more neighbors if (curr_nn_idx[j] == ANN_NULL_IDX) { cout << "\t[no other pts in radius bound]\n"; break; } else { // output point info cout << "\t" << curr_nn_idx[j] << "\t" << ANN_ROOT(curr_dists[j]) << "\n"; } } // output range count if (radius_bound != 0) { cout << " pts_in_radius_bound = " << apx_pts_in_range[i] << "\n"; } // increment subarray pointers curr_nn_idx += near_neigh; curr_dists += near_neigh; } cout << " )\n"; } cout << "]\n"; } } //---------------------------------------------------------------- // Unknown directive //---------------------------------------------------------------- else { cerr << "Directive: " << directive << "\n"; Error("Unknown directive", ANNabort); } } //-------------------------------------------------------------------- // End of input loop (deallocate stuff that was allocated) //-------------------------------------------------------------------- if (the_tree != NULL) delete the_tree; if (data_pts != NULL) annDeallocPts(data_pts); if (query_pts != NULL) annDeallocPts(query_pts); if (apx_nn_idx != NULL) delete [] apx_nn_idx; if (apx_dists != NULL) delete [] apx_dists; if (apx_pts_in_range != NULL) delete [] apx_pts_in_range; annClose(); // close ANN return EXIT_SUCCESS;}//------------------------------------------------------------------------// generatePts - call appropriate routine to generate points of a// given distribution.//------------------------------------------------------------------------void generatePts( ANNpointArray &pa, // point array (returned) int n, // number of points to generate PtType type, // point type ANNbool new_clust, // new cluster centers desired? ANNpointArray src, // source array (if distr=PLANTED) int n_src) // source size (if distr=PLANTED){ if (pa != NULL) annDeallocPts(pa); // get rid of any old points pa = annAllocPts(n, dim); // allocate point storage switch (distr) { case UNIFORM: // uniform over cube [-1,1]^d. annUniformPts(pa, n, dim); break; case GAUSS: // Gaussian with mean 0 annGaussPts(pa, n, dim, std_dev); break; case LAPLACE: // Laplacian, mean 0 and var 1 annLaplacePts(pa, n, dim); break; case CO_GAUSS: // correlated Gaussian annCoGaussPts(pa, n, dim, corr_coef); break; case CO_LAPLACE: // correlated Laplacian annCoLaplacePts(pa, n, dim, corr_coef); break; case CLUS_GAUSS: // clustered Gaussian annClusGaussPts(pa, n, dim, n_color, new_clust, std_dev); break; case CLUS_ORTH_FLATS: // clustered on orthog flats annClusOrthFlats(pa, n, dim, n_color, new_clust, std_dev, max_dim); break; case CLUS_ELLIPSOIDS: // clustered ellipsoids annClusEllipsoids(pa, n, dim, n_color, new_clust, std_dev, std_dev_lo, std_dev_hi, max_dim); break; case PLANTED: // planted distribution annPlanted(pa, n, dim, src, n_src, std_dev); break; default: Error("INTERNAL ERROR: Unknown distribution", ANNabort); break; } if (stats > SILENT) { if(type == DATA) cout << "[Generating Data Points:\n"; else cout << "[Generating Query Points:\n"; cout << " number = " << n << "\n"; cout << " dim = " << dim << "\n"; cout << " distribution = " << distr_table[distr] << "\n"; if (annIdum < 0) cout << " seed = " << annIdum << "\n"; if (distr == GAUSS || distr == CLUS_GAUSS || distr == CLUS_ORTH_FLATS) cout << " std_dev = " << std_dev << "\n"; if (distr == CLUS_ELLIPSOIDS) { cout << " std_dev = " << std_dev << " (small) \n"; cout << " std_dev_lo = " << std_dev_lo << "\n"; cout << " std_dev_hi = " << std_dev_hi << "\n"; } if (distr == CO_GAUSS || distr == CO_LAPLACE) cout << " corr_coef = " << corr_coef << "\n"; if (distr == CLUS_GAUSS || distr == CLUS_ORTH_FLATS || distr == CLUS_ELLIPSOIDS) { cout << " colors = " << n_color << "\n"; if (new_clust) cout << " (cluster centers regenerated)\n"; } if (distr == CLUS_ORTH_FLATS || distr == CLUS_ELLIPSOIDS) { cout << " max_dim = " << max_dim << "\n"; } } // want to see points? if ((type == DATA && stats >= SHOW_PTS) || (type == QUERY && stats >= QUERY_RES)) { if(type == DATA) cout << "(Data Points:\n"; else cout << "(Query Points:\n"; for (int i = 0; i < n; i++) { cout << " " << setw(4) << i << "\t"; printPoint(pa[i], dim); cout << "\n"; } cout << " )\n"; } cout << "]\n";}//------------------------------------------------------------------------// readPts - read a collection of data or query points.//------------------------------------------------------------------------void readPts( ANNpointArray &pa, // point array (returned) int &n, // number of points char *file_nm, // file name PtType type) // point type (DATA, QUERY){ int i; //-------------------------------------------------------------------- // Open input file and read points //-------------------------------------------------------------------- ifstream in_file(file_nm); // try to open data file if (!in_file) { cerr << "File name: " << file_nm << "\n"; Error("Cannot open input data/query file", ANNabort); } // allocate storage for points if (pa != NULL) annDeallocPts(pa); // get rid of old points pa = annAllocPts(n, dim); for (i = 0; i < n; i++) { // read the data if (!(in_file >> pa[i][0])) break; for (int d = 1; d < dim; d++) { in_file >> pa[i][d]; } } char ignore_me; // character for EOF test in_file >> ignore_me; // try to get one more character if (!in_file.eof()) { // exhausted space before eof if (type == DATA) Error("`data_size' too small. Input file truncated.", ANNwarn); else Error("`query_size' too small. Input file truncated.", ANNwarn); } n = i; // number of points read //-------------------------------------------------------------------- // Print summary
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -