📄 predictionassessmentoperator.java
字号:
// Reset the cursor of the MiningStoredData set, so the transform starts from
// the first reord. Otherwise, the returned object might be NULL. TWang.
assessmentData.reset();
wekaInstances = (Instances) WekaCoreAdapter.PDMMiningInputStream2WekaInstances(assessmentData);
} catch (Exception e) {
e.printStackTrace();
throw new MiningException("Could not call weka's model valuation module correctly.");
};
Attribute targetAtt = wekaInstances.attribute(targetAttributeName);
if (targetAtt != null) {
// Set the target class in WekaInstances
if (targetAtt.numValues() != 2) {
throw new AppException("Can only support data source with a two-level target variable.");
}
wekaInstances.setClass(targetAtt);
} else {
throw new MiningException("Invalid model assessment data.");
}
// Call Weka's evaluation interfaces, for KBBI platform, no costmatrix is
// supplied in the training version.
int instancesNum = wekaInstances.numInstances();
if (instancesNum > 0) {
// classifierResults[0]: observed class index,
// classifierResults[1]: probability with which record belongs to class 0 - the 1st value of the nominal
// attribute;
// classifierResults[2]: probability with which record belongs to calss 1 - the 2nd value of the nominal
// attribute;
classifierResults = new double[instancesNum][3];
try {
for (int i = 0; i < instancesNum; i++) {
double[] dist = ((Classifier) classifer).distributionForInstance(wekaInstances.instance(i));
// ==> Can get the class index information from the input model and all the history
// could be repeated. TWang.
classifierResults[i][0] = (double) (wekaInstances.instance(i).classValue());
classifierResults[i][1] = dist[0];
classifierResults[i][2] = dist[1];
}
} catch (Exception e) {
e.printStackTrace();
throw new MiningException("Could not call weka's model valuation module correctly.");
};
return (calChartDataFromResults(classifierResults, instancesNum, nTargetLevel, chartType, nQuantileNum));
}
return null;
}
private double[][] calChartDataFromResults(double[][] classifierResults, int nResultsNum, int nTargetLevel,
String chartType, int nQuantileNum) throws MiningException, AppException, SysException {
double[][] chartData = null;
int[] sortedResultsIndex = null;
int[] groupedResultsIndex = null;
int nBaseColumnForSorting = 0;
// nTargetLevel: 0--descending order based 2th column's value.
// nTargetLevel: 1--descending order based 3th column's value.
if (EVAUATION_TARGETLEVEL_ZERO == nTargetLevel)
nBaseColumnForSorting = 1;
else if (EVAUATION_TARGETLEVEL_ONE == nTargetLevel)
nBaseColumnForSorting = 2;
sortedResultsIndex = PredictionAssessmentOperatorUtil.arraysorting(classifierResults, nResultsNum,
nBaseColumnForSorting);
groupedResultsIndex = PredictionAssessmentOperatorUtil.groupedIndex(nResultsNum, nQuantileNum);
if (sortedResultsIndex != null && groupedResultsIndex != null) {
chartData = new double[nQuantileNum][2];
// Calculate total hits number and average resonse rate.
int totalHitsNum = 0;
int totalRecordsNum = 0;
totalRecordsNum = nResultsNum;
for (int i = 0; i < totalRecordsNum; i++) {
int classifiedResultsIndex = i;
// Calculate and count the predicted class. TWang.
int predictedClass = classifierResults[classifiedResultsIndex][1] > classifierResults[classifiedResultsIndex][2]
? 0
: 1;
if (EVAUATION_TARGETLEVEL_ZERO == nTargetLevel) {
if (0 == predictedClass) {
totalHitsNum++;
}
} else if (EVAUATION_TARGETLEVEL_ONE == nTargetLevel) {
if (1 == predictedClass) {
totalHitsNum++;
}
}
}
// calculte chartData.
int cumulativeHitsNum = 0;
// used to test the artificially inserted records. TWang.
int previousIndex = -1;
for (int i = 0; i < nQuantileNum; i++) {
chartData[i][0] = i;
int hitsNumPerGroup = 0;
// if only calculate cumulative gain chart, we need not declare
// varialbe "instancesNumPerGroup", the reason for doing so is
// to debug current triaing version, and to calculate other evaulation
// chart data in the future.
int instancesNumPerGroup = 0;
// Calculat the response and hit for each quantile.
int startIndex = groupedResultsIndex[i];
int realIndex = -1; // -1 means error. TWang.
if (nResultsNum > nQuantileNum) {
int numberOfRecord = 0;
if (i != groupedResultsIndex.length - 1) {
numberOfRecord = groupedResultsIndex[i + 1] - groupedResultsIndex[i];
} else {
numberOfRecord = nResultsNum - groupedResultsIndex[i];
}
instancesNumPerGroup = numberOfRecord;
for (int countIndex = 0; countIndex < numberOfRecord; countIndex++) {
realIndex = sortedResultsIndex[startIndex];
int predictedClass = classifierResults[realIndex][1] > classifierResults[realIndex][2] ? 0 : 1;
if (EVAUATION_TARGETLEVEL_ZERO == nTargetLevel) {
if (0 == predictedClass) {
hitsNumPerGroup++;
}
} else if (EVAUATION_TARGETLEVEL_ONE == nTargetLevel) {
if (1 == predictedClass) {
hitsNumPerGroup++;
}
}
startIndex++;
}
cumulativeHitsNum += hitsNumPerGroup;
} else { // If the total record number is smaller than quantile number, each quantile has 1 record
// only.
instancesNumPerGroup = 1;
realIndex = sortedResultsIndex[startIndex];
int predictedClass = classifierResults[realIndex][1] > classifierResults[realIndex][2] ? 0 : 1;
if (EVAUATION_TARGETLEVEL_ZERO == nTargetLevel) {
if (0 == predictedClass) {
hitsNumPerGroup++; // only count the real records, skip all artificially inserted records.
if (startIndex != previousIndex) {
previousIndex = startIndex;
cumulativeHitsNum += hitsNumPerGroup;
}
}
} else if (EVAUATION_TARGETLEVEL_ONE == nTargetLevel) {
if (1 == predictedClass) {
hitsNumPerGroup++;
if (startIndex != previousIndex) {
previousIndex = startIndex;
cumulativeHitsNum += hitsNumPerGroup;
}
}
}
}
if (chartType.equals(PredictionAssessmentOperatorProperty.CHART_TYPE_CUMULATIVE_GAIN)) {
if (0 == totalHitsNum)
chartData[i][1] = 0;
else
// fomular to calculate cumulative gain rate.
chartData[i][1] = cumulativeHitsNum * 100 / (double) totalHitsNum;
} else if (chartType.equals(PredictionAssessmentOperatorProperty.CHART_TYPE_LIFT)) {
if (0 == totalHitsNum)
chartData[i][1] = 0;
else
// Need first convert to double, otherwise NOT correct. TWang.
chartData[i][1] = (hitsNumPerGroup * totalRecordsNum)
/ (double) (instancesNumPerGroup * totalHitsNum);
}
}
// int a = 0;
}
return chartData;
}
/**
* @return Returns the m_OperatorNode.
*/
public IOperatorNode getOperatorNode() {
return m_OperatorNode;
}
/**
* Return true if the opeartor could be connected as the parent. Now always return true if a_Operator is of type:
* InputOperator/ModelOperator/TransformOperator.
*
* @see eti.bi.alphaminer.ui.operator.Operator#acceptParent(eti.bi.alphaminer.ui.operator.Operator)
*/
public boolean acceptParent(Operator a_Operator) {
// Modeling operators accept input and transform operators
if (a_Operator != null) {
// each time rebuild the data set; accurate and fast.
// otherwise, need to handle dataset management when any arrow is inserted or removed.
// try {
// m_DatasetsIDList.clear();
// m_DatasetsList.clear();
// addDataSetRecursively(this);
// m_DataSetBuilt = true;
// } catch (SysException e) {
// e.printStackTrace();
// }
if (a_Operator instanceof InputOperator) {
// if (m_DatasetsIDList != null && !m_DatasetsIDList.contains(a_Operator.getNodeID()))
return true;
} else if (a_Operator instanceof ModelOperator) {
if(acceptParentDefinitionID(a_Operator.getOperatorDefinitionID())){
return true;
}
else{
return false;
}
} else if (a_Operator instanceof TransformOperator) {
// try {
// return !findDataSource(a_Operator.getNodeID());
// } catch (SysException e) {
// e.printStackTrace();
// }
return true;
}
}
return false;
}
/**
* Return true if the node has one parent that is already in the process; Return false otherwise. The input should
* be non-InutDataSource Operator. Mar 17, 2005. TWang.
*
* @param a_HeadNodeID
* @return boolean
* @throws SysException
*/
@SuppressWarnings("unused")
private boolean findDataSource(String a_HeadNodeID) throws SysException {
Vector parentOperators = m_CaseHandler.getParentOperators(m_CaseID, a_HeadNodeID);
if (parentOperators == null)
return false;
for (int i = 0; i < parentOperators.size(); i++) {
Operator parentOp = (Operator) parentOperators.elementAt(i);
if (parentOp instanceof InputOperator) {
if (m_DatasetsIDList.contains(parentOp.getNodeID())) {
return true;
}
} else {
return findDataSource(parentOp.getNodeID());
}
}
return false;
}
//2006/07/29 Xiaojun Chen add for
//register a definitionID so as to make the PredictionAssessmentOperator can accept this type of operator
/**
* register a type of operator for PredictionAssessmentOperator's parent
* @param aParentDefinitionID
* */
public static void registerParentsDefinitionID(String aParentDefinitionID){
if(acceptParentsDefinitionID.indexOf(aParentDefinitionID)<0){
acceptParentsDefinitionID.add(aParentDefinitionID);
}
}
//2006/07/29 Xiaojun Chen add for
//decide if a type of operator can be accepted as PredictionAssessmentOperator's parent
/**
* return rtue if PredictionAssessmentOperator can accept this type of operator as it's parent
* @param aParentDefinitionID
* */
public static boolean acceptParentDefinitionID(String aParentDefinitionID){
if(acceptParentsDefinitionID.indexOf(aParentDefinitionID)>-1){
return true;
}
else{
return false;
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -