⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 testsparsevector.java

📁 常用机器学习算法,java编写源代码,内含常用分类算法,包括说明文档
💻 JAVA
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. *//** 		@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>*/package edu.umass.cs.mallet.base.types.tests;import edu.umass.cs.mallet.base.types.SparseVector;import java.io.*;import junit.framework.*;import edu.umass.cs.mallet.base.types.DenseVector;public class TestSparseVector extends TestCase{	public TestSparseVector (String name) {		super (name);	}		double[] dbl1 = new double[] {1, 2, 3, 4, 5};	double[] dbl2 = new double[] {1, 1.5, 2, 1, 1};	double[] dbl3 = new double[] { 2.0, 2.5, 3.0, 4.7, 3.5,																 3.6, 0,   0,   0,   0,																 0,   0,   0,   0,   0,																 0, };	int[] idxs = new int[] {3, 5, 7, 13, 15};	SparseVector s1 = new SparseVector (idxs, dbl1, dbl1.length, dbl1.length,																			true, true, true);	SparseVector s2 = new SparseVector (idxs, dbl2, dbl2.length, dbl2.length,																			true, true, true);	DenseVector d1 = new DenseVector (dbl3, true);	private void checkAnswer (SparseVector actual, double[] ans)	{		assertEquals ("Wrong number of locations:",									ans.length, actual.numLocations());		for (int i = 0; i < actual.numLocations(); i++) {	    assertEquals ("Value incorrect at location "+i+": ",										ans[i], actual.valueAtLocation (i) , 0.0);		}	}		public void testPlusEquals ()	{		SparseVector s = (SparseVector) s1.cloneMatrix ();		s.plusEqualsSparse (s2, 2.0);		checkAnswer (s, new double[] { 3, 5, 7, 6, 7 }); 		SparseVector s2p = new SparseVector											 (new int[] { 13 },												new double[] { 0.8 });		s.plusEqualsSparse (s2p, 1.0);		checkAnswer (s, new double[] { 3, 5, 7, 6.8, 7 }); 		SparseVector s3p = new SparseVector											 (new int[] { 14 },												new double[] { 0.8 });		s.plusEqualsSparse (s3p, 1.0);		checkAnswer (s, new double[] { 3, 5, 7, 6.8, 7 }); 		// verify s unchanged		SparseVector s4 = new SparseVector											(new int[] { 7, 14, 15 },											 new double[] { 0.2, 0.8, 1.2 });		s.plusEqualsSparse (s4, 1.0);		checkAnswer (s, new double[] { 3, 5, 7.2, 6.8, 8.2 });			SparseVector s5 = new SparseVector (new int[] { 7 }, new double[] { 0.2 });		s5.plusEqualsSparse (s1);		for (int i = 0; i < s5.numLocations(); i++) {	    assertEquals (7, s5.indexAtLocation (i));	    assertEquals (3.2, s5.valueAtLocation (i), 0.0);		}		SparseVector s6 = new SparseVector (new int[] { 7 }, new double[] { 0.2 });		s6.plusEqualsSparse (s1, 3.5);		for (int i = 0; i < s6.numLocations(); i++) {	    assertEquals (7, s6.indexAtLocation (i));	    assertEquals (10.7, s6.valueAtLocation (i), 0.0);		}	}	public void testDotProduct () {		SparseVector t1 = new SparseVector (new int[] { 7 }, new double[] { 0.2 });		assertEquals (0.6, t1.dotProduct (s1), 0.00001);		assertEquals (0.6, s1.dotProduct (t1), 0.00001);				assertEquals (19.0, s1.dotProduct (s2), 0.00001);		assertEquals (19.0, s2.dotProduct (s1), 0.00001);		assertEquals (11.9, s1.dotProduct (d1), 0.00001);		assertEquals (10.1, s2.dotProduct (d1), 0.00001);		// test dotproduct when vector with more locations has a lower		//   max-index than short vector		SparseVector t2 = new SparseVector (new int[] { 3, 30 }, new double[] { 0.2, 3.5 });		SparseVector t3 = new SparseVector (null, new double[] { 1, 1, 1, 1, });		assertEquals (0.2, t3.dotProduct (t2), 0.00001); 	}	public void testIncrementValue ()	{		SparseVector s = (SparseVector) s1.cloneMatrix ();		s.incrementValue (5, 0.75);		double[] ans = new double[] {1, 2.75, 3, 4, 5};		for (int i = 0; i < s.numLocations(); i++) {	    assertTrue (s.valueAtLocation (i) == ans[i]);		}	}		public void testSetValue ()	{		SparseVector s = (SparseVector) s1.cloneMatrix ();		s.setValue (5, 0.3);		double[] ans = new double[] {1, 0.3, 3, 4, 5};		for (int i = 0; i < s.numLocations(); i++) {	    assertTrue (s.valueAtLocation (i) == ans[i]);		}	}	public void testDenseSparseVector ()	{		SparseVector svDense = new SparseVector (null, dbl3);		double sdot = svDense.dotProduct (svDense);		double ddot = d1.dotProduct (d1);		assertEquals (sdot, ddot, 0.0001);		svDense.plusEqualsSparse (s1);		checkAnswer (svDense, new double[] { 2.0, 2.5, 3.0, 5.7, 3.5,																				 5.6, 0,   3,   0,   0,																				 0,   0,   0,   4,   0,																				 5, });		svDense.plusEqualsSparse (s1, 2.0);		checkAnswer (svDense, new double[] { 2.0, 2.5, 3.0, 7.7, 3.5,																				 9.6, 0,   9,   0,   0,																				 0,   0,   0,   12,   0,																				 15, });				double[] dbl4 = new double [dbl3.length + 1];		for (int i = 0; i < dbl4.length; i++) dbl4[i] = 2.0;		SparseVector sv4 = new SparseVector (null, dbl4);		svDense.plusEqualsSparse (sv4);		checkAnswer (svDense, new double[] { 4.0,  4.5,    5.0,  9.7,   5.5,																				 11.6, 2.0,   11.0,  2.0,   2.0,																				 2,   2,   2,   14,   2.0,																				 17, });	}	private static int[] idx2 = { 3, 7, 12, 15, 18 };	public void testBinaryVector ()	{		SparseVector binary1 = new SparseVector (idxs, null, idxs.length, idxs.length,																						 false, false, false);		SparseVector binary2 = new SparseVector (idx2, null, idx2.length, idx2.length,																						false, false, false);		assertEquals (3, binary1.dotProduct (binary2), 0.0001);		assertEquals (3, binary2.dotProduct (binary1), 0.0001);		assertEquals (15.0, binary1.dotProduct (s1), 0.0001);		assertEquals (15.0, s1.dotProduct (binary1), 0.0001);		assertEquals (9.0, binary2.dotProduct (s1), 0.0001);		assertEquals (9.0, s1.dotProduct (binary2), 0.0001);		SparseVector dblVec = (SparseVector) s1.cloneMatrix ();		dblVec.plusEqualsSparse (binary1);		checkAnswer (dblVec, new double[] { 2, 3, 4, 5, 6 });		SparseVector dblVec2 = (SparseVector) s1.cloneMatrix ();		dblVec2.plusEqualsSparse (binary2);		checkAnswer (dblVec2, new double[] { 2, 2, 4, 4, 6 });	}		public void testCloneMatrixZeroed ()	{		SparseVector s = (SparseVector) s1.cloneMatrixZeroed ();		for (int i = 0; i < s.numLocations(); i++) {	    assertTrue (s.valueAtLocation (i) == 0.0);	    assertTrue (s.indexAtLocation (i) == idxs [i]);		}	}	public void testPrint ()	{		ByteArrayOutputStream baos = new ByteArrayOutputStream ();		PrintStream out = new PrintStream (baos);		PrintStream oldOut = System.out;		System.setOut (out);		SparseVector standard = new SparseVector (idxs, dbl2);		standard.print ();		assertEquals ("SparseVector[3] = 1.0\nSparseVector[5] = 1.5\nSparseVector[7] = 2.0\nSparseVector[13] = 1.0\nSparseVector[15] = 1.0\n", baos.toString ());		baos.reset ();		SparseVector dense = new SparseVector (null, dbl2);		dense.print ();		assertEquals ("SparseVector[0] = 1.0\nSparseVector[1] = 1.5\nSparseVector[2] = 2.0\nSparseVector[3] = 1.0\nSparseVector[4] = 1.0\n", baos.toString ());		baos.reset ();		SparseVector binary = new SparseVector (idxs, null, idxs.length, idxs.length,																						false, false, false);		binary.print ();		assertEquals ("SparseVector[3] = 1.0\nSparseVector[5] = 1.0\nSparseVector[7] = 1.0\nSparseVector[13] = 1.0\nSparseVector[15] = 1.0\n", baos.toString ());		baos.reset ();	}	public static Test suite ()	{		return new TestSuite (TestSparseVector.class);	}	protected void setUp ()	{	}	public static void main (String[] args)	{		junit.textui.TestRunner.run (suite());	}	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -