📄 instancelist.java
字号:
ret[i] = cloneMe.cloneEmpty(); if (i > 0) maxind[i] += maxind[i-1]; } for (int i = 0; i < maxind.length; i++) maxind[i] = Math.rint (maxind[i] * instances.size()); int j = 0; // This gives a slight bias toward putting an extra instance in the last InstanceList. for (int i = 0; i < instances.size(); i++) { while (i >= maxind[j]) j++; ret[j].instances.add (instances.get(i)); } return ret; } /** Returns a pair of new lists such that the first list in the pair contains * every <code>m</code>th element of this list, starting with the first. * The second list contains all remaining elements. */ public InstanceList[] splitByModulo (int m) { InstanceList[] ret = new InstanceList[2]; ret[0] = this.cloneEmpty(); ret[1] = this.cloneEmpty(); for (int i = 0; i < this.size(); i++) { if (i % m == 0) ret[0].instances.add (this.get(i)); else ret[1].instances.add (this.get(i)); } return ret; } public InstanceList sampleWithReplacement (java.util.Random r, int numSamples) { InstanceList ret = this.cloneEmpty(); for (int i = 0; i < numSamples; i++) ret.instances.add (this.getInstance(r.nextInt(instances.size()))); return ret; } /** Returns the <code>Instance</code> at the specified index. */ public Instance getInstance (int index) { return (Instance) instances.get (index); } /** * Returns an <code>InstanceList</code> of the same size, where the instances come from the * random sampling (with replacement) of this list using the instance weights. * The new instances all have their weights set to one. */ // added by Gary - ghuang@cs.umass.edu public InstanceList sampleWithInstanceWeights(java.util.Random r) { double[] weights = new double[size()]; for (int i = 0; i < weights.length; i++) weights[i] = getInstanceWeight(i); return sampleWithWeights(r, weights); } /** * Returns an <code>InstanceList</code> of the same size, where the instances come from the * random sampling (with replacement) of this list using the given weights. * The length of the weight array must be the same as the length of this list * The new instances all have their weights set to one. */ // added by Gary - ghuang@cs.umass.edu public InstanceList sampleWithWeights(java.util.Random r, double[] weights) { if (weights.length != size()) throw new IllegalArgumentException("length of weight vector must equal number of instances"); if (size() == 0) return cloneEmpty(); double sumOfWeights = 0; for (int i = 0; i < size(); i++) { if (weights[i] < 0) throw new IllegalArgumentException("weight vector must be non-negative"); sumOfWeights += weights[i]; } if (sumOfWeights <= 0) throw new IllegalArgumentException("weights must sum to positive value"); InstanceList newList = new InstanceList(); double[] probabilities = new double[size()]; double sumProbs = 0; for (int i = 0; i < size(); i++) { sumProbs += r.nextDouble(); probabilities[i] = sumProbs; } MatrixOps.timesEquals(probabilities, sumOfWeights / sumProbs); // make sure rounding didn't mess things up probabilities[size() - 1] = sumOfWeights; // do sampling int a = 0; int b = 0; sumProbs = 0; while (a < size() && b < size()) { sumProbs += weights[b]; while (a < size() && probabilities[a] <= sumProbs) { newList.add(getInstance(b)); newList.setInstanceWeight(a, 1); a++; } b++; } return newList; } //added by Fuchun /** Replaces the <code>Instance</code> at position <code>index</code> * with a new one. */ public void setInstance(int index, Instance instance) { instances.set(index, instance); } public double getInstanceWeight (int index) { if (instanceWeights == null) return 1.0; else return instanceWeights.get(index); } public void setInstanceWeight (int index, double weight) { //System.out.println ("setInstanceWeight index="+index+" weight="+weight); if (weight != getInstanceWeight(index)) { if (instanceWeights == null) instanceWeights = new DoubleList (instances.size(), 1.0); instanceWeights.set (index, weight); } } public void setFeatureSelection (FeatureSelection selectedFeatures) { if (selectedFeatures != null && selectedFeatures.getAlphabet() != null // xxx We allow a null vocabulary here? See CRF3.java && selectedFeatures.getAlphabet() != getDataAlphabet()) throw new IllegalArgumentException ("Vocabularies do not match"); featureSelection = selectedFeatures; } public FeatureSelection getFeatureSelection () { return featureSelection; } public void setPerLabelFeatureSelection (FeatureSelection[] selectedFeatures) { if (selectedFeatures != null) { for (int i = 0; i < selectedFeatures.length; i++) if (selectedFeatures[i].getAlphabet() != getDataAlphabet()) throw new IllegalArgumentException ("Vocabularies do not match"); } perLabelFeatureSelection = selectedFeatures; } public FeatureSelection[] getPerLabelFeatureSelection () { return perLabelFeatureSelection; } /** Sets the "target" field to <code>null</code> in all instances. This makes unlabeled data. */ public void removeTargets() { for (int i = 0; i < instances.size(); i++) getInstance(i).setTarget (null); } /** Sets the "source" field to <code>null</code> in all instances. This will often save memory when the raw data had been placed in that field. */ public void removeSources() { for (int i = 0; i < instances.size(); i++) getInstance(i).clearSource(); } /** Returns the <code>Instance</code> at the specified index. */ public Object get (int index) { return instances.get (index); } /** Constructs a new <code>InstanceList</code>, deserialized from <code>file</code>. If the string value of <code>file</code> is "-", then deserialize from {@link System.in}. */ public static InstanceList load (File file) { try { ObjectInputStream ois; if (file.toString().equals("-")) ois = new ObjectInputStream (System.in); else ois = new ObjectInputStream (new FileInputStream (file)); InstanceList ilist = (InstanceList) ois.readObject(); ois.close(); return ilist; } catch (Exception e) { e.printStackTrace(); throw new IllegalArgumentException ("Couldn't read InstanceList from file "+file); } } // Serialization of InstanceList private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject (ObjectOutputStream out) throws IOException { int i, size; out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(instances); out.writeObject(instanceWeights); out.writeObject(pipe); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int i, size; int version = in.readInt (); instances = (ArrayList) in.readObject(); instanceWeights = (DoubleList) in.readObject(); pipe = (Pipe) in.readObject(); } // added - culotta@cs.umass.edu /** <code>CrossValidationIterator</code> allows iterating over pairs of <code>InstanceList</code>, where each pair is split into training/testing based on nfolds. */ public class CrossValidationIterator implements java.util.Iterator, Serializable { int nfolds; InstanceList[] folds; int index; /** @param _nfolds number of folds to split InstanceList into @param seed seed for random number used to split InstanceList */ public CrossValidationIterator (int _nfolds, int seed) { assert (_nfolds > 0) : "nfolds: " + nfolds; this.nfolds = _nfolds; this.index = 0; folds = new InstanceList[_nfolds]; double fraction = (double) 1 / _nfolds; double[] proportions = new double[_nfolds]; for (int i=0; i < _nfolds; i++) proportions[i] = fraction; folds = split (new java.util.Random (seed), proportions); } public CrossValidationIterator (int _nfolds) { this (_nfolds, 1); } public boolean hasNext () { return index < nfolds; } /** * Returns the next training/testing split. * @return A pair of lists, where <code>InstanceList[0]</code> is the larger split (training) * and <code>InstanceList[1]</code> is the smaller split (testing) */ public InstanceList[] nextSplit () { InstanceList[] ret = new InstanceList[2]; ret[0] = new InstanceList (pipe); for (int i=0; i < folds.length; i++) { if (i==index) continue; InstanceList.Iterator iter = folds[i].iterator(); while (iter.hasNext()) ret[0].add (iter.nextInstance()); } ret[1] = folds[index].shallowClone(); index++; return ret; } /** Returns the next split, given the number of folds you want in * the training data. */ public InstanceList[] nextSplit (int numTrainFolds) { InstanceList[] ret = new InstanceList[2]; ret[0] = new InstanceList (pipe); ret[1] = new InstanceList (pipe); // train on folds [index, index+numTrainFolds), test on rest for (int i = 0; i < folds.length; i++) { int foldno = (index + i) % folds.length; InstanceList addTo; if (i < numTrainFolds) { addTo = ret[0]; } else { addTo = ret[1]; } InstanceList.Iterator iter = folds[foldno].iterator(); while (iter.hasNext()) addTo.add (iter.nextInstance()); } index++; return ret; } public Object next () { return nextSplit(); } public void remove () { throw new UnsupportedOperationException(); } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -