⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lwl.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    
    for(int i=0, insCount=0; i < m_Train.numInstances(); i++) {
        switch(m_WeightKernel) {
          case LINEAR:
          case EPANECHNIKOV:
          case TRICUBE:      
              if(insCount<k) {
                  distance[i] = distance(instance, m_Train.instance(i));
                  h.put(i, distance[i]);
              }
              else {
                  MyHeapElement temp = h.peek();
                  distance[i] = distance(instance, m_Train.instance(i), 
                                         temp.distance);
                  if(distance[i]<temp.distance) {
                      h.get();
                      h.put(i, distance[i]);
                  }
              }
              break;
          default:
              distance[i] = distance(instance, m_Train.instance(i));
              break;
        }
        insCount++;
    }
    
    int [] sortKey;
    sortKey = Utils.sort(distance);
    
    if (m_Debug) {
      System.out.println("Instance Distances");
      for (int i = 0; i < sortKey.length; i++) {
	System.out.println("" + distance[sortKey[i]]);
      }
    }
    
    // Determine the bandwidth
    double bandwidth = distance[sortKey[k-1]];

    // Check for bandwidth zero
    if (bandwidth <= 0) {
      for (int i = k; i < sortKey.length; i++) {
	if (distance[sortKey[i]] > bandwidth) {
	  bandwidth = distance[sortKey[i]];
	  break;
	}
      }    
      if(bandwidth <= 0) {
	throw new Exception("All training instances coincide with test "+
                            "instance!");
      }
    }
    
    // Rescale the distances by the bandwidth
    for (int i = 0; i < distance.length; i++) {
      distance[i] = distance[i] / bandwidth;
    }

    // Pass the distances through a weighting kernel
    for (int i = 0; i < distance.length; i++) {
      switch (m_WeightKernel) {
      case LINEAR:
	distance[i] = Math.max(1.0001 - distance[i], 0);
	break;
      case EPANECHNIKOV:
          if(distance[i]<=1)
              distance[i] = 3/4D*(1.0001 - distance[i]*distance[i]);
          else
              distance[i] = 0;
          break;
      case TRICUBE:
          if(distance[i]<=1)
              distance[i] = Math.pow( (1.0001 - Math.pow(distance[i], 3)), 3 );
          else
              distance[i] = 0;
          break;
      case CONSTANT:
          //System.err.println("using constant kernel");
          if(distance[i]<=1)
            distance[i] = 1;
          else
            distance[i] = 0;
          break;
      case INVERSE:
	distance[i] = 1.0 / (1.0 + distance[i]);
	break;
      case GAUSS:
	distance[i] = Math.exp(-distance[i] * distance[i]);
	break;
      }
    }

    if (m_Debug) {
      System.out.println("Instance Weights");
      for (int i = 0; i < sortKey.length; i++) {
	System.out.println("" + distance[sortKey[i]]);
      }
    }

    // Set the weights on a copy of the training data
    Instances weightedTrain = new Instances(m_Train, 0);
    double sumOfWeights = 0, newSumOfWeights = 0;
    for (int i = 0; i < sortKey.length; i++) {
      double weight = distance[sortKey[i]];
      if (weight < 1e-20) {
	break;
      }
      Instance newInst = (Instance) m_Train.instance(sortKey[i]).copy();
      sumOfWeights += newInst.weight();
      newSumOfWeights += newInst.weight() * weight;
      newInst.setWeight(newInst.weight() * weight);
      weightedTrain.add(newInst);
    }
    if (m_Debug) {
      System.out.println("Kept " + weightedTrain.numInstances() + " out of "
			 + m_Train.numInstances() + " instances");
    }
    
    // Rescale weights
    for (int i = 0; i < weightedTrain.numInstances(); i++) {
      Instance newInst = weightedTrain.instance(i);
      newInst.setWeight(newInst.weight() * sumOfWeights / newSumOfWeights);
    }

    // Create a weighted classifier
    m_Classifier.buildClassifier(weightedTrain);

    if (m_Debug) {
      System.out.println("Classifying test instance: " + instance);
      System.out.println("Built base classifier:\n" 
			 + m_Classifier.toString());
    }

    // Return the classifier's predictions
    return m_Classifier.distributionForInstance(instance);
  }
 
  /**
   * Returns a description of this classifier.
   *
   * @return a description of this classifier as a string.
   */
  public String toString() {

    if (m_Train == null) {
      return "Locally weighted learning: No model built yet.";
    }
    String result = "Locally weighted learning\n"
      + "===========================\n";

    result += "Using classifier: " + m_Classifier.getClass().getName() + "\n";

    switch (m_WeightKernel) {
    case LINEAR:
      result += "Using linear weighting kernels\n";
      break;
    case EPANECHNIKOV:
      result += "Using epanechnikov weighting kernels\n";
      break;
    case TRICUBE:
      result += "Using tricube weighting kernels\n";
      break;
    case INVERSE:
      result += "Using inverse-distance weighting kernels\n";
      break;
    case GAUSS:
      result += "Using gaussian weighting kernels\n";
      break;
    case CONSTANT:
      result += "Using constant weighting kernels\n";
      break;
    }
    result += "Using " + (m_UseAllK ? "all" : "" + m_kNN) + " neighbours";
    return result;
  }

  /**
   * Calculates the distance between two instances
   *
   * @param test the first instance
   * @param train the second instance
   * @return the distance between the two given instances, between 0 and 1
   */          
  private double distance(Instance first, Instance second) throws Exception {
      return distance(first, second, Math.sqrt(Double.MAX_VALUE));
  }
  
  /**
   * Calculates the distance between two instances
   *
   * @param test the first instance
   * @param train the second instance
   * @param cutOffValue skips the rest of the calculations and returns Double.Max
   * if distance is going to become larger than this cutOffValue.
   * @return the distance between the two given instances, between 0 and 1
   */          
  private double distance(Instance first, Instance second, double cutOffValue) 
          throws Exception {
    return euclideanDistance(first, second, cutOffValue);
  }    
  
  /**
   * Calculates the euclidean distance between two instances
   *
   * @param test the first instance
   * @param train the second instance
   * @param cutOffValue skips the rest of the calculations and returns Double.Max
   * if distance is going to become larger than this cutOffValue.
   * @return the distance between the two given instances, between 0 and 1
   */
  private double euclideanDistance(Instance first, Instance second, 
                                   double cutOffValue) {

    double distance = 0;
    int firstI, secondI;
    cutOffValue = cutOffValue*cutOffValue;
    
    for (int p1 = 0, p2 = 0; 
	 p1 < first.numValues() || p2 < second.numValues();) {
      if (p1 >= first.numValues()) {
	firstI = m_Train.numAttributes();
      } else {
	firstI = first.index(p1); 
      }
      if (p2 >= second.numValues()) {
	secondI = m_Train.numAttributes();
      } else {
	secondI = second.index(p2);
      }
      if (firstI == m_Train.classIndex()) {
	p1++; continue;
      } 
      if (secondI == m_Train.classIndex()) {
	p2++; continue;
      } 
      double diff;
      if (firstI == secondI) {
	diff = difference(firstI, 
			  first.valueSparse(p1),
			  second.valueSparse(p2));
	p1++; p2++;
      } else if (firstI > secondI) {
	diff = difference(secondI, 
			  0, second.valueSparse(p2));
	p2++;
      } else {
	diff = difference(firstI, 
			  first.valueSparse(p1), 0);
	p1++;
      }
      distance += diff * diff;
      if(distance>cutOffValue)
          return Double.MAX_VALUE; //distance;
    }
    distance = Math.sqrt(distance);
    return distance;
  }
   
  /**
   * Computes the difference between two given attribute
   * values.
   */
  private double difference(int index, double val1, double val2) {
    
    switch (m_Train.attribute(index).type()) {
      case Attribute.NOMINAL:
        
        // If attribute is nominal
        if(Instance.isMissingValue(val1) ||
           Instance.isMissingValue(val2) ||
           ((int)val1 != (int)val2)) {
          return 1;
        } else {
          return 0;
        }
      case Attribute.NUMERIC:
        // If attribute is numeric
        if (Instance.isMissingValue(val1) ||
        Instance.isMissingValue(val2)) {
          if(Instance.isMissingValue(val1) &&
             Instance.isMissingValue(val2)) {
            if(m_NoAttribNorm==false)  //We are doing normalization
              return 1;
            else
              return (m_Max[index] - m_Min[index]);
          } else {
            double diff;
            if (Instance.isMissingValue(val2)) {
              diff = (m_NoAttribNorm==false) ? norm(val1, index) : val1;
            } else {
              diff = (m_NoAttribNorm==false) ? norm(val2, index) : val2;
            }
            if (m_NoAttribNorm==false && diff < 0.5) {
              diff = 1.0 - diff;
            }
            else if (m_NoAttribNorm==true) {
              if((m_Max[index]-diff) > (diff-m_Min[index]))
                return m_Max[index]-diff;
              else
                return diff-m_Min[index];
            }
            return diff;
          }
        } else {
          return (m_NoAttribNorm==false) ? 
                                  (norm(val1, index) - norm(val2, index)) :
                                  (val1 - val2);
        }
      default:
        return 0;
    }
  }

  /**
   * Normalizes a given value of a numeric attribute.
   *
   * @param x the value to be normalized
   * @param i the attribute's index
   */
  private double norm(double x,int i) {

    if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i], m_Min[i])) {
      return 0;
    } else {
      return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);
    }
  }
                      
  /**
   * Updates the minimum and maximum values for all the attributes
   * based on a new instance.
   *
   * @param instance the new instance
   */
  private void updateMinMax(Instance instance) {  

    for (int j = 0; j < m_Train.numAttributes(); j++) {
      if (!instance.isMissing(j)) {
	if (Double.isNaN(m_Min[j])) {
	  m_Min[j] = instance.value(j);
	  m_Max[j] = instance.value(j);
	} else if (instance.value(j) < m_Min[j]) {
	  m_Min[j] = instance.value(j);
	} else if (instance.value(j) > m_Max[j]) {
	  m_Max[j] = instance.value(j);
	}
      }
    }
  }
  

  
  /**
   * Main method for testing this class.
   *
   * @param argv the options
   */
  public static void main(String [] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(
	    new LWL(), argv));
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
  
  
  private class MyHeap {
    //m_heap[0].index containts the current size of the heap
    //m_heap[0].distance is unused.
    MyHeapElement m_heap[] = null;
    public MyHeap(int maxSize) {
        if((maxSize%2)==0)
            maxSize++;
        
        m_heap = new MyHeapElement[maxSize+1];
        m_heap[0] = new MyHeapElement(0, 0);
        //System.err.println("m_heap size is: "+m_heap.length);
    }
    public int size() {
        return m_heap[0].index;
    }
    public MyHeapElement peek() {
        return m_heap[1];
    }
    public MyHeapElement get() throws Exception  {
        if(m_heap[0].index==0)
            throw new Exception("No elements present in the heap");
        MyHeapElement r = m_heap[1];
        m_heap[1] = m_heap[m_heap[0].index];
        m_heap[0].index--;
        downheap();
        return r;
    }
    public void put(int i, double d) throws Exception {
        if((m_heap[0].index+1)>(m_heap.length-1))
            throw new Exception("the number of elements cannot exceed the "+
                                "initially set maximum limit");
        m_heap[0].index++;
        m_heap[m_heap[0].index] = new MyHeapElement(i, d);
        //m_heap[m_heap[0].index].index = i;
        //m_heap[m_heap[0].index].distance = d;        
        //System.err.print("new size: "+(int)m_heap[0]+", ");
        upheap();
    }
    private void upheap() {
        int i = m_heap[0].index;
        MyHeapElement temp;
        while( i > 1  && m_heap[i].distance>m_heap[i/2].distance) {
            temp = m_heap[i];
            m_heap[i] = m_heap[i/2];
            i = i/2;
            m_heap[i] = temp; //this is i/2 done here to avoid another division.
        }
    }
    private void downheap() {
        int i = 1;
        MyHeapElement temp;
        while( (2*i) <= m_heap[0].index && 
                  (m_heap[i].distance < m_heap[2*i].distance || 
                   m_heap[i].distance<m_heap[2*i+1].distance )) {
            if((2*i+1)<=m_heap[0].index) {
                if(m_heap[2*i].distance>m_heap[2*i+1].distance) {
                    temp = m_heap[i];
                    m_heap[i] = m_heap[2*i];
                    i = 2*i;
                    m_heap[i] = temp;
                }
                else {
                    temp = m_heap[i];
                    m_heap[i] = m_heap[2*i+1];
                    i = 2*i+1;
                    m_heap[i] = temp;
                }
            }
            else {
                temp = m_heap[i];
                m_heap[i] = m_heap[2*i];
                i = 2*i;
                m_heap[i] = temp;
            }
        }
    }    
    
  }
  
  private class MyHeapElement {
      int index;
      double distance; 
      public MyHeapElement(int i, double d) {
          distance = d; index = i;
      }
  }
  
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -