📄 main.cpp
字号:
// -*-c++-*-//===========================================================//= University of Illinois at Urbana-Champaign =//= Department of Computer Science =//= Dr. Dan Roth - Cognitive Computation Group =//= =//= Project: SNoW =//= =//= Module: Snow.cpp =//= Version: 3.1.4 =//= Authors: Jeff Rosen, Andrew Carlson, Nick Rizzolo, =//= Mark Sammons =//= Date: xx/xx/99 = //= =//= Comments: =//===========================================================//Main.cpp is now a shell for backward compatibility; //instantiates Snow object, possibly as part of server code.#include "TargetRank.h"#include "Network.h"#include "Snow.h"#include "Usage.h"#include "SnowParam.h"class GlobalParams;#ifdef SERVER_MODE_#include "SendReceive.h"#include <numeric>#include <iomanip>//#include <strstream>#include <sstream>#include <arpa/inet.h>#include <netinet/in.h>#include <sys/socket.h>#include <sys/types.h>#include <sys/wait.h>#include <signal.h>#include <pthread.h>#include <stdlib.h>#ifndef WIN32#include <unistd.h>#endif#include <errno.h>#include <stdio.h>#include <ctype.h>#include <string.h>#ifndef WIN32#include <strings.h>#define BACKLOG 5 #define BUFFER_SIZE 16384#endifstatic int serverRun = 1;struct ClientData{ int fileDesc; pthread_mutex_t* pLexiconMutex; Snow * snow;};void RunServer(Snow * s);void SigHandler( int sig );void* ProcessSocket( void* clientData );void ParseOptions(char* options, GlobalParams & globalParams);void SigHandler( int sig ){ cout << "SigHandler got signal " << sig << "...\n"; serverRun = 0;}#endifvoid ShowOptions(GlobalParams & globalParams);int main( int argc, char* argv[] ) { int rtn = 0; unsigned int prediction = 0; int & rtnVal = rtn; GlobalParams globalParams; //Display logo if(globalParams.verbosity != VERBOSE_QUIET){ cout << endl; for (int i = 0; Snow::logo[i][0]; ++i){ cout << Snow::logo[i]; } } if (argc == 1) { // no options were given-- output usage char** usagePtr = usage; while (**usagePtr != '&') { cout << *usagePtr; ++usagePtr; } return 0; } if (!ParseCmdLine(argc, argv, globalParams)){ cout << "Fatal Error: couldn't parse command line. " << endl; return 0; } else { ShowOptions(globalParams); } if (!ValidateParams(globalParams)){ cout << "Fatal Error: couldn't validate globalParams settings. " << endl; return 0; } Snow s(globalParams); if(s.globalParams.outputFile.length() == 0){ if (s.globalParams.verbosity != VERBOSE_QUIET) { cout << "Directing output to "; if (s.globalParams.runMode == MODE_SERVER) cout << "clients"; else cout << "console"; cout << ".\n"; } } //check for mode #ifdef SERVER_MODE_ if(s.globalParams.runMode == MODE_SERVER){ if(s.globalParams.serverPort > 0) RunServer(&s); } else {#endif if (s.globalParams.runMode == MODE_TRAIN) s.Train(); else if (s.globalParams.runMode == MODE_TEST) s.Test(); else if (s.globalParams.runMode == MODE_INTERACTIVE) s.Interactive(); else if (s.globalParams.runMode == MODE_SERVER) { cout << "Fatal Error: run mode is set to SERVER, but Main.cpp was" << " not compiled " << "with SERVER=1 -- " << "necessary functionality is missing. " << endl << "Recompile Main using 'gmake SERVER=1'." << endl; exit(1); } else //MODE_EVAL prediction = s.Evaluate(); if (s.globalParams.runMode == MODE_EVAL) rtnVal = (int)prediction; return rtnVal; }#ifdef SERVER_MODE_}void RunServer(Snow * s){ int sockfd, new_fd; /* listen on sock_fd, new connection on new_fd */ struct sockaddr_in my_addr; /* my address information */ struct sockaddr_in their_addr; /* connector's address information */ socklen_t sin_size; GlobalParams & globalParams = s->globalParams; if ((sockfd = socket(AF_INET, SOCK_STREAM, 0)) == -1) { cerr << "socket: " << strerror(errno) << endl; exit(1); } // if (globalParams.runMode == MODE_SERVER) cout << "Server port: " << globalParams.serverPort << endl; my_addr.sin_family = AF_INET; /* host byte order */ my_addr.sin_port = htons(globalParams.serverPort); my_addr.sin_addr.s_addr = INADDR_ANY; /* automatically fill with my IP */ bzero(&(my_addr.sin_zero), 8); /* zero the rest of the struct */ if (bind(sockfd, (struct sockaddr*)&my_addr, sizeof(struct sockaddr)) == -1) { cerr << "bind: " << strerror(errno) << endl; exit(1); } if (listen(sockfd, BACKLOG) == -1) { cerr << "listen: " << strerror(errno) << endl; exit(1); } pthread_mutex_t lexiconMutex; pthread_mutex_init(&lexiconMutex, NULL); signal(SIGHUP, SigHandler); // Copy the inputFile to testFile in this case globalParams.testFile = globalParams.inputFile; ofstream outputConjunctionStream; if (globalParams.writeConjunctions) { string outputConjunctionFile = globalParams.testFile + ".conjunctions"; outputConjunctionStream.open(outputConjunctionFile.c_str()); if (globalParams.verbosity >= VERBOSE_MIN && globalParams.runMode != MODE_TRAIN) cout << "Writing test examples with conjunctions to file: '" << outputConjunctionFile << "'\n"; } // Open the error file if necessary ofstream errorStream; if (!((globalParams).errorFile.empty())) { errorStream.open(globalParams.errorFile.c_str()); if (!errorStream) { cerr << "Fatal Error:\n"; cerr << "Failed to open error file '" << globalParams.errorFile.c_str() << "'\n\n"; s->Pause(); return; } else { errorStream << "Algorithms:\n"; (s->network)->WriteAlgorithms(&errorStream); errorStream << endl; } } while (serverRun) { /* main accept() loop */ sin_size = sizeof(struct sockaddr_in); cout << "Waiting for clients...\n"; if ((new_fd = accept(sockfd, (struct sockaddr*)&their_addr, &sin_size)) == -1) { cerr << "accept: " << strerror(errno) << endl; continue; } cout << "SNoW: got connection from " << inet_ntoa(their_addr.sin_addr) << endl; ClientData cd; cd.fileDesc = new_fd; // cd.network = (*s).network; cd.pLexiconMutex = &lexiconMutex; // cd.globalParams = &globalParams; cd.snow = s; pthread_t theThread; int threadVal; threadVal = pthread_create(&theThread, NULL, ProcessSocket, &cd); //cout << threadVal << endl; } pthread_mutex_destroy(&lexiconMutex); close(sockfd);}void* ProcessSocket(void* clientData){ const bool DEBUG_SERVER(false); int client_fd = ((ClientData*)clientData)->fileDesc; pthread_mutex_t* pLexiconMutex(((ClientData*)clientData)->pLexiconMutex); Snow * s = ((ClientData*)clientData)->snow; Network* network = s->network; GlobalParams & globalParams = s->globalParams; Example example(globalParams); TargetRanking rank(globalParams, network->SingleTarget(), network->FirstThreshold()); // parent thread doesn't need to 'join' us pthread_detach(pthread_self()); // Block the ourselves from handling SIGHUP // so the main thread always gets it sigset_t blockSet; sigemptyset(&blockSet); sigaddset(&blockSet, SIGHUP); pthread_sigmask(SIG_BLOCK, &blockSet, NULL); char* options; if (receive_bytes(client_fd, options, 0)) { ParseOptions(options, globalParams); delete [] options; } stringstream header(stringstream::out); globalParams.pResultsOutput = &header; if (globalParams.verbosity >= VERBOSE_MIN) { // output algorithm information *(globalParams.pResultsOutput) << "Algorithm information:\n"; network->WriteAlgorithms(globalParams.pResultsOutput); } send_bytes(client_fd, header.str().c_str(), header.str().length(), 0); int examples = 0, correct, suppressed, not_labeled; int count; do { char* msg; // char reply[BUFFER_SIZE]; // memset(reply, 0, BUFFER_SIZE); if(DEBUG_SERVER) { cerr << "##testing stringstream: " << endl; stringstream testss(stringstream::out); globalParams.pResultsOutput = &testss; int num = 4; *(globalParams.pResultsOutput) << "testing testss... int: '" << num << "' should be 4..." << endl; string test_out = testss.str(); cerr << "##testss string is: " << test_out << endl; } stringstream reply(stringstream::out); globalParams.pResultsOutput = &reply; count = receive_bytes(client_fd, msg, 0); if (count > 0) { rank.clear(); bool readResult = true; //reads up to null terminator -- means that //input must be sent one example at a time stringstream testStream; testStream << msg; //= new stringstream(msg, stringstream::in); if(DEBUG_SERVER) { cerr << "##SNoW: received message: " << msg << "calling EvaluateExample()... " << endl; } //EvaluateExample() calls Output() which uses globalParams.pResultsOutput // to send output; we need to make sure that this makes it into reply s->EvaluateExample(&testStream, correct, suppressed, examples, not_labeled, NULL); if(DEBUG_SERVER) { cerr << "##SNoW: evaluated example: reply contains " << reply.str() << "..." << endl; cerr << "##SNoW: sending message " << reply.str().c_str() << "..." << endl; } send_bytes(client_fd, reply.str().c_str(), reply.str().length(), 0); delete [] msg; }// end if count > 0 // delete globalParams.pResultsOutput; } while (count != 0); cout << "Closing client socket...\n"; close(client_fd); return NULL;}void ParseOptions(char* options, GlobalParams & globalParams){ cout << "Options: " << options << endl; options = strtok(options, " "); while (options) { while (options && options[0] != '-') options = strtok(NULL, " "); if (options) { char c = *++options; if (!*++options) options = strtok(NULL, " "); if (options) { switch (c) { case 'b': globalParams.bayesSmoothing = atof(options); break; case 'f': if (*options == '+') globalParams.fixedFeature = true; else if (*options == '-') globalParams.fixedFeature = false; break; case 'L': globalParams.targetOutputLimit = atol(options); break; case 'l': if (*options == '+') globalParams.labelsPresent = true; else if (*options == '-') globalParams.labelsPresent = false; break; case 'm': if (*options == '+') globalParams.multipleLabels = true; else if (*options == '-') globalParams.multipleLabels = false; break; case 'o': if (!strcmp(options, "accuracy")) globalParams.predictMethod = ACCURACY; else if (!strcmp(options, "winners")) globalParams.predictMethod = WINNERS; else if (!strcmp(options, "allpredictions")) globalParams.predictMethod = ALL_PREDICTIONS; else if (!strcmp(options, "allactivations")) globalParams.predictMethod = ALL_ACTIVATIONS; else if (!strcmp(options, "allboth")) globalParams.predictMethod = ALL_BOTH; break; case 'p': globalParams.predictionThreshold = atof(options); break; case 'v': if (!strcmp(options, "off")) globalParams.verbosity = VERBOSE_QUIET; else if (!strcmp(options, "min")) globalParams.verbosity = VERBOSE_MIN; else if (!strcmp(options, "med")) globalParams.verbosity = VERBOSE_MED; else if (!strcmp(options, "max")) globalParams.verbosity = VERBOSE_MAX; break; case 'w': globalParams.smoothing = atof(options); break; } } options = strtok(NULL, " "); } }}#endifvoid ShowOptions(GlobalParams & globalParams){ if (globalParams.verbosity != VERBOSE_QUIET) { if (globalParams.runMode != MODE_EVAL && globalParams.runMode != MODE_SERVER) cout << "Input file: '" << globalParams.inputFile << "'\n"; cout << "Network file: '" << globalParams.networkFile << "'\n"; if (globalParams.testFile.length() > 0) cout << "Test file: '" << globalParams.testFile << "'\n"; if (globalParams.errorFile.length() > 0) cout << "Error file: '" << globalParams.errorFile << "'\n"; #ifdef SERVER_MODE_ if (globalParams.runMode == MODE_SERVER) cout << "Server port: " << globalParams.serverPort << endl;#endif if (globalParams.runMode == MODE_TRAIN) { cout << "Training with " << globalParams.cycles << " cycles over training data.\n"; if (globalParams.examplesInMemory) cout << "Storing examples in memory.\n"; if (globalParams.discardMethod == DISCARD_ABS) cout << "Absolute discarding @ " << globalParams.discardThreshold << endl; } if (globalParams.rawMode) cout << "Conventional (\"raw\") mode enabled.\n"; if (globalParams.thickSeparator.positive || globalParams.thickSeparator.negative) cout << "Thick separator set to " << globalParams.thickSeparator.positive << ", " << globalParams.thickSeparator.negative << ".\n"; if (globalParams.threshold_relative) cout << "Threshold relative updating enabled.\n"; if (globalParams.constraintClassification) cout << "Training with constraint classification enabled.\n"; if (globalParams.gradientDescent) cout << "Gradient descent function approximation enabled.\n"; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -