📄 i_qlearner_id.java
字号:
// denominator: number of times did the last // action in the last state times the previous // average reward plus the current reward // plus the projected reward // numerator: number times did the last // state/action plus this state plus next } p[xn][an]++; //count times in the last state/action profile[xn][an]++; //count times for this trial } else first_of_trial = 0; /* * Select random action, possibly */ if (rgen.nextDouble() <= randomrate) { action = rgen.nextInt() % numactions; if (action < 1) action = -1 * action; if (true) System.out.println("random action, rate is " + randomrate); } randomrate *= randomratedecay; /* * Remember for next time */ xn = yn; an = action; if (logging) CheckForChanges(); return action; } /** * Local method to see how much the policy has changed. */ private void CheckForChanges() { int i,j; for(i = 0; i<numstates; i++) { double val = -9999999999f; int action = 0; for(j=0; j<numactions; j++) { if (q[i][j] > val) { action = j; val = q[i][j]; } } if (last_policy[i] != action) { changes++; last_policy[i] = action; } } if (logging) log(String.valueOf(changes)); } /** * Called when the current trial ends. * * @param Vn double, the value of the absorbing state. * @param reward double, the reward for the last output. */ public void endTrial(double Vn, double rn) { total_reward += rn; if (DEBUG) System.out.println( "xn ="+xn+" an ="+an+" rn="+rn); if (criteria == DISCOUNTED) { // Watkins update rule: q[xn][an] = (1 - alpha)*q[xn][an] + alpha*(rn + gamma*Vn); } else // criteria == AVERAGE { // average update rule q[xn][an] = (p[xn][an] * q[xn][an] + rn)/ (p[xn][an] + 2); // see update above in query() for explanation // of this rule } p[xn][an] += 1; profile[xn][an] += 1; if (logging) { CheckForChanges(); try { savePolicy(); } catch (IOException e) { } } } /** * Called to initialize for a new trial. */ public int initTrial(int s) { first_of_trial = 1; changes = 0; queries = 0; total_reward = 0; profile = new double[numstates][numactions]; System.out.println("Q init"); return(query(s, 0)); } /** * Report the average reward per step in the trial. * @return the average. */ public double getAvgReward() { return(total_reward/(double)queries); } /** * Report the number of queries in the trial. * @return the total. */ public int getQueries() { return(queries); } /** * Report the number of policy changes in the trial. * @return the total. */ public int getPolicyChanges() { return(changes); } /** * Read the policy from a file. * * @param filename String, the name of the file to read from. */ public void readPolicy() throws IOException { int i, j, k; String linein; FileInputStream f; InputStreamReader isr; StreamTokenizer p; try { f = new FileInputStream(policyfilename); isr = new InputStreamReader(f); p = new StreamTokenizer(isr); } catch (SecurityException e) { return; } // configure the tokenizer p.parseNumbers(); p.slashSlashComments(true); p.slashStarComments(true); k = p.nextToken(); alpha = p.nval; k = p.nextToken(); gamma = p.nval; k = p.nextToken(); randomrate = p.nval; // to get around java bug that can't read e-xxx nums if (randomrate > 1.0) { k = p.nextToken(); randomrate = 0; } k = p.nextToken(); randomratedecay = p.nval; for(i=0; i<numstates; i++) { for(j=0; j<numactions; j++) { k = p.nextToken(); this.p[i][j] = p.nval; k = p.nextToken(); q[i][j] = p.nval; } } f.close(); return; } /** * Write the policy to a file. * * @param filename String, the name of the file to write to. */ public void savePolicy() throws IOException { int i, j; String lineout; FileOutputStream f = new FileOutputStream(policyfilename); PrintWriter p = new PrintWriter(f); p.println("// Q-learning Parameters:"); p.println(alpha + " // alpha"); p.println(gamma + " // gamma"); p.println(randomrate + " // random rate"); p.println(randomratedecay + " // random rate decay"); p.println("// The policy. "); p.println("// The first number is the hits in that "); p.println("// state/action, the following num is the s/a "); p.println("// Q-value. "); for(i=0; i<numstates; i++) { for(j=0; j<numactions; j++) { p.print(this.p[i][j] + " "); p.print(q[i][j] + " "); } p.println(); } f.close(); return; } /** * Write the policy profile to a file. * * @param filename String, the name of the file to write to. */ public void saveProfile(String profile_filename) throws IOException { int i, j; String lineout; double total_hits = 0; PrintWriter p = new PrintWriter( new BufferedWriter( new FileWriter(profile_filename))); p.println("// Policy profile:"); p.println("// Q-learning Parameters:"); p.println("// "+alpha + " // alpha"); p.println("// "+gamma + " // gamma"); p.println("// "+randomrate + " // random rate"); p.println("// "+randomratedecay + " // random rate decay"); p.println("// Number of states "); p.println(numstates); p.println("// The profile. "); p.println("// proportion of hits, action, number of hits"); /*--- total up all the state/action hits ---*/ for(i=0; i<numstates; i++) { for(j=0; j<numactions; j++) { total_hits += (double) this.profile[i][j]; } } for(i=0; i<numstates; i++) { double Vn = -9999999999f; //very bad int action = 0; int hits = 0; for(j=0; j<numactions; j++) { hits += this.profile[i][j]; if (q[i][j] > Vn) { Vn = q[i][j]; action = j; } } p.println((double)hits/total_hits + " " + action + " " + hits); } p.flush(); p.close(); return; } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -