📄 functionnode.c
字号:
#include "FunctionNode.h"#include "VariableNode.h"#include "Constant.h"//#define EPS 1e-20FunctionNode::FunctionNode(const char *name, int id, int numNeighbours, LocalFunction *localFunction) : Node(name, id, numNeighbours) { (*this).localFunction=localFunction;}void FunctionNode::update(int port) { double sum; //Brut Force implementation of update //for every value in the alphabet of outgoing message for (int i=0; i<(*(VariableNode*)neighbours[port]).getSize(); i++){ int *tempConfiguration = new int[numNeighbours]; //reset tempConfiguration for (int k=0; k<numNeighbours; k++) tempConfiguration[k] = 0; sum = 0.0; summation(0, tempConfiguration, port, i, sum); outMessages[port][i] = sum; delete tempConfiguration; } //normalization of message values sum=0.0; for (i = 0; i < (*(VariableNode*)neighbours[port]).getSize(); i++) { sum +=outMessages[port][i]; } for (i = 0; i < (*(VariableNode*)neighbours[port]).getSize(); i++) { if (sum > EPS) { outMessages[port][i] /=sum; } else { outMessages[port][i] = 1.0/(*(VariableNode*)neighbours[port]).getSize(); } }}void FunctionNode::summation(int neighbourIndex, int *tempConfiguration, int fixVariableIndex, int fixVariableValue, double& sum){ if (neighbourIndex < numNeighbours - 1) { for (int i=0; i < (*(VariableNode*)neighbours[neighbourIndex]).getSize(); i++) { tempConfiguration[neighbourIndex] = i; summation(neighbourIndex + 1, tempConfiguration, fixVariableIndex, fixVariableValue, sum); } }else { //for every configuration calculate a partial sum double tempProduct; for (int k = 0; k < (*(VariableNode*)neighbours[neighbourIndex]).getSize(); k++) { tempConfiguration[neighbourIndex] = k; tempConfiguration[fixVariableIndex] = fixVariableValue; tempProduct = 1.0; //multiply all incomming messages for this configuration for (int l=0; l<numNeighbours; l++){ if (numNeighbours != fixVariableIndex) tempProduct *= (*neighbours[l]).getOutMessage(portNumbers[l])[tempConfiguration[l]]; } //and then multiply the messages with the local function tempProduct *= (*localFunction).getResult(tempConfiguration, numNeighbours); //add this term to the total sum sum += tempProduct; } }} void FunctionNode::reset(void) { for (int n = 0; n < numNeighbours; n++) { int size = (*(VariableNode*)(neighbours[n])).getSize(); for (int i = 0; i < size; i++) { outMessages[n][i] = 0.0; // 1.0/size; } }}void FunctionNode::draw(void) { printf("NODE: %20s, %6d\n", name, id); printf("---------------------------------------------\n"); printf("PORT NEIGHBOUR OUTMESSAGE\n"); for (int n = 0; n < numNeighbours; n++) { printf("%3d %20s%15f\n", n, (*neighbours[n]).getName(), outMessages[n][0]); int size = (*(VariableNode*)(neighbours[n])).getSize(); for (int i = 1; i < size; i++) { printf("%45f\n", outMessages[n][i]); } } printf("\n");}void FunctionNode::draw(FILE *file) { fprintf(file, "NODE: %20s, %6d\n", name, id); fprintf(file, "---------------------------------------------\n"); fprintf(file, "PORT NEIGHBOUR OUTMESSAGE\n"); for (int n = 0; n < numNeighbours; n++) { fprintf(file, "%3d %20s%15f\n", n, (*neighbours[n]).getName(), outMessages[n][0]); int size = (*(VariableNode*)(neighbours[n])).getSize(); for (int i = 1; i < size; i++) { fprintf(file, "%45f\n", outMessages[n][i]); } } fprintf(file, "\n");}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -