📄 snow.cpp
字号:
if (globalParams.curveInterval && !(examples % globalParams.curveInterval) && (globalParams.currentCycle > 1 || !globalParams.noFirstCycleUpdate) && globalParams.testFile.length() > 0) { // New way to do the learning curve. Now it just outputs testing // results after every curveInterval examples instead of dumping // another network. // Prepare network network->TrainingComplete(); // Which also calls NormalizeConfidence() *globalParams.pResultsOutput << "Testing after " << examples << " examples in cycle " << globalParams.currentCycle << "...\n"; BasicTest(*network); *globalParams.pResultsOutput << endl; } } //end loop over examples if (globalParams.currentCycle == 1 && globalParams.eligibilityMethod == ELIGIBILITY_PERCENT) network->PerformPercentageEligibility(); if (!globalParams.examplesInMemory) { trainStream.clear(); trainStream.seekg(0L); } } //end for(currentCycle...gp.cycles) if (globalParams.currentCycle <= globalParams.cycles && globalParams.verbosity >= VERBOSE_MIN) { cout << "Only " << globalParams.currentCycle - 1 << " cycle"; if (globalParams.currentCycle - 1 == 1) cout << " was"; else cout << "s were"; cout << " run.\n"; } network->TrainingComplete(); // Also calls NormalizeConfidence() network->Discard(); ofstream output(globalParams.networkFile.c_str()); if (!output) { cerr << "Fatal Error:\n"; cerr << "Failed to open network file '" << globalParams.networkFile.c_str() << "' for writing\n\n"; Pause(); return; } output.setf(ios::fixed); output.precision(8); network->Write(output); output.close(); if (globalParams.testFile.length() > 0) { if (globalParams.examplesInMemory) training_set.clear(); cout << "Training complete. Testing...\n"; globalParams.curveInterval = 0; globalParams.currentCycle = 0; globalParams.runMode = MODE_TEST; BasicTest(*network); }}void Snow::Test(){ // Copy the inputFile to testFile in this case globalParams.testFile = globalParams.inputFile; BasicTest(*network);}void Snow::Interactive() { ifstream inStream; globalParams.rawMode = true; if (globalParams.inputFile != "") { inStream.open(globalParams.inputFile.c_str()); if (!inStream) { cerr << "Fatal Error:\n"; cerr << "Failed to open input file: '" << globalParams.inputFile.c_str() << "'\n\n"; return; } } // Open the error file if necessary ofstream errorStream; if (!globalParams.errorFile.empty()) { errorStream.open(globalParams.errorFile.c_str()); if (!errorStream) { cerr << "Fatal Error:\n"; cerr << "Failed to open error file '" << globalParams.errorFile.c_str() << "'\n\n"; Pause(); return; } else { errorStream << "Algorithms:\n"; network->WriteAlgorithms(&errorStream); errorStream << endl; } } int examples = 0, correct = 0, suppressed = 0, not_labeled = 0; Example example(globalParams); while ((globalParams.inputFile != "") ? !inStream.eof() : !cin.eof()) { if (globalParams.inputFile != "") example.Read(inStream); else example.Read(cin); switch(example.command) { case 'e': // evaluate { network->TrainingComplete(); TargetRanking rank(globalParams, network->SingleTarget(), network->FirstThreshold()); network->RankTargets(example, rank); Output(rank, &errorStream, example, correct, suppressed, examples, not_labeled); } break; case 'p': // promote case 'd': // demote network->PresentInteractiveExample(example); break; default: // other -- error break; } } ofstream output(globalParams.networkFile.c_str()); if (!output) { cerr << "Fatal Error:\n"; cerr << "Failed to open network file '" << globalParams.networkFile.c_str() << "' for writing\n\n"; Pause(); return; } output.setf(ios::fixed); output.precision(8); network->Write(output); output.close();}void Snow::BasicTest(Network& network){ ofstream outputConjunctionStream; if (globalParams.writeConjunctions) { string outputConjunctionFile = globalParams.testFile + ".conjunctions"; outputConjunctionStream.open(outputConjunctionFile.c_str()); if (globalParams.verbosity >= VERBOSE_MIN && globalParams.runMode != MODE_TRAIN) cout << "Writing test examples with conjunctions to file: '" << outputConjunctionFile << "'\n"; } // Open the error file if necessary ofstream errorStream; if (!globalParams.errorFile.empty()) { errorStream.open(globalParams.errorFile.c_str()); if (!errorStream) { cerr << "Fatal Error:\n"; cerr << "Failed to open error file '" << globalParams.errorFile.c_str() << "'\n\n"; Pause(); return; } else { errorStream << "Algorithms:\n"; network.WriteAlgorithms(&errorStream); errorStream << endl; } } ifstream testStream(globalParams.testFile.c_str()); if (!testStream) { cerr << "Fatal Error:\n"; cerr << "Failed to open test file '" << globalParams.testFile.c_str() << "'\n\n"; Pause(); return; } Example example(globalParams); TargetRanking rank(globalParams, network.SingleTarget(), network.FirstThreshold()); if (globalParams.verbosity >= VERBOSE_MIN && globalParams.curveInterval == 0) { // output algorithm information *globalParams.pResultsOutput << "Algorithm information:\n"; network.WriteAlgorithms(globalParams.pResultsOutput); } int examples = 0, correct = 0, suppressed = 0, not_labeled = 0; while (!testStream.eof()) { rank.clear(); bool readResult = true; if (globalParams.labelsPresent) readResult = example.ReadLabeled(testStream); else readResult = example.Read(testStream); if (!readResult) { if (!testStream.eof()) cerr << "Failed reading example " << examples << " from " << globalParams.inputFile.c_str() << endl; } else { if (globalParams.generateConjunctions == CONJUNCTIONS_ON) example.GenerateConjunctions(); if (globalParams.writeConjunctions) example.Write(outputConjunctionStream); network.RankTargets(example, rank); ++examples; Output(rank, &errorStream, example, correct, suppressed, examples, not_labeled); // If we're in online mode if (globalParams.onlineLearning && globalParams.runMode == MODE_TEST) { network.PresentExample(example); network.TrainingComplete(); // Also calls NormalizeConfidence() } } } if (globalParams.onlineLearning) { if (globalParams.eligibilityMethod == ELIGIBILITY_PERCENT) network.PerformPercentageEligibility(); // write updated network from online learning to a new network file char newNetwork[256]; strcpy(newNetwork,globalParams.networkFile.c_str()); strcat(newNetwork, ".new"); ofstream newNetworkOut(newNetwork); network.Write(newNetworkOut); } // take care of final output (accuracy, other statistics) //network.ShowSize(); FinalOutput(*globalParams.pResultsOutput, correct, suppressed, examples, not_labeled);} void Snow::Output( TargetRanking &rank, ofstream* errorStream, Example &ex, int &correct, int &suppressed, int examples, int not_labeled ){ const bool DEBUG_OUTPUT(false); int i; bool prediction_wrong = false; FeatureID prediction; //char prediction_text[32]; stringstream prediction_stream(stringstream::out); sort(rank.begin(), rank.end(), greater<TargetRank>()); TargetRanking::iterator it = rank.begin(); TargetRanking::iterator end = rank.end(); if(DEBUG_OUTPUT) { int count = 0; for(; it != end; ++it) count++; cerr << "##in Output: targetRanking has " << count << " elements." << endl; cerr << "## gp.targetOutputLimit is " << globalParams.targetOutputLimit << endl; it = rank.begin(); } if (globalParams.targetOutputLimit < rank.size()) end = it + globalParams.targetOutputLimit; if (rank.single_target) { if (it->baseActivation >= rank.threshold) { prediction_stream << "positive" << '\0'; prediction = 1; } else { prediction_stream << "negative" << '\0'; prediction = 0; } if (ex.features.Targets() - prediction == 0) ++correct; else prediction_wrong = true; } else { if(DEBUG_OUTPUT) { cerr << "##gp.targetIds.size is " << globalParams.targetIds.size() << endl; } bool rankPrediction = false; if (globalParams.targetIds.size() == 1 || (rankPrediction = Predict(rank))) { prediction = it->id; if(DEBUG_OUTPUT) { cerr << "##Snow::Output(): rankPrediction is " << prediction << endl; } prediction_stream << prediction << '\0'; if(DEBUG_OUTPUT) { cerr << "##wrote prediction to prediction_stream, which now says: " << prediction_stream.str() << "..." << endl; } if (ex.features.Targets() == 0) ++not_labeled; else if (ex.FeatureIsLabel(prediction)) ++correct; else prediction_wrong = true; } else { prediction = NO_PREDICTION; prediction_stream << "no prediction" << '\0'; prediction_wrong = ex.features.Targets() > 0; ++suppressed; } } switch (globalParams.predictMethod) { case ACCURACY: if (errorStream && errorStream->is_open() && prediction_wrong) { *errorStream << "Ex: " << examples << " Prediction: " << prediction_stream.str(); if (ex.features.Targets() > 0) { *errorStream << " Label: " << ex.features[0].id; for (i = 1; i < ex.features.Targets() && globalParams.multipleLabels; ++i) *errorStream << ", " << ex.features[i].id; } else *errorStream << " Not labeled.";
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -