⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 markovnet.cpp

📁 gibbs
💻 CPP
字号:
#include "MarkovNet.h"// DEBUG//#define DEBUG//#define DEBUG2//#define DEBUG3MarkovNet::MarkovNet(const BayesNet& bn)    : nodes(new Node[bn.getNumVars() * 2]), numNodes(0){    VarSchema schema = bn.getSchema();    // Create one node per original variable    int numVars = schema.getNumVars();    for (int i = 0; i < numVars; i++) {        nodes[i].index = i;        nodes[i].fixedValue = -1;        nodes[i].marginal = Distribution(schema.getRange(i));        nodes[i].phi = Distribution(schema.getRange(i));        numNodes++;    }#ifdef DEBUG    cout << "Added one node per original variable.\n";#endif    // Link nodes with edges and potential functions; add extra nodes    // as necessary.    for (int i = 0; i < numVars; i++) {#ifdef DEBUG        cout << "Linking node " << i << endl;#endif        DecisionTree& dtree = *bn.decisionTrees[i];        const list<int>& parents  = bn.parents[i];        if (parents.size() == 0) {#ifdef DEBUG            cout << "    No parents; done.\n";#endif             // Construct a marginal distribution from the decision tree            nodes[i].phi = Distribution(schema.getRange(i));            for (int val = 0; val < schema.getRange(i); val++) {                nodes[i].phi[val] = dtree.getProb(val, NULL);            }        } else if (parents.size() == 1) {#ifdef DEBUG            cout << "    One parent; done.\n";#endif            Potential psi(dtree, schema);            addEdge(&nodes[i], &nodes[parents.front()], psi);        } else {            // Multiple parents case#ifdef DEBUG            cout << "    Multiple parents.\n";#endif            // Get total dimension of inputs to decision tree            int inputDim = 1;            list<int>::const_iterator p;            for (p = parents.begin(); p != parents.end(); p++ ) {                inputDim *= schema.getRange(*p);            }            // Add mediator node to represent all values of all parents            Node* mediator = &nodes[numNodes++];            mediator->index      = numNodes-1;            mediator->fixedValue = -1;            mediator->marginal   = Distribution(inputDim);            mediator->phi        = Distribution(inputDim);            // Add edges from mediator node to actual parents,             // to enforce value consistency.            int cumulativeDim = 1;            for (p = parents.begin(); p != parents.end(); p++) {                int parentDim = schema.getRange(*p);                Potential psi(inputDim, parentDim);                for (int r = 0; r < inputDim; r++) {                    for (int c = 0; c < parentDim; c++) {                        psi.set(r, c, 0.0);                    }                    int parentVal = (r/cumulativeDim) % parentDim;                    psi.set(r, parentVal, 1.0);                }                cumulativeDim *= parentDim;                addEdge(mediator, &nodes[*p], psi);                // DEBUG                //cout << "psi(" << *p << ") = " << psi << endl;            }            // Finally, add an edge from the child to the mediator            // using its distribution over parents from the Bayes net.            Potential psi(dtree, schema);            addEdge(&nodes[i], mediator, psi);        }    }}void MarkovNet::addEdge(Node* n1, Node* n2, Potential& psi){    Edge e;    edges.push_back(e);    edges.back().n1 = n1;    edges.back().n2 = n2;    edges.back().psi = psi;#ifdef DEBUG3    cout << "Psi = " << psi << endl;#endif    n1->edges.push_back(&edges.back());    n2->edges.push_back(&edges.back());}#if 0double MarkovNet::getLikelihood(VarSet query, VarSet evidence,        double threshold, double damping){    // DEBUG    cout << "Resetting all nodes\n";    // Set evidence    resetAllNodes();    // DEBUG    cout << "Resetting all messages\n";    resetAllMessages();    // DEBUG    cout << "Setting all evidence\n";    for (int i = 0; i < evidence.getNumVars(); i++) {        if (evidence.isTested(i)) {            fixNodeValue(i, (int)evidence[i]);        }    }    // DEBUG    cout << "Running belief propagation\n";    // Run belief propogation to convergence    runBP(threshold, damping);    // Compute product of marginal distributions    double prob = 1.0;    for (int i = 0; i < query.getNumVars(); i++) {        if (query.isTested(i) && !evidence.isTested(i)) {            prob *= getMarginal(i).get((int)query[i]);        }    }    return prob;}#endifvoid MarkovNet::runBP(double convergenceThreshold, double dampingFactor){    list<Edge*>::iterator neighbor;    // For all fixed nodes, we need only send our messages once.    for (int n = 0; n < numNodes; n++) {        // Only consider fixed nodes at this point        if (nodes[n].fixedValue < 0) {            continue;        }        // Send messages to each neighbor        for (neighbor = nodes[n].edges.begin();                 neighbor != nodes[n].edges.end(); neighbor++) {            Edge* currEdge = *neighbor;            // Don't send messages to other fixed nodes            if (currEdge->otherNode(&nodes[n])->fixedValue >= 0) {                continue;            }            // Send the messages            currEdge->sendMsg(&nodes[n], nodes[n].marginal);        }    }    // For nodes with only fixed neighbors, compute marginals exactly once    for (int n = 0; n < numNodes; n++) {        // Already fixed        if (nodes[n].fixedValue >= 0) {            continue;        }        // Assume fixed until proven otherwise        nodes[n].fixedValue = 0;        for (neighbor = nodes[n].edges.begin();                neighbor != nodes[n].edges.end(); neighbor++) {            // Neighbor not fixed: we guessed wrong            if ((*neighbor)->otherNode(&nodes[n])->fixedValue < 0) {                nodes[n].fixedValue = -1;                break;            }        }        // Compute marginal from all incoming messages        if (nodes[n].fixedValue >= 0) {            nodes[n].marginal = nodes[n].phi;            for (neighbor = nodes[n].edges.begin();                    neighbor != nodes[n].edges.end(); neighbor++) {                nodes[n].marginal *= (*neighbor)->msgTo(&nodes[n]);#ifdef DEBUG2                cout << "Pre-incoming: " << (*neighbor)->msgTo(&nodes[n]) << endl;#endif            }#ifdef DEBUG2            cout << "New marginal: " << nodes[n].marginal << endl;#endif            nodes[n].marginal.normalize();        }    }    int maxIters = 1000;    double delta = 0.0;    int i;    for (i = 0; i < maxIters; i++) {        delta = BPiter(dampingFactor);        if (delta < convergenceThreshold) {            break;        }#ifdef DEBUG        cout << "Iteration " << i << ": delta = " << delta << endl;#endif    }#ifdef DEBUG    if (i == maxIters) {        cout << "Did not converge after " << maxIters << " iterations.\n";    } else {        cout << "Successfully converged in " << i << " iterations.\n";    }    cout << "Final delta: " << delta << endl;#endif    return;}double MarkovNet::BPiter(double dampingFactor){    double delta = 1.0;    list<Edge*>::iterator neighbor;    list<Edge*>::reverse_iterator rneighbor;    int index;    for (int i = 0; i < numNodes; i++) {        // Skip fixed nodes; their messages have already been sent        if (nodes[i].fixedValue >= 0) {            continue;        }         int numNeighbors = nodes[i].edges.size();        vector<Distribution> forwardDistribs(numNeighbors+1);        vector<Distribution> reverseDistribs(numNeighbors);        // Compute products of first n messages and phi (the prior        // distribution at this node), for all n        forwardDistribs[0] = nodes[i].phi;        index = 1;        for (neighbor = nodes[i].edges.begin();                neighbor != nodes[i].edges.end(); neighbor++) {            // Multiply previous distrib by the incoming message            forwardDistribs[index] = forwardDistribs[index-1];            forwardDistribs[index] *= (*neighbor)->msgTo(&nodes[i]);#ifdef DEBUG2            cout << "Incoming message: " << (*neighbor)->msgTo(&nodes[i]) << endl;#endif            index++;        }        // Compute products of last n messages, for all n        // (Except for product of all messages, which is redundant.)        reverseDistribs[0] = Distribution(nodes[i].phi.dim());        index = 1;        for (rneighbor = nodes[i].edges.rbegin();                index < numNeighbors; rneighbor++) {            // Multiply by the incoming message on each edge            reverseDistribs[index] = reverseDistribs[index-1];            reverseDistribs[index] *= (*rneighbor)->msgTo(&nodes[i]);            index++;        }        // The marginal is just the product of all messages and phi,        // which we've computed above.        Distribution marginal = forwardDistribs[numNeighbors];        marginal.normalize();#ifdef DEBUG2        cout << "Marginal: " << marginal << endl;#endif        // Compute messages to send to neighbors        index = 0;        for (neighbor = nodes[i].edges.begin();                neighbor != nodes[i].edges.end(); neighbor++, index++) {            Edge* currEdge = *neighbor;            // Don't send messages to fixed nodes            if (currEdge->otherNode(&nodes[i])->fixedValue >= 0) {                continue;            }            // Compute message            Distribution message = forwardDistribs[index];            if (index != numNeighbors - 1) {                message *= reverseDistribs[numNeighbors - index - 1];            }            // Distribute message            currEdge->sendMsg(&nodes[i], message);#if 0            // HACK DEBUG -- Comparing messages!            Distribution outgoing = nodes[i].phi;            if (nodes[i].fixedValue < 0) {                for (list<Edge*>::iterator n2 = nodes[i].edges.begin();                        n2 != nodes[i].edges.end(); n2++) {                    if (*n2 != currEdge) {                        outgoing *= ((*n2)->n1 == &nodes[i]) ? (*n2)->message2to1                                                             : (*n2)->message1to2;                    }                }            } else {                outgoing = nodes[i].marginal;            }            // DEBUG -- Report inconsistencies between old and new methods            if (nodes[i].fixedValue < 0) {                bool printStuff = false;                for (int a = 0; a < message.dim(); a++) {                    if (outgoing[a]/message[a] > 1.1                            || message[a]/outgoing[a] > 1.1) {                        printStuff = true;                    }                }                if (printStuff) {                    for (int a = 0; a < numNeighbors; a++) {                        cout << "Forward " << a << ": "                             << forwardDistribs[a] << endl;                    }                    for (int a = 0; a < numNeighbors; a++) {                        cout << "Reverse " << a << ": "                             << reverseDistribs[a] << endl;                    }                    cout << "Old: " << outgoing << endl;                    cout << "New: " << message << endl;                    cout << "Marginal: " << marginal << endl;                    cout << "Index: " << index << endl;                     cout << "Num neighbors: " << numNeighbors << endl;                }            }            //  HACK            message = outgoing;#endif        }        // Compare to previous marginal at this node         // and store the new one.        for (int x_i = 0; x_i < marginal.dim(); x_i++) {            double delta_x_i = marginal[x_i]/nodes[i].marginal[x_i];            if (delta_x_i < 1.0) {                delta_x_i = 1.0/delta_x_i;            }            if (delta < delta_x_i) {                delta = delta_x_i;            }        }        nodes[i].marginal = marginal;    }    return delta;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -