⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cachedkernel.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

package weka.classifiers.functions.supportVector;

import weka.core.*;

/**
 * Base class for RBFKernel and PolyKernel that implements a simple LRU.
 * (least-recently-used) cache if the cache size is set to a value > 0.
 * Otherwise it uses a full cache.
 * 
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @author Shane Legg (shane@intelligenesis.net) (sparse vector code)
 * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code)
 * @author J. Lindgren (jtlindgr{at}cs.helsinki.fi) (RBF kernel)
 * @author Steven Hugg (hugg@fasterlight.com) (refactored, LRU cache)
 * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz) (full cache)
 * @version $Revision: 1.1 $
 */
public abstract class CachedKernel extends Kernel {
    
  /** Counts the number of kernel evaluations. */
  private int m_kernelEvals = 0;

  /** Counts the number of kernel cache hits. */
  private int m_cacheHits = 0;

  /** The size of the cache (a prime number) */
  private int m_cacheSize;

  /** Kernel cache */
  private double[] m_storage;
  protected long[] m_keys;

  /** The kernel matrix if full cache is used (i.e. size is set to 0) */
  private double[][] m_kernelMatrix;

  /** The number of instance in the dataset */
  private int m_numInsts;

  /** number of cache slots in an entry */
  private int m_cacheSlots = 4;


  /**
   * Initializes the kernel cache. The actual size of the cache in bytes is
   * (64 * cacheSize).
   */
  protected CachedKernel(Instances data, int cacheSize) {
    m_data = data;
    m_cacheSize = cacheSize;
    if (cacheSize > 0) {

      // Use LRU cache
      m_storage = new double[m_cacheSize * m_cacheSlots];
      m_keys = new long[m_cacheSize * m_cacheSlots];
    } 

    m_numInsts = m_data.numInstances();
  }

  /**
   * This method is overridden in subclasses to implement specific kernels.
   * 
   * @param id1
   *            the index of instance 1
   * @param id2
   *            the index of instance 2
   * @param inst1
   *            the instance 1 object
   * @return the dot product
   * @throws Exception
   */
  protected abstract double evaluate(int id1, int id2, Instance inst1)
    throws Exception;

  /**
   * Implements the abstract function of Kernel using the cache. This method
   * uses the evaluate() method to do the actual dot product.
   */
  public double eval(int id1, int id2, Instance inst1) throws Exception {
		
    double result = 0;
    long key = -1;
    int location = -1;

    // we can only cache if we know the indexes
    if (id1 >= 0) {

      // Use full cache?
      if (m_cacheSize == 0) {
	if (m_kernelMatrix == null) {
	  m_kernelMatrix = new double[m_data.numInstances()][];
	  for(int i = 0; i < m_data.numInstances(); i++) {
	    m_kernelMatrix[i] = new double[i + 1];
	    for(int j = 0; j <= i; j++) {
	      m_kernelEvals++;
	      m_kernelMatrix[i][j] = evaluate(i, j, m_data.instance(i));
	    }
	  }
	} 
	m_cacheHits++;
	result = (id1 > id2) ? m_kernelMatrix[id1][id2] : m_kernelMatrix[id2][id1];
	return result;
      }

      // Use LRU cache
      if (id1 > id2) {
	key = (id1 + ((long) id2 * m_numInsts));
      } else {
	key = (id2 + ((long) id1 * m_numInsts));
      }
      location = (int) (key % m_cacheSize) * m_cacheSlots;
      int loc = location;
      for (int i = 0; i < m_cacheSlots; i++) {
	long thiskey = m_keys[loc];
	if (thiskey == 0)
	  break; // empty slot, so break out of loop early
	if (thiskey == (key + 1)) {
	  m_cacheHits++;
	  // move entry to front of cache (LRU) by swapping
	  // only if it's not already at the front of cache
	  if (i > 0) {
	    double tmps = m_storage[loc];
	    m_storage[loc] = m_storage[location];
	    m_keys[loc] = m_keys[location];
	    m_storage[location] = tmps;
	    m_keys[location] = thiskey;
	    return tmps;
	  } else
	    return m_storage[loc];
	}
	loc++;
      }
    }

    result = evaluate(id1, id2, inst1);

    m_kernelEvals++;

    // store result in cache
    if (key != -1) {
      // move all cache slots forward one array index
      // to make room for the new entry
      System.arraycopy(m_keys, location, m_keys, location + 1,
		       m_cacheSlots - 1);
      System.arraycopy(m_storage, location, m_storage, location + 1,
		       m_cacheSlots - 1);
      m_storage[location] = result;
      m_keys[location] = (key + 1);
    }
    return result;
  }

  /**
   * Returns the number of time Eval has been called.
   * 
   * @return the number of kernel evaluation.
   */
  public int numEvals() {
    return m_kernelEvals;
  }

  /**
   * Returns the number of cache hits on dot products.
   * 
   * @return the number of cache hits.
   */
  public int numCacheHits() {
    return m_cacheHits;
  }

  /**
   * Frees the cache used by the kernel.
   */
  public void clean() {
    m_storage = null;
    m_keys = null;
    m_kernelMatrix = null;
  }

  /**
   * Calculates a dot product between two instances
   * 
   * @param inst1
   *            the first instance
   * @param inst2
   *            the second instance
   * @return the dot product of the two instances.
   * @exception Exception
   *                if an error occurs
   */
  protected final double dotProd(Instance inst1, Instance inst2)
    throws Exception {

    double result = 0;

    // we can do a fast dot product
    int n1 = inst1.numValues();
    int n2 = inst2.numValues();
    int classIndex = m_data.classIndex();
    for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
      int ind1 = inst1.index(p1);
      int ind2 = inst2.index(p2);
      if (ind1 == ind2) {
	if (ind1 != classIndex) {
	  result += inst1.valueSparse(p1) * inst2.valueSparse(p2);
	}
	p1++;
	p2++;
      } else if (ind1 > ind2) {
	p2++;
      } else {
	p1++;
      }
    }
    return (result);
  }

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -