📄 phoneticdecisiontree.h
字号:
// file: $isip/class/pr/PhoneticDecisionTree/PhoneticDecisionTree.h// version: $Id: PhoneticDecisionTree.h,v 1.10 2003/04/01 02:06:52 duncan Exp $//// make sure definitions are only made once//#ifndef ISIP_PHONETIC_DECISION_TREE#define ISIP_PHONETIC_DECISION_TREE// isip include files//#ifndef ISIP_DECISION_TREE_BASE#include <DecisionTreeBase.h>#endif#ifndef ISIP_PHONETIC_DECISION_TREE_NODE#include <PhoneticDecisionTreeNode.h>#endif#ifndef ISIP_SEARCH_SYMBOL#include <SearchSymbol.h>#endif#ifndef ISIP_STATISTICAL_MODEL#include <StatisticalModel.h>#endif#ifndef ISIP_GAUSSIAN_MODEL#include <GaussianModel.h>#endif#ifndef ISIP_MIXTURE_MODEL#include <MixtureModel.h>#endif#ifndef ISIP_BIGRAPH_ARC#include <BiGraphArc.h>#endif#ifndef ISIP_CONTEXT_MAP#include <ContextMap.h>#endif#ifndef ISIP_SEARCH_NODE#include <SearchNode.h>#endif// PhoneticDecisionTree: a class that computes the PhoneticDecisionTree.// currently.//class PhoneticDecisionTree: public DecisionTreeBase<PhoneticDecisionTreeNode> { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // other important constants // //---------------------------------------- // define the algorithm choices // enum ALGORITHM { ML = 0, DEF_ALGORITHM = ML }; // define the implementation choices // enum IMPLEMENTATION { DEFAULT = 0, DEF_IMPLEMENTATION = DEFAULT }; // define the static NameMap objects // static const NameMap ALGO_MAP; static const NameMap IMPL_MAP; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; static const String PARAM_ALGORITHM; static const String PARAM_IMPLEMENTATION; static const String PARAM_SPLIT_THRESHOLD; static const String PARAM_MERGE_THRESHOLD; static const String PARAM_NUM_OCC_THRESHOLD; static const String PARAM_BDT; //---------------------------------------- // // other static constants // //---------------------------------------- static const String YES; static const String NO; static const String CPH; static const String POS; //---------------------------------------- // // default values and arguments // //---------------------------------------- // define default values for the thresholds // static const float DEF_SPLIT_THRESHOLD = 10; static const float DEF_MERGE_THRESHOLD = 5; static const float DEF_NUM_OCC_THRESHOLD = 100; //---------------------------------------- // // error codes // //---------------------------------------- static const long ERR = 00100300; //--------------------------------------------------------------------------- // // protected data // //---------------------------------------------------------------------------protected: // define the structures //typedef Triple< Pair<Long, Long>, Float, Boolean> TopoTriple; typedef Triple<Long, StatisticalModel, HashTable<String, String> > DataPoint; typedef SingleLinkedList<DataPoint> Data; typedef BiGraphVertex<PhoneticDecisionTreeNode> TreeNode; // algorithm name // ALGORITHM algorithm_d; // implementation name // IMPLEMENTATION implementation_d; // data on the root node // PhoneticDecisionTreeNode pdt_rootnode_d; // thresholds for building the decision trees // Float split_threshold_d; Float merge_threshold_d; Float num_occ_threshold_d; // static memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //---------------------------------------------------------------------------public: // method: name // static const String& name() { return CLASS_NAME; } // other static methods // static boolean diagnose(Integral::DEBUG debug_level); // debug methods: // setDebug is inherited from the base class // boolean debug(const unichar* msg) const; // method: destructor // ~PhoneticDecisionTree() { } // method: default constructor // PhoneticDecisionTree(ALGORITHM algorithm = DEF_ALGORITHM, IMPLEMENTATION implementation = DEF_IMPLEMENTATION, float split_threshold = DEF_SPLIT_THRESHOLD, float merge_threshold = DEF_MERGE_THRESHOLD, float num_occ_threshold = DEF_NUM_OCC_THRESHOLD) { algorithm_d = algorithm; implementation_d = implementation; split_threshold_d = split_threshold; merge_threshold_d = merge_threshold; num_occ_threshold_d = num_occ_threshold; } // method: copy constructor // PhoneticDecisionTree(const PhoneticDecisionTree& arg) { assign(arg); } // assign methods // boolean assign(const PhoneticDecisionTree& arg); // method: operator= // PhoneticDecisionTree& operator= (const PhoneticDecisionTree& arg) { assign(arg); return *this; } // i/o methods // long sofSize() const; // method: read // boolean read(Sof& sof, long tag) { return read(sof, tag, name()); } boolean read(Sof& sof, long tag, const String& name); // method: write // boolean write(Sof& sof, long tag) const { return write(sof, tag, name()); } boolean write(Sof& sof, long tag, const String& name) const; // method: readData // boolean readData(Sof& sof, const String& pname = DEF_PARAM, long size = SofParser::FULL_OBJECT, boolean param = true, boolean nested = false); // method: writeData // boolean writeData(Sof& sof, const String& pname = DEF_PARAM) const; // equality methods // boolean eq(const PhoneticDecisionTree& arg) const; // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static boolean setGrowSize(long grow_size) { return mgr_d.setGrow(grow_size); } // other memory management methods // boolean clear(Integral::CMODE ctype = Integral::DEF_CMODE); //--------------------------------------------------------------------------- // // class-specific public methods: // set methods // //--------------------------------------------------------------------------- // method: setAlgorithm // boolean setAlgorithm(ALGORITHM algorithm) { algorithm_d = algorithm; return true; } // method: setImplementation // boolean setImplementation(IMPLEMENTATION implementation) { implementation_d = implementation; return true; } // method: setSplitThreshold // boolean setSplitThreshold(float split_threshold) { split_threshold_d = split_threshold; return true; } // method: setMergeThreshold // boolean setMergeThreshold(float merge_threshold) { merge_threshold_d = merge_threshold; return true; } // method: setNumOccThreshold // boolean setNumOccThreshold(float num_occ_threshold) { num_occ_threshold_d = num_occ_threshold; return true; } // method: set // boolean set(ALGORITHM algorithm = DEF_ALGORITHM, IMPLEMENTATION implementation = DEF_IMPLEMENTATION, float split_threshold = DEF_SPLIT_THRESHOLD, float merge_threshold = DEF_MERGE_THRESHOLD, float num_occ_threshold = DEF_NUM_OCC_THRESHOLD) { algorithm_d = algorithm; implementation_d = implementation; split_threshold_d = split_threshold; merge_threshold_d = merge_threshold; num_occ_threshold_d = num_occ_threshold; return true; } //--------------------------------------------------------------------------- // // class-specific public methods: // get methods // //--------------------------------------------------------------------------- // method: getAlgorithm // ALGORITHM getAlgorithm() const { return algorithm_d; } // method: getImplementation // IMPLEMENTATION getImplementation() const { return implementation_d; } // method: getSplitThreshold // float getSplitThreshold() const { return split_threshold_d; } // method: getMergeThreshold // float getMergeThreshold() const { return merge_threshold_d; } // method: getNumOccThreshold // double getNumOccThreshold() const { return num_occ_threshold_d; } // method: get // boolean get(ALGORITHM& algorithm, IMPLEMENTATION& implementation, float& split_threshold, float& merge_threshold, float& num_occ_threshold) { algorithm = algorithm_d; implementation = implementation_d; split_threshold = split_threshold_d; merge_threshold = merge_threshold_d; num_occ_threshold = num_occ_threshold_d; return true; } // method: getStatTrain // boolean getStatTrain(Vector<ContextMap>& context_map, Vector<DiGraph<SearchNode> >& sub_graphs, Vector<SearchSymbol>& symbol_table, Vector<SearchSymbol>& contextless_symbol_table, HashTable<SearchSymbol,Long>& symbol_hash, Vector<StatisticalModel>& stat_models, Filename& phonetic_dt_file, HashTable<SearchSymbol, Long>& tied_model_hash, Vector<StatisticalModel>& tied_stat_models); // method: getStatTest // boolean getStatTest(Vector<ContextMap>& context_map, long& left_context, long& right_context, Vector<SearchSymbol>& upper_symbol_table, Vector<SearchSymbol>& upper_contextless_symbol_table, Vector<DiGraph<SearchNode> >& sub_graphs, Vector<SearchSymbol>& symbol_table, HashTable<SearchSymbol, Long>& symbol_hash, Filename& ques_ans_file); //--------------------------------------------------------------------------- // // class-specific public methods: // computational methods // //--------------------------------------------------------------------------- // method: runDecisionTree // boolean runDecisionTree(); // method: trainDecisionTree // boolean trainDecisionTree(); // method: load // boolean load(const Attributes& attributes, PhoneticDecisionTreeNode& pdtnode); // method: loadTrain // boolean loadTrain(Vector<ContextMap>& context_map, long& left_context, long& right_context, Vector<SearchSymbol>& upper_symbol_table, Vector<SearchSymbol>& contextless_symbol_table, Vector<DiGraph<SearchNode> >& sub_graphs, Vector<SearchSymbol>& symbol_table, HashTable<SearchSymbol, Long>& symbol_hash, Vector<StatisticalModel>& stat_models, Filename& ques_ans_file, HashTable<SearchSymbol, Long>& tied_symbol_hash, Vector<StatisticalModel>& tied_stat_models); // method: loadTest method // boolean loadTest(Filename& phonetic_dt_file); // method to set the parser // boolean setParser(SofParser* parser); //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: // method: classifyDataPoint // Long classifyDataPoint(DataPoint& datapoint); // classification and merging methods // boolean classifyData(TreeNode* node, Attribute& attribute); boolean mergeLeafNodes(TreeNode* start_node, TreeNode* best_node); // subtree manipulation methods // boolean splitSubTree(TreeNode* node); boolean mergeSubTree(TreeNode* node); boolean reindexSubTree(TreeNode* node, long& index); Long findClass(TreeNode* node, DataPoint& datapoint); // method to find best attribute // boolean findBestAttribute(TreeNode* node, Attribute& best_attribute, float& likelihood); // method to find the index of a typical StatisticalModel at a node // Long findTypicalIndex(TreeNode* node); // method to mark a node // boolean markNode(TreeNode* node, boolean& flag); // method to update the typical-index of the best-node // boolean updateTypicalIndex(TreeNode* start_node, TreeNode* best_node); // mathematical manipulation methods // boolean computeSumOccupancy(TreeNode* node, float& sum_num_occ); boolean isSplitOccupancyBelowThreshold(TreeNode* node, Attribute& attribute); boolean computeDeterminantPooledCovariance(TreeNode* node, float& det_pooled_covariance); double computeScale(StatisticalModel& stat_model); // compute likelihood methods // boolean computeLikelihoodNode(TreeNode* node, float& likelihood); boolean computeLikelihoodSplitNode(TreeNode* node, Attribute& attribute, float& split_likelihood); boolean computeLikelihoodMergeNodes(TreeNode* start_node, TreeNode* node, float& merge_likelihood); // contexts generation methods // boolean createContexts(Vector<SearchSymbol>& symbols, long& length, Vector<ContextMap>& all_contexts); boolean appendContextLevel(Vector<SearchSymbol>& symbols, long& level, Vector<ContextMap>& all_contexts); boolean validateContexts(Vector<SearchSymbol>& contextless_symbol_table, Vector<ContextMap>& all_contexts, Vector<ContextMap>& valid_contexts); boolean getUnseenContexts(Vector<ContextMap>& seen_contexts, Vector<ContextMap>& valid_contexts, Vector<ContextMap>& unseen_contexts); // method: updateLowerLevel // boolean updateLowerLevel(Vector<ContextMap>& context_map, Vector<ContextMap>& unseen_context_map, Vector<DiGraph<SearchNode> >& sub_graphs, Vector<SearchSymbol>& symbol_table, HashTable<SearchSymbol,Long>& symbol_hash); // method: getCentralSymbols // boolean getCentralSymbols(Vector<SearchSymbol>& symbol_table, Vector<SearchSymbol>& contextless_symbol_table, SingleLinkedList<String>& central_symbols); // method: readQuestionAnswer // boolean readQuestionAnswer(Filename& ques_ans_file, SingleLinkedList<Pair<Long, String> >& questions, HashTable<String, String>& answers); // method: poolStatisticalModel // boolean poolStatisticalModel(Vector<ContextMap>& context_map, Vector<SearchSymbol>& contextless_symbol_table, Vector<DiGraph<SearchNode> >& sub_graphs, Vector<SearchSymbol>& symbol_table, HashTable<SearchSymbol,Long>& symbol_hash, Vector<StatisticalModel>& stat_models, long& context_len, SingleLinkedList<Pair<Long, String> >& questions, HashTable<String, String>& answers, Data& data, HashTable<SearchSymbol, Long>& tied_symbol_hash, Vector<StatisticalModel>& tied_stat_models); // method: isTiedSSymbol // boolean isTiedSSymbol(SearchSymbol& search_symbol, HashTable<SearchSymbol, Long>& symbol_hash);};// end of include file// #endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -