📄 i_priqlearner_id.java
字号:
return retval.toString(); } protected void updateState(state st) { st.delta = 0; for (int action = 0; action< numactions; action++) { st.reCalcQValue(action); } double diff = st.reCalcValueAction(); if (diff > 1) { System.out.println("State " + st.stateID + " changed by " + diff + " to " + st.stateValue); System.out.println("X: " + (st.stateID%10) + " Y: " + ((st.stateID/10)%10)); System.out.println("QueueSize: " + changeQueue.size()); } if (diff > 0) { Iterator it = st.incommingStates.iterator(); while (it.hasNext()) { state s = (state)it.next(); s.delta += diff; if (changeQueue.contains(s)) changeQueue.alteredKey(s); else changeQueue.add(s); } } } /** * Select an output based on the state and reward. * * @param statein int, the current state. * @param rewardin double, reward for the last output, positive * numbers are "good." */ public int query(int yn, double rn) { //System.out.println("state "+yn+" reward "+rn); total_reward += rn; queries++; if ((yn < 0) || (yn>(numstates - 1))) { System.out.println("i_PriQLearner_id.query: state "+yn +" is out of range."); return 0; } // add this new transition if (states[yn] == null) { states[yn] = new state(yn); } if (!first_of_trial) { if (states[xn] == null) throw new UnexpectedException(); states[xn].sawTransitionTo(states[yn], an, rn); profile[xn][an]++; // update our value table and policy // System.out.print("current state: "); updateState(states[xn]); // must always update the state we just left // System.out.println("QueueSize: " + changeQueue.size()); // run the updates till the max delta is minUpdate * (largest delta when we started) double limit = minUpdate; if (!changeQueue.isEmpty()) { state minState = (state)changeQueue.peekMin(); limit *= minState.delta; } for (int i=0; i<updateCount-1; i++) { if (changeQueue.isEmpty()) break; state st = (state)changeQueue.removeMin(); double delta = st.delta; updateState(st); if (delta < limit) break; } } else first_of_trial = false; // choose the action int action; randomrate *= randomratedecay; if (rgen.nextDouble() <= randomrate) { action = Math.abs(rgen.nextInt() % numactions); if (true) System.out.println("random action, rate is " + randomrate); } else { action = states[yn].action;/* if (states[yn].stateValue > 0.01) System.out.println("State value: " + states[yn].stateValue);*/ } /* * Remember for next time */ xn = yn; an = action; if (logging) CheckForChanges(); return action; } /** * Local method to see how much the policy has changed. */ private void CheckForChanges() { int i,j; for(i = 0; i<numstates; i++) { int action = 0; if (states[i] != null) { action = states[i].action; if (last_policy[i] != action) { changes++; last_policy[i] = action; } } } if (logging) log(String.valueOf(changes)); } /** * Called when the current trial ends. * * @param Vn double, the value of the absorbing state. * @param reward double, the reward for the last output. */ public void endTrial(double Vn, double rn) { total_reward += rn; if (DEBUG) System.out.println("xn ="+xn+" an ="+an+" rn="+rn); state resultingState = new state(-1); resultingState.stateValue = Vn; // add this new transition if (!first_of_trial) { if (states[xn] == null) throw new UnexpectedException(); states[xn].sawTransitionTo(resultingState, an, rn); profile[xn][an]++; // update our value table and policy updateState(states[xn]); // must always update the state we just left for (int i=0; i<updateCount-1; i++) { state st = (state)changeQueue.removeMin(); double delta = st.delta; updateState(st); if (delta < minUpdate) break; } } // record the policy if (logging) { CheckForChanges(); try { savePolicy(); } catch (IOException e) { } } } /** * Called to initialize for a new trial. */ public int initTrial(int s) { first_of_trial = true; changes = 0; queries = 0; total_reward = 0; profile = new double[numstates][numactions]; System.out.println("Prioritized Sweeping init"); return(query(s, 0)); } /** * Report the average reward per step in the trial. * @return the average. */ public double getAvgReward() { return(total_reward/(double)queries); } /** * Report the number of queries in the trial. * @return the total. */ public int getQueries() { return(queries); } /** * Report the number of policy changes in the trial. * @return the total. */ public int getPolicyChanges() { return(changes); } /** * Read the policy from a file. * * @param filename String, the name of the file to read from. */ public void readPolicy() throws IOException { System.err.println(getClass().getName() + " readPolicy() not implemented!"); return; } /** * Write the policy to a file. * * @param filename String, the name of the file to write to. */ public void savePolicy() throws IOException { System.err.println(getClass().getName() + " readPolicy() not implemented!"); return; } /** * Write the policy profile to a file. * * @param filename String, the name of the file to write to. */ public void saveProfile(String profile_filename) throws IOException { int i, j; String lineout; double total_hits = 0; PrintWriter p = new PrintWriter( new BufferedWriter( new FileWriter(profile_filename))); p.println("// Policy profile:"); p.println("// Q-learning Parameters:"); p.println("// "+gamma + " // gamma"); p.println("// "+randomrate + " // random rate"); p.println("// "+randomratedecay + " // random rate decay"); p.println("// Number of states "); p.println(numstates); p.println("// The profile. "); p.println("// proportion of hits, action, number of hits"); /*--- total up all the state/action hits ---*/ for(i=0; i<numstates; i++) { for(j=0; j<numactions; j++) { total_hits += (double) this.profile[i][j]; } } for(i=0; i<numstates; i++) { double Vn = Double.NEGATIVE_INFINITY; //very bad int action = 0; int hits = 0; state s = states[i]; for(j=0; j<numactions; j++) { hits += this.profile[i][j]; } if (s != null) { action = s.action; } p.println((double)hits/total_hits + " " + action + " " + hits); } p.flush(); p.close(); return; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -