⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 snow.cpp

📁 基于稀疏网络的精选机器学习模型
💻 CPP
📖 第 1 页 / 共 3 页
字号:
      if (globalParams.curveInterval          && !(examples % globalParams.curveInterval)          && (globalParams.currentCycle > 1              || !globalParams.noFirstCycleUpdate)          && globalParams.testFile.length() > 0)      {        // New way to do the learning curve.  Now it just outputs testing        // results after every curveInterval examples instead of dumping        // another network.        // Prepare network         network->TrainingComplete();  // Which also calls NormalizeConfidence()        *globalParams.pResultsOutput << "Testing after " << examples                                     << " examples in cycle "                                     << globalParams.currentCycle << "...\n";        BasicTest(*network);        *globalParams.pResultsOutput << endl;      }    } //end loop over examples    if (globalParams.currentCycle == 1        && globalParams.eligibilityMethod == ELIGIBILITY_PERCENT)      network->PerformPercentageEligibility();    if (!globalParams.examplesInMemory)    {      trainStream.clear();      trainStream.seekg(0L);    }  } //end for(currentCycle...gp.cycles)  if (globalParams.currentCycle <= globalParams.cycles      && globalParams.verbosity >= VERBOSE_MIN)  {    cout << "Only " << globalParams.currentCycle - 1 << " cycle";    if (globalParams.currentCycle - 1 == 1) cout << " was";    else cout << "s were";    cout << " run.\n";  }  network->TrainingComplete();  // Also calls NormalizeConfidence()  network->Discard();  ofstream output(globalParams.networkFile.c_str());  if (!output)  {    cerr << "Fatal Error:\n";    cerr << "Failed to open network file '"       << globalParams.networkFile.c_str()      << "' for writing\n\n";    Pause();    return;  }  output.setf(ios::fixed);  output.precision(8);  network->Write(output);  output.close();  if (globalParams.testFile.length() > 0)  {    if (globalParams.examplesInMemory) training_set.clear();    cout << "Training complete.  Testing...\n";    globalParams.curveInterval = 0;    globalParams.currentCycle = 0;    globalParams.runMode = MODE_TEST;    BasicTest(*network);  }}void Snow::Test(){  // Copy the inputFile to testFile in this case  globalParams.testFile = globalParams.inputFile;  BasicTest(*network);}void Snow::Interactive() {  ifstream inStream;   globalParams.rawMode = true;  if (globalParams.inputFile != "") {    inStream.open(globalParams.inputFile.c_str());    if (!inStream) {      cerr << "Fatal Error:\n";      cerr << "Failed to open input file: '"	   << globalParams.inputFile.c_str() << "'\n\n";      return;    }  }  // Open the error file if necessary  ofstream errorStream;  if (!globalParams.errorFile.empty())  {    errorStream.open(globalParams.errorFile.c_str());    if (!errorStream)    {      cerr << "Fatal Error:\n";      cerr << "Failed to open error file '" << globalParams.errorFile.c_str()	   << "'\n\n";      Pause();      return;    }    else    {      errorStream << "Algorithms:\n";      network->WriteAlgorithms(&errorStream);      errorStream << endl;    }  }  int examples = 0, correct = 0, suppressed = 0, not_labeled = 0;  Example example(globalParams);  while ((globalParams.inputFile != "") ? !inStream.eof() : !cin.eof()) {       if (globalParams.inputFile != "")      example.Read(inStream);    else      example.Read(cin);    switch(example.command)     {      case 'e':  // evaluate	{	  network->TrainingComplete();	  TargetRanking rank(globalParams, network->SingleTarget(), 			     network->FirstThreshold());	  network->RankTargets(example, rank);	  Output(rank, &errorStream, example, correct, suppressed, examples,		 not_labeled);	}	break;      case 'p':  // promote      case 'd': // demote	network->PresentInteractiveExample(example);	break;      default: // other -- error	break;    }  }  ofstream output(globalParams.networkFile.c_str());  if (!output)  {    cerr << "Fatal Error:\n";    cerr << "Failed to open network file '"	 << globalParams.networkFile.c_str()	 << "' for writing\n\n";    Pause();    return;  }  output.setf(ios::fixed);  output.precision(8);  network->Write(output);  output.close();}void Snow::BasicTest(Network& network){  ofstream outputConjunctionStream;  if (globalParams.writeConjunctions)  {    string outputConjunctionFile = globalParams.testFile + ".conjunctions";    outputConjunctionStream.open(outputConjunctionFile.c_str());    if (globalParams.verbosity >= VERBOSE_MIN        && globalParams.runMode != MODE_TRAIN)      cout << "Writing test examples with conjunctions to file: '"           << outputConjunctionFile << "'\n";  }  // Open the error file if necessary  ofstream errorStream;  if (!globalParams.errorFile.empty())  {    errorStream.open(globalParams.errorFile.c_str());    if (!errorStream)    {      cerr << "Fatal Error:\n";      cerr << "Failed to open error file '" << globalParams.errorFile.c_str()           << "'\n\n";      Pause();      return;    }    else    {      errorStream << "Algorithms:\n";      network.WriteAlgorithms(&errorStream);      errorStream << endl;    }          }  ifstream testStream(globalParams.testFile.c_str());  if (!testStream)  {    cerr << "Fatal Error:\n";    cerr << "Failed to open test file '" << globalParams.testFile.c_str()	 << "'\n\n";    Pause();    return;  }  Example example(globalParams);  TargetRanking rank(globalParams, network.SingleTarget(), 		     network.FirstThreshold());    if (globalParams.verbosity >= VERBOSE_MIN      && globalParams.curveInterval == 0)  {    // output algorithm information    *globalParams.pResultsOutput << "Algorithm information:\n";    network.WriteAlgorithms(globalParams.pResultsOutput);  }  int examples = 0, correct = 0, suppressed = 0, not_labeled = 0;  while (!testStream.eof())  {    rank.clear();     bool readResult = true;        if (globalParams.labelsPresent)      readResult = example.ReadLabeled(testStream);    else readResult = example.Read(testStream);        if (!readResult)    {             if (!testStream.eof())	cerr << "Failed reading example " << examples << " from "	     << globalParams.inputFile.c_str() << endl;    }    else    {      if (globalParams.generateConjunctions == CONJUNCTIONS_ON)	example.GenerateConjunctions();      if (globalParams.writeConjunctions)	example.Write(outputConjunctionStream);      network.RankTargets(example, rank);      ++examples;      Output(rank, &errorStream, example, correct, suppressed, examples,	     not_labeled);            // If we're in online mode      if (globalParams.onlineLearning && globalParams.runMode == MODE_TEST)      {	network.PresentExample(example);	network.TrainingComplete();  // Also calls NormalizeConfidence()      }    }  }  if (globalParams.onlineLearning)  {    if (globalParams.eligibilityMethod == ELIGIBILITY_PERCENT)      network.PerformPercentageEligibility();        // write updated network from online learning to a new network file    char newNetwork[256];    strcpy(newNetwork,globalParams.networkFile.c_str());    strcat(newNetwork, ".new");    ofstream newNetworkOut(newNetwork);    network.Write(newNetworkOut);  }  // take care of final output (accuracy, other statistics)  //network.ShowSize();  FinalOutput(*globalParams.pResultsOutput, correct, suppressed, examples,	      not_labeled);}                  void Snow::Output( TargetRanking &rank, ofstream* errorStream, Example &ex,             int &correct, int &suppressed, int examples, int not_labeled ){  const bool DEBUG_OUTPUT(false);  int i;  bool prediction_wrong = false;  FeatureID prediction;  //char prediction_text[32];  stringstream prediction_stream(stringstream::out);  sort(rank.begin(), rank.end(), greater<TargetRank>());  TargetRanking::iterator it = rank.begin();  TargetRanking::iterator end = rank.end();  if(DEBUG_OUTPUT)  {    int count = 0;    for(; it != end; ++it)      count++;    cerr << "##in Output: targetRanking has " << count << " elements." << endl;    cerr << "## gp.targetOutputLimit is " << globalParams.targetOutputLimit 	 << endl;    it = rank.begin();  }  if (globalParams.targetOutputLimit < rank.size())    end = it + globalParams.targetOutputLimit;  if (rank.single_target)  {    if (it->baseActivation >= rank.threshold)    {      prediction_stream << "positive" << '\0';      prediction = 1;    }    else    {      prediction_stream << "negative" << '\0';      prediction = 0;    }    if (ex.features.Targets() - prediction == 0) ++correct;    else prediction_wrong = true;  }  else  {    if(DEBUG_OUTPUT)    {      cerr << "##gp.targetIds.size is " << globalParams.targetIds.size() 	   << endl;    }    bool rankPrediction = false;    if (globalParams.targetIds.size() == 1 || (rankPrediction = Predict(rank)))    {      prediction = it->id;      if(DEBUG_OUTPUT)      {	cerr << "##Snow::Output(): rankPrediction is " << prediction << endl;      }      prediction_stream << prediction << '\0';      if(DEBUG_OUTPUT)      {	cerr << "##wrote prediction to prediction_stream, which now says: "	     << prediction_stream.str() << "..." << endl;      }      if (ex.features.Targets() == 0) ++not_labeled;      else if (ex.FeatureIsLabel(prediction)) ++correct;      else prediction_wrong = true;    }    else    {      prediction = NO_PREDICTION;      prediction_stream << "no prediction" << '\0';      prediction_wrong = ex.features.Targets() > 0;      ++suppressed;    }  }  switch (globalParams.predictMethod)  {    case ACCURACY:      if (errorStream && errorStream->is_open() && prediction_wrong)      {        *errorStream << "Ex: " << examples << " Prediction: "                     << prediction_stream.str();        if (ex.features.Targets() > 0)        {          *errorStream << " Label: " << ex.features[0].id;          for (i = 1;               i < ex.features.Targets() && globalParams.multipleLabels; ++i)            *errorStream << ", " << ex.features[i].id;        }        else *errorStream << " Not labeled.";

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -