⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_struct_api.cpp

📁 SVMhmm: Learns a hidden Markov model from examples. Training examples (e.g. for part-of-speech taggi
💻 CPP
📖 第 1 页 / 共 3 页
字号:
/***********************************************************************/
/*                                                                     */
/*   svm_struct_api.c                                                  */
/*                                                                     */
/*   Definition of API for attaching implementing SVM learning of      */
/*   structures (e.g. parsing, multi-label classification, HMM)        */
/*                                                                     */
/*   Author: Thorsten Joachims                                         */
/*   Date: 03.07.04                                                    */
/*                                                                     */
/*   Copyright (c) 2004  Thorsten Joachims - All rights reserved       */
/*                                                                     */
/*   This software is available for non-commercial use only. It must   */
/*   not be modified and distributed without prior permission of the   */
/*   author. The author is not responsible for implications from the   */
/*   use of this software.                                             */
/*                                                                     */
/***********************************************************************/

#include <csignal>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <cctype> //tolower
#include <iostream>
#include <fstream>
#include <sstream> //istringstream
#include <iomanip>
#include <string>
#include <algorithm> //transform()
using namespace std;
#include <ext/hash_map> //this location is compiler-dependent
using __gnu_cxx::hash; //__gnu_cxx is where gcc sticks nonstandard STL stuff
using __gnu_cxx::hash_map;
#include <boost/tuple/tuple.hpp>
using boost::tuple;
#include "svm_struct/svm_struct_common.h"
#include "svm_struct_api.h"

/*
define an assertion handler for when a BOOST assertion gets triggered and we want to be able to trace it upward in gdb
(set gdb to break inside the function below)
*/
namespace boost
{
void assertion_failed(char const * expr, char const * function, char const * file, long line)
{
	printf("boost assertion failed: %s at %s(%ld) (%s)\n", expr, file, line, function);
	printf("\n");
}
}

#ifndef min
#define min(a, b) ((a) < (b) ? (a) : (b))
#endif
#ifndef max
#define max(a, b) ((a) > (b) ? (a) : (b))
#endif

template <class T>
T sqr(T t)
{
	return t * t;
}

/********** debug-printing ************/

void printSentence(PATTERN x)
{
	for(unsigned int i = 0; i < x.getLength(); i++)
		printf("%s ", x.getToken(i).getString().c_str());
	printf("\n");
}

void printLabelSeq(LABEL y)
{
	for(unsigned int i = 0; i < y.getLength(); i++)
		printf("%s ", getTagByID(y.getTag(i)).c_str());
	printf("\n");
}

/**************** tags ****************/

namespace
{
class hashString
{
	public:

		hashString() {}

		size_t operator () (const string& s) const
		{
			return hasher(s.c_str());
		}

	private:

		static const hash<const char*> hasher;
};
const hash<const char*> hashString::hasher = hash<const char*>();

hash_map<tagID, tag> idToTagMap;
hash_map<string, tagID, hashString> tagToIDMap;

/*
during classification, we read in the training-set tags from the model, then
the test-set tags from the input; we want to register the first set but not the second,
so provide a flag
*/
bool registryWritable = true;

}

void setTagRegistryWritable(bool w)
{
	registryWritable = w;
}

/*
if t is in the map,
return a newly assigned unique tag ID
*/
tagID registerTag(const tag& t)
{
	hash_map<string, tagID, hashString>::iterator i = tagToIDMap.find(t);
	if(i != tagToIDMap.end()) //tag has been registered
		return (*i).second;
	else if(registryWritable) //tag has not been registered
	{
		idToTagMap[idToTagMap.size()] = t;
		tagToIDMap[t] = idToTagMap.size() - 1;
		return idToTagMap.size() - 1;
	}
	else //tag has not been registered, but registry is read-only
		return -1; //wraps to UINT_MAX
}

/*
return the number of tags that have been registered
(registering is done while reading input)
*/
unsigned int getNumTags()
{
	return idToTagMap.size();
}

const tag& getTagByID(tagID id) throw(invalid_argument)
{
	if(idToTagMap.find(id) == idToTagMap.end()) throw invalid_argument("getTagByID(): unknown ID");
	return idToTagMap[id];
}

/************* class token ************/

token::token()
{
	initFeatures();
}

token::token(const string& s) : str(s)
{
	initFeatures();
}

token::token(const token& t) : str(t.str)
{
	features = t.features;
}

token::~token()
{
	//features will delete itself
}

/*
initialize the features map/list

should only be called from a constructor
*/
void token::initFeatures()
{
	features = shared_ptr<SVECTOR>(new SVECTOR);
	features->words = (WORD*)my_malloc(sizeof(WORD));
	features->words[0].wnum = 0;
	//svmlight yells if userdefined is NULL; it must have at least one element
	features->userdefined = (char*)my_malloc(sizeof(char));
	features->userdefined[0] = 0;
	features->next = NULL;
	features->factor = 1;
}

const token& token::operator = (const token& t)
{
	str = t.str;
	features = t.features;
	return *this;
}

/************* class label ************/

bool label::operator == (const label& l) const
{
	if(getLength() != l.getLength()) return false;
	for(unsigned int i = 0; i < getLength(); i++)
		if(getTag(i) != l.getTag(i))
			return false;
	return true;
}

/********** class strMatcher **********/

/*
auxiliary to read_struct_examples()
*/
inline strMatcher match(const string& s) {return strMatcher(s);}

/*
auxiliary to read_struct_examples(): try to match a string literal in an input stream

the stream may be partially read if an error occurs
*/
istream& operator >> (istream& in, const strMatcher& m)
{
	if(!in) return in;
	for(unsigned int i = 0; i < m.str.length(); i++)
	{
		if(in.peek() != m.str[i])
		{
			in.setstate(ios_base::failbit); //set failure
			return in;
		}
		in.get(); //extract one character
	}
	return in;
}

/**************************************/

void        svm_struct_learn_api_init(int argc, char* argv[])
{
  /* Called in learning part before anything else is done to allow
     any initializations that might be necessary. */
}

void        svm_struct_learn_api_exit()
{
  /* Called in learning part at the very end to allow any clean-up
     that might be necessary. */
}

void        svm_struct_classify_api_init(int argc, char* argv[])
{
  /* Called in prediction part before anything else is done to allow
     any initializations that might be necessary. */
}

void        svm_struct_classify_api_exit()
{
  /* Called in prediction part at the very end to allow any clean-up
     that might be necessary. */
}

/**************************************/

/*
this function gets called by both the learning and prediction modules

automatically generate the feature vector filename from the POS filename: POS_BASE.ext -> POS_BASE_feats.dat

NOTE we want this to be called before init_struct_model() so we'll have defined the input language size
*/
SAMPLE      read_struct_examples(const char *filename, STRUCT_LEARN_PARM *sparm)
{
	/*
	if we're reading the training set, fSS hasn't been set; if we're on classification,
	it was set when we read in the model
	*/
	bool onClassification = (sparm->featureSpaceSize != 0);

  /* Reads struct examples and returns them in sample. The number of
     examples must be written into sample.n */
  SAMPLE   sample;

  //holding space until we allocate sample.examples; note the shared_ptr default ctor gives a null pointer
  vector<shared_ptr<vector<token> > > tokens;
  vector<shared_ptr<vector<tagID> > > tagIDs;

  ifstream infile(filename);
  if(!infile)
  {
	  fprintf(stderr, "read_struct_examples(): can't open '%s' for reading; exiting\n", filename);
	  exit(-1);
  }

  unsigned int lineNum = 0;
  string line, comment, _tag, word;
  unsigned int exNum, exIndex, featNum, maxFeatNumFound = 0;
  double featVal;
  while(getline(infile, line, '\n') && line.length() > 0) //an empty line ends input
  {
#define PARSE_ERROR(infoDesc, lineNo)\
	{\
		fprintf(stderr, "parse error reading %s on line %u of '%s'", infoDesc, lineNo, filename);\
		exit(-1);\
	}
		//if there's a comment on the line, remove it into another string
		size_t commentIndex = line.find("#", line.find_first_of("1234567890")); //a # before the feature list must be a word; don't remove it
		if(commentIndex != string::npos)
		{
			comment = line.substr(commentIndex + 1);
			line = line.substr(0, commentIndex);
		}
		else comment = "";
		//parse tag
		istringstream instr(line);
		if(!(instr >> _tag >> match(" qid:") >> exNum >> match(".") >> exIndex)) PARSE_ERROR("token info", lineNum);
		//resize temporary storage of tokens and tags
		if(tokens.size() < exNum) //input example and token numbers start at 1
		{
			tokens.resize(exNum);
			tagIDs.resize(exNum);
		}
		if(tokens[exNum - 1].get() == NULL)
		{
			tokens[exNum - 1] = shared_ptr<vector<token> >(new vector<token>);
			tagIDs[exNum - 1] = shared_ptr<vector<tagID> >(new vector<tagID>);
		}
		if(tokens[exNum - 1]->size() < exIndex)
		{
			tokens[exNum - 1]->resize(exIndex);
			tagIDs[exNum - 1]->resize(exIndex);
		}
		(*tagIDs[exNum - 1])[exIndex - 1] = registerTag(_tag); //returns a new id only if the tag hasn't been seen before
		//parse features
		SVECTOR& features = (*tokens[exNum - 1])[exIndex - 1].getFeatureMap();
		unsigned int numFeats = 0;
		while(instr >> featNum >> match(":") >> featVal)
		{
			if(onClassification) //avoid features with higher numbers than what we saw during training
			{
				if(featNum <= sparm->featureSpaceSize)
				{
					features.words = (WORD*)realloc(features.words, ++numFeats * sizeof(WORD));
					features.words[numFeats - 1].wnum = featNum; //feature numbers start at 1 in the input
					features.words[numFeats - 1].weight = featVal;
				}
			}
			else
			{
				features.words = (WORD*)realloc(features.words, ++numFeats * sizeof(WORD));
				features.words[numFeats - 1].wnum = featNum; //feature numbers start at 1 in the input
				features.words[numFeats - 1].weight = featVal;
				if(featNum > maxFeatNumFound) maxFeatNumFound = featNum;
			}
		}
		features.words = (WORD*)realloc(features.words, ++numFeats * sizeof(WORD));
		features.words[numFeats - 1].wnum = 0; //signal to end word list
		if(instr.bad()) PARSE_ERROR("features", lineNum); //read error, as opposed to just reaching end of line
		//parse the comment (first word, if any, is interpreted as the token string; rest is ignored)
		size_t wordStart = comment.find_first_not_of(" \t\n\r");
		if(wordStart != string::npos) //the comment contains non-whitespace
		{
			comment = comment.substr(wordStart, comment.find_last_not_of(" \t\n\r") + 1); //trim whitespace
			istringstream incomment(comment);
			if(incomment >> word) (*tokens[exNum - 1])[exIndex - 1].setString(word);
		}
		lineNum++;
#undef PARSE_ERROR
  }
  infile.close();

	if(!onClassification) //if during training, figure out the feature space size
	{
  		if(maxFeatNumFound == 0)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -