📄 relieffattributeeval.java
字号:
first.valueSparse(p1), 0);
p1++;
}
// distance += diff * diff;
distance += diff;
}
// return Math.sqrt(distance / m_NumAttributesUsed);
return distance;
}
/**
* update attribute weights given an instance when the class is numeric
*
* @param instNum the index of the instance to use when updating weights
*/
private void updateWeightsNumericClass (int instNum) {
int i, j;
double temp,temp2;
int[] tempSorted = null;
double[] tempDist = null;
double distNorm = 1.0;
int firstI, secondI;
Instance inst = m_trainInstances.instance(instNum);
// sort nearest neighbours and set up normalization variable
if (m_weightByDistance) {
tempDist = new double[m_stored[0]];
for (j = 0, distNorm = 0; j < m_stored[0]; j++) {
// copy the distances
tempDist[j] = m_karray[0][j][0];
// sum normalizer
distNorm += m_weightsByRank[j];
}
tempSorted = Utils.sort(tempDist);
}
for (i = 0; i < m_stored[0]; i++) {
// P diff prediction (class) given nearest instances
if (m_weightByDistance) {
temp = difference(m_classIndex,
inst.value(m_classIndex),
m_trainInstances.
instance((int)m_karray[0][tempSorted[i]][1]).
value(m_classIndex));
temp *= (m_weightsByRank[i]/distNorm);
}
else {
temp = difference(m_classIndex,
inst.value(m_classIndex),
m_trainInstances.
instance((int)m_karray[0][i][1]).
value(m_classIndex));
temp *= (1.0/(double)m_stored[0]); // equal influence
}
m_ndc += temp;
Instance cmp;
cmp = (m_weightByDistance)
? m_trainInstances.instance((int)m_karray[0][tempSorted[i]][1])
: m_trainInstances.instance((int)m_karray[0][i][1]);
double temp_diffP_diffA_givNearest =
difference(m_classIndex, inst.value(m_classIndex),
cmp.value(m_classIndex));
// now the attributes
for (int p1 = 0, p2 = 0;
p1 < inst.numValues() || p2 < cmp.numValues();) {
if (p1 >= inst.numValues()) {
firstI = m_trainInstances.numAttributes();
} else {
firstI = inst.index(p1);
}
if (p2 >= cmp.numValues()) {
secondI = m_trainInstances.numAttributes();
} else {
secondI = cmp.index(p2);
}
if (firstI == m_trainInstances.classIndex()) {
p1++; continue;
}
if (secondI == m_trainInstances.classIndex()) {
p2++; continue;
}
temp = 0.0;
temp2 = 0.0;
if (firstI == secondI) {
j = firstI;
temp = difference(j, inst.valueSparse(p1), cmp.valueSparse(p2));
p1++;p2++;
} else if (firstI > secondI) {
j = secondI;
temp = difference(j, 0, cmp.valueSparse(p2));
p2++;
} else {
j = firstI;
temp = difference(j, inst.valueSparse(p1), 0);
p1++;
}
temp2 = temp_diffP_diffA_givNearest * temp;
// P of different prediction and different att value given
// nearest instances
if (m_weightByDistance) {
temp2 *= (m_weightsByRank[i]/distNorm);
}
else {
temp2 *= (1.0/(double)m_stored[0]); // equal influence
}
m_ndcda[j] += temp2;
// P of different attribute val given nearest instances
if (m_weightByDistance) {
temp *= (m_weightsByRank[i]/distNorm);
}
else {
temp *= (1.0/(double)m_stored[0]); // equal influence
}
m_nda[j] += temp;
}
}
}
/**
* update attribute weights given an instance when the class is discrete
*
* @param instNum the index of the instance to use when updating weights
*/
private void updateWeightsDiscreteClass (int instNum) {
int i, j, k;
int cl;
double cc = m_numInstances;
double temp, temp_diff, w_norm = 1.0;
double[] tempDistClass;
int[] tempSortedClass = null;
double distNormClass = 1.0;
double[] tempDistAtt;
int[][] tempSortedAtt = null;
double[] distNormAtt = null;
int firstI, secondI;
// store the indexes (sparse instances) of non-zero elements
Instance inst = m_trainInstances.instance(instNum);
// get the class of this instance
cl = (int)m_trainInstances.instance(instNum).value(m_classIndex);
// sort nearest neighbours and set up normalization variables
if (m_weightByDistance) {
// do class (hits) first
// sort the distances
tempDistClass = new double[m_stored[cl]];
for (j = 0, distNormClass = 0; j < m_stored[cl]; j++) {
// copy the distances
tempDistClass[j] = m_karray[cl][j][0];
// sum normalizer
distNormClass += m_weightsByRank[j];
}
tempSortedClass = Utils.sort(tempDistClass);
// do misses (other classes)
tempSortedAtt = new int[m_numClasses][1];
distNormAtt = new double[m_numClasses];
for (k = 0; k < m_numClasses; k++) {
if (k != cl) // already done cl
{
// sort the distances
tempDistAtt = new double[m_stored[k]];
for (j = 0, distNormAtt[k] = 0; j < m_stored[k]; j++) {
// copy the distances
tempDistAtt[j] = m_karray[k][j][0];
// sum normalizer
distNormAtt[k] += m_weightsByRank[j];
}
tempSortedAtt[k] = Utils.sort(tempDistAtt);
}
}
}
if (m_numClasses > 2) {
// the amount of probability space left after removing the
// probability of this instance's class value
w_norm = (1.0 - m_classProbs[cl]);
}
// do the k nearest hits of the same class
for (j = 0, temp_diff = 0.0; j < m_stored[cl]; j++) {
Instance cmp;
cmp = (m_weightByDistance)
? m_trainInstances.
instance((int)m_karray[cl][tempSortedClass[j]][1])
: m_trainInstances.instance((int)m_karray[cl][j][1]);
for (int p1 = 0, p2 = 0;
p1 < inst.numValues() || p2 < cmp.numValues();) {
if (p1 >= inst.numValues()) {
firstI = m_trainInstances.numAttributes();
} else {
firstI = inst.index(p1);
}
if (p2 >= cmp.numValues()) {
secondI = m_trainInstances.numAttributes();
} else {
secondI = cmp.index(p2);
}
if (firstI == m_trainInstances.classIndex()) {
p1++; continue;
}
if (secondI == m_trainInstances.classIndex()) {
p2++; continue;
}
if (firstI == secondI) {
i = firstI;
temp_diff = difference(i, inst.valueSparse(p1),
cmp.valueSparse(p2));
p1++;p2++;
} else if (firstI > secondI) {
i = secondI;
temp_diff = difference(i, 0, cmp.valueSparse(p2));
p2++;
} else {
i = firstI;
temp_diff = difference(i, inst.valueSparse(p1), 0);
p1++;
}
if (m_weightByDistance) {
temp_diff *=
(m_weightsByRank[j]/distNormClass);
} else {
if (m_stored[cl] > 0) {
temp_diff /= (double)m_stored[cl];
}
}
m_weights[i] -= temp_diff;
}
}
// now do k nearest misses from each of the other classes
temp_diff = 0.0;
for (k = 0; k < m_numClasses; k++) {
if (k != cl) // already done cl
{
for (j = 0, temp = 0.0; j < m_stored[k]; j++) {
Instance cmp;
cmp = (m_weightByDistance)
? m_trainInstances.
instance((int)m_karray[k][tempSortedAtt[k][j]][1])
: m_trainInstances.instance((int)m_karray[k][j][1]);
for (int p1 = 0, p2 = 0;
p1 < inst.numValues() || p2 < cmp.numValues();) {
if (p1 >= inst.numValues()) {
firstI = m_trainInstances.numAttributes();
} else {
firstI = inst.index(p1);
}
if (p2 >= cmp.numValues()) {
secondI = m_trainInstances.numAttributes();
} else {
secondI = cmp.index(p2);
}
if (firstI == m_trainInstances.classIndex()) {
p1++; continue;
}
if (secondI == m_trainInstances.classIndex()) {
p2++; continue;
}
if (firstI == secondI) {
i = firstI;
temp_diff = difference(i, inst.valueSparse(p1),
cmp.valueSparse(p2));
p1++;p2++;
} else if (firstI > secondI) {
i = secondI;
temp_diff = difference(i, 0, cmp.valueSparse(p2));
p2++;
} else {
i = firstI;
temp_diff = difference(i, inst.valueSparse(p1), 0);
p1++;
}
if (m_weightByDistance) {
temp_diff *=
(m_weightsByRank[j]/distNormAtt[k]);
}
else {
if (m_stored[k] > 0) {
temp_diff /= (double)m_stored[k];
}
}
if (m_numClasses > 2) {
m_weights[i] += ((m_classProbs[k]/w_norm)*temp_diff);
} else {
m_weights[i] += temp_diff;
}
}
}
}
}
}
/**
* Find the K nearest instances to supplied instance if the class is numeric,
* or the K nearest Hits (same class) and Misses (K from each of the other
* classes) if the class is discrete.
*
* @param instNum the index of the instance to find nearest neighbours of
*/
private void findKHitMiss (int instNum) {
int i, j;
int cl;
double ww;
double temp_diff = 0.0;
Instance thisInst = m_trainInstances.instance(instNum);
for (i = 0; i < m_numInstances; i++) {
if (i != instNum) {
Instance cmpInst = m_trainInstances.instance(i);
temp_diff = distance(cmpInst, thisInst);
// class of this training instance or 0 if numeric
if (m_numericClass) {
cl = 0;
}
else {
cl = (int)m_trainInstances.instance(i).value(m_classIndex);
}
// add this diff to the list for the class of this instance
if (m_stored[cl] < m_Knn) {
m_karray[cl][m_stored[cl]][0] = temp_diff;
m_karray[cl][m_stored[cl]][1] = i;
m_stored[cl]++;
// note the worst diff for this class
for (j = 0, ww = -1.0; j < m_stored[cl]; j++) {
if (m_karray[cl][j][0] > ww) {
ww = m_karray[cl][j][0];
m_index[cl] = j;
}
}
m_worst[cl] = ww;
}
else
/* if we already have stored knn for this class then check to
see if this instance is better than the worst */
{
if (temp_diff < m_karray[cl][m_index[cl]][0]) {
m_karray[cl][m_index[cl]][0] = temp_diff;
m_karray[cl][m_index[cl]][1] = i;
for (j = 0, ww = -1.0; j < m_stored[cl]; j++) {
if (m_karray[cl][j][0] > ww) {
ww = m_karray[cl][j][0];
m_index[cl] = j;
}
}
m_worst[cl] = ww;
}
}
}
}
}
// ============
// Test method.
// ============
/**
* Main method for testing this class.
*
* @param args the options
*/
public static void main (String[] args) {
try {
System.out.println(AttributeSelection.
SelectAttributes(new ReliefFAttributeEval(), args));
}
catch (Exception e) {
log.error(e.getStackTrace().toString());
log.error(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -