⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 prioritizedsweeping.java

📁 Reinforcement Learning
💻 JAVA
字号:
import java.text.DecimalFormat;
import java.util.Date;
import java.util.Random;
import java.util.Vector;

public class PrioritizedSweeping //implements Algorithms
{
	private double epsilon, discountFactor, pjog;
	private final int PATH_COST= 1;
	private int BACKUPS_ALLOWED;
	private Maze myMaze;

	private StateInfo[][] model;
	private int[][] policy;
	private int[][] selectedAction;
	private double[][][] qsa;
	private ValueFunction currValues;
	public boolean isBestAct = true;
	public boolean receivedPenalty = false;
	State start, curr, next;;
//	double currStateJ, nextStateJ ;		//estimated "reward to go" for the current state
	double priority;
	
	PSSimulator ui;
	DecimalFormat df = new DecimalFormat("0.0");
	Vector backedUpStates;
	

	PriorityQueue pQueue;	//priority queue
	double tinyThreshold;
	final int NUMBER_OF_BACKUPS = 10;
	
	//for analysis purposes
	ValueFunction optVals;
	boolean isOptValCalc;
	double PRECISION=0.01;
	double policyScore;//score of the policy that u get after each episode
	int numEpisodes;
	long startTime, stopTime;
	Random rand = new Random();
	
	public int totalBackups = 0;
	public int numSteps = 0;
	
	public int maxCounter;
    public int maxCounterLast=0;
    public int averageBackups;
    
    //Constant for time-worn exploration
    private static final int Min_Transition_Required = 3;
	
	static class Properties {
	    public static int PJOG=1;
	    public static int MaxBackups=2;
	    public static int Epsilon=3;
	    public static int TinyThreshold=4;
	}
	
	public PrioritizedSweeping(Maze _maze, double _pjog, double _epsilon, int MaxBackUps, double _tinyThreshold)
	{
		myMaze = _maze;
		start = new State(0,0);
		curr = new State(0,0);
		
		pjog = _pjog;
		epsilon = _epsilon;
		BACKUPS_ALLOWED=MaxBackUps;

		model = new StateInfo[myMaze.width][myMaze.height];
		policy = new int[myMaze.width][myMaze.height];
		selectedAction = new int[myMaze.width][myMaze.height];
		
		tinyThreshold = _tinyThreshold;
		
		qsa = new double[myMaze.width][myMaze.height][Action.numActions];
		
		//for analysis purposes
		optVals = new ValueFunction(myMaze.width, myMaze.height);
		isOptValCalc = false;
		
		calcTrueValues();
		initialize();
	}
	
	public void setSeed(long seed)
	{
	    rand.setSeed(seed);
	}
	
	
	public void initialize()
	{
		//initialize pqueue
		pQueue = new PriorityQueue();
		
		//initialize all estimated "value to go" for each state to zero
		for (int i=0;i<model.length;i++)
			for(int j=0;j<model[i].length;j++)
			{
//				System.out.println("i,j "+i+j);
				model[i][j] = new StateInfo();
			}
		
		//initialise policy for all states as -1
		for(int i=0;i<policy.length;i++)
			for (int j=0;j<policy[i].length;j++)
			{
				policy[i][j] = -1;
				selectedAction[i][j] = -1;
			}
		
		//initialize all currValues.stateValues as 0
		currValues = new ValueFunction(myMaze.width, myMaze.height);
		
		//initialize qsa
		for(int i=0;i<qsa.length;i++)
			for (int j=0;j<qsa[i].length;j++)
				for (int k=0;k<qsa[i][j].length;k++)
					qsa[i][j][k] = 0;
		
		
		backedUpStates = new Vector();
		
		numEpisodes = 0;
		policyScore = 0;
		//calcTrueValues();
		maxCounter =0;
	}
	
	public void setProperty (int name, String value)
	{
	    if(name==Properties.PJOG) {
	        pjog = Double.parseDouble(value);
	    }
	    else if (name==Properties.Epsilon){
	        epsilon = Double.parseDouble(value);
	    }
	    else if (name==Properties.MaxBackups){
	        BACKUPS_ALLOWED = Integer.parseInt(value);
	    }
	    else if (name==Properties.TinyThreshold){
	        tinyThreshold = Double.parseDouble(value);
	    }
	}
	
	public ValueFunction getValueFunction()
	{
		for(int i=0;i<currValues.stateValue.length;i++)
			for(int j=0;j<currValues.stateValue[i].length;j++)
			{
				currValues.stateValue[i][j] = model[i][j].getMinQsa();
			}
	    return currValues;
	}
	
	public int[][] getPolicy()
	{
		for(int i=0;i<policy.length;i++)
			for(int j=0;j<policy[i].length;j++)
				policy[i][j] = model[i][j].getBestAction();
		return policy;
	}
	
	public double[][][] getQsa()
	{
		return qsa;
	}
	
	public State getCurrState()
	{
		return curr;
	}
	
	private boolean reachedGoal(State s)
	{
		if(myMaze.goals.contains(s))
			return true;
		else
			return false;
	}
	
	public void execute(int numIterations)
	{
		curr.copy(start);
		while(!reachedGoal(curr))
		{
			step();
		}		
	}
	
	
	private State getStart()
	{
	    int x = rand.nextInt(myMaze.width);
	    int y = rand.nextInt(myMaze.height);
	    //System.out.print("("+x+","+y+")");
	    State start = new State(x,y);
	    return start;
	}
	public boolean step()
	{
		double reward;
		int currAction;
		
		if(reachedGoal(curr)) {
			//recordStopTime();
		    maxCounterLast =maxCounter;
		    maxCounter=0;
			numEpisodes += 1;
			//curr=copy(start);
			curr = getStart();
			if(numSteps!=0)
			    averageBackups = totalBackups/numSteps;
			else 
			    averageBackups = 0;
			totalBackups = 0;
			numSteps =0;
			/*policyScore = evalPolicy();
			System.out.print(numEpisodes+"\t"+policyScore+"\n");
			*/
			return true;
		}
		
		//Choose action using epsilon greedy exploration policy
		currAction = chooseAction(curr, Math.random());

		//Perform the choosen action. Noise in environment is represented using PJOG
		next = Action.performAction(curr, currAction, pjog); 
		
		if(!myMaze.isValidTransition(curr, next)) {
			//stay in same state and add penalty 
			reward = myMaze.getReward(curr, next);
			receivedPenalty = true;
			next.copy(curr);
		}
		else {
			//else transition cost = pathCost
			reward = PATH_COST;			
			receivedPenalty = false;
		}
		
		model[curr.x][curr.y].addSuccs(currAction, next, reward);
		model[next.x][next.y].addPreds(curr, currAction);

		pQueue.placeOnTop(curr, priority);
//		if((deltaQ/qBefore) > 0.4)//change is greater than 40% of original value
		
		processQueue();
		
		curr.copy(next);
		numSteps++;
		return false;
	}

	//Chooses the best action according to epsilon greedy exploration policy
	private int chooseAction(State curr, double randNum)
	{    
		int bestAction = model[curr.x][curr.y].getBestAction();
		selectedAction[curr.x][curr.y] = bestAction;	//updates only for display purposes
		double d = epsilon/(Action.numActions);
		int choosenAction = bestAction;
		
		for(int i=0;i<Action.numActions;i++) {
			if (randNum < (i+1)*d ) {
				choosenAction=i;
				break;
			}
		}
		
		Utility.show("BestAction:"+bestAction);
		Utility.show("Rand"+randNum);
		Utility.show("ChoosenAction:"+choosenAction);
		
		if(choosenAction == bestAction)
		    isBestAct = true;
		else 
		    isBestAct = false;
		
		return choosenAction;
	}

	
	//this contains the actual prioritized sweeping algorithm
	private void processQueue()
	{
		//Utility.show("===========>Processing Queue<================");
		int counter = 0;
		while(!pQueue.isEmpty() && counter<=BACKUPS_ALLOWED)
		{
			//Utility.show("Counter: "+counter);
			//pQueue.printQueue();
			//step 2.1
			State top = pQueue.pop();

			double jOldVal = model[top.x][top.y].getMinQsa();
			double roMin = Double.MAX_VALUE;
			int minAction = 0;
			
			//performing step 2.2
			for(int i=0;i<Action.numActions;i++)
			{
				model[top.x][top.y].actionSuccs[i].qsa = calcExpectedVal(top,i);
				if(model[top.x][top.y].actionSuccs[i].qsa<roMin)
				{
					roMin  = model[top.x][top.y].actionSuccs[i].qsa;
					minAction = i;
				}
			}
			double deltaMax = Math.abs(roMin - jOldVal);	//step 2.3
			
			//step 2.4 : no computation required under current implementation as
			// we use getMinQsa for Ji-hat value.

			//step 2.5
			Vector[] preds = model[top.x][top.y].getPreds();
			for(int i=0;i<Action.numActions;i++)
			{
				Vector predStates = preds[i];
				for(int j=0;j<predStates.size();j++)
				{
					StateAndVisits snv = (StateAndVisits)predStates.get(j);
					int num  = snv.getVisits();
					
					StateActionInfo temp = model[snv.state.x][snv.state.y].actionSuccs[i];
					int den = temp.actionCounter;
					
					double prob = (double)num/den;
					double priority = prob*deltaMax;

					if(priority>tinyThreshold)
						pQueue.insert(snv.state, priority);
				}
			}
			//filling qsa array for visualization purposes
			for(int i=0;i<Action.numActions;i++)
				qsa[top.x][top.y][i] = model[top.x][top.y].actionSuccs[i].qsa;

			counter+=1;
//			currValues.stateValue[top.x][top.y] = model[top.x][top.y].getMinQsa();
		}
		totalBackups += (counter-1); 
		if (counter-1>maxCounter)
		    maxCounter=counter-1;
		//Utility.show("===========>Backupstate over<================");
		//Utility.show("");
	}
	
	private double calcExpectedVal(State s, int a)
	{
		double val;
		StateInfo currStInfo= model[s.x][s.y];
		
		//number of times action a was executed till now
		int actionCtr = currStInfo.actionSuccs[a].actionCounter;
		if (0==actionCtr)
			return 0;
		//Added by vivek recently to do what andrew suggested
		//if (actionCtr < Min_Transition_Required)
		//    return 0;
		
		//estimated reward associated with state-action pair s-a
		double estReward = currStInfo.actionSuccs[a].totalReward/actionCtr;
		
		//all successors for action a 
		Vector succs = currStInfo.getSuccs(a);
		
		double sum = 0;
		StateAndVisits temp;
		for(int i=0;i<succs.size();i++)
		{
			temp = (StateAndVisits)succs.get(i);

			int num = temp.getVisits();	//num of times transition from temp to currSt was done
			int den = actionCtr;		//num of times action a was executed
			double prob = (double)num/den;
			
			double jVal = model[temp.state.x][temp.state.y].getMinQsa();
			sum += prob*jVal;
		}
		
		val = estReward + sum;
		return val;
	}
	
	void printPredecessors(Vector preds)
	{
//		StateActionInfo sap;
//		for(int i=0;i<preds.size();i++)
//		{
//			sap = (StateActionInfo)preds.get(i);
//			sap.printState();
//		}
	}
	
	void printBackedUpStates()
	{
		System.out.println("printing backedup states");
		State st;
		for(int i=0;i<backedUpStates.size();i++)
		{
			st = (State)backedUpStates.get(i);
			st.printState();
		}
		System.out.println();
	}
	
	//for visualization purposes
	public Vector getBackedUpStates()
	{
		return backedUpStates;
	}
	
	public double evalPolicy()
	{
		double score=0;
		int[][] currPolicy;
		currPolicy = getPolicy();
    	ValueFunction evalVals = new ValueFunction(myMaze.width, myMaze.height);
    	ValueFunction prevEvalVals = new ValueFunction(myMaze.width, myMaze.height);
    	prevEvalVals.initialize();
    	
       	State currState, desiredNextState;
		double maxDelta = 0, delta = 0, v=0, maxV, minV=10000, prob,safe;
		boolean valueConverged = false;
		int valueIters=0, MAX_VALUE_ALLOWED = 1000;
		
		while (!valueConverged)
		{
			evalVals.initialize();
		    maxDelta = 0;
		    maxV =0;
			for(int i=0; i<myMaze.width; i++) {
				for(int j=0; j<myMaze.height;j++) {
					v = 0;
				    currState = new State(i,j);
				    if(myMaze.goals.contains(currState)) {
				        evalVals.stateValue[i][j] = 0;
				        continue;
				    }
					Vector allNext = new Vector(myMaze.getSuccessors(currState));	//vector of successor states
					if(-1 == policy[i][j])
					{
						evalVals.stateValue[i][j] = 0;
						continue;
					}

					desiredNextState = Action.performAction(currState, policy[i][j]);
					
				    for(int m=0;m<allNext.size();m++)
				    {
			            State s = (State)allNext.get(m);
			            
			            if(!desiredNextState.equals(s))
			                prob = pjog/(Action.numActions-1);
			            else
			                prob = 1-pjog;
			            
			            if(myMaze.isValidTransition(currState,s))
			                safe = prevEvalVals.stateValue[s.x][s.y];
			            else
			                safe = myMaze.getReward(currState,s) + prevEvalVals.stateValue[i][j];
			            
			            v += prob*safe;        
				    }
				    v=v+1;//this 1 is added as pathcost or transition cost
				    evalVals.stateValue[i][j] = v;
				    
				    maxV = (maxV<v) ? v : maxV;
				    
					delta = Math.abs(evalVals.stateValue[i][j] - prevEvalVals.stateValue[i][j]);
					if(maxDelta<delta) 
						maxDelta = delta;
				}
			}
			valueIters++;
			
			if(maxDelta<PRECISION)
				valueConverged = true;
			
			if(maxV>MAX_VALUE_ALLOWED)
			    valueConverged = true;
			
			for(int i=0;i<myMaze.width;i++)
				for(int j=0;j<myMaze.height;j++)
					prevEvalVals.stateValue[i][j] = evalVals.stateValue[i][j]; 
		}
		score = computeScore(evalVals);
		return score;
	}
    	
	private double computeScore(ValueFunction evalVals)
	{
		double netScore=0;
		for(int i=0;i<evalVals.stateValue.length;i++)
			for(int j=0;j<evalVals.stateValue[i].length;j++)
			{
				netScore += Math.abs(optVals.stateValue[i][j] - evalVals.stateValue[i][j]);
			}
		
		netScore = netScore/(myMaze.height*myMaze.width);
		return netScore;
	}
	
	public double getScore()
	{
		return policyScore;
	}
	
	public double getTime()
	{
//		System.out.println("time= "+(stopTime - startTime));
		return (stopTime - startTime);
	}
	
	private void calcTrueValues()
	{
//		System.out.println("calc'ing opt values");
		ValueIteration valitr = new ValueIteration(myMaze, pjog, PRECISION);
		while(!valitr.step());
		optVals = valitr.getValueFunction();
		isOptValCalc = true;
	}
	
	public void recordStartTime()
	{
	    startTime = new Date().getTime();
//	    System.out.println("start time= "+startTime);
	}
	
	public void recordStopTime()
	{
	    stopTime = new Date().getTime();
//	    System.out.println("stop time= "+stopTime);
	}
	
	public void setEpsilon(double _eps)
	{
		epsilon = _eps;
	}
	
	
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -