⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rlearner.java

📁 用java写的一个强化学习程序
💻 JAVA
字号:
import java.util.Vector;
import java.lang.*;
import java.lang.reflect.*;

public class RLearner {

    RLWorld thisWorld;
    RLPolicy policy;

    // Learning types
    public static final int Q_LEARNING = 1;
    public static final int SARSA = 2;
    public static final int Q_LAMBDA = 3; // Good parms were lambda=0.05, gamma=0.1, alpha=0.01, epsilon=0.1

    // Action selection types
    public static final int E_GREEDY = 1;
    public static final int SOFTMAX = 2;

    int learningMethod;
    int actionSelection;

    double epsilon;
    double temp;

    double alpha;
    double gamma;
    double lambda;

    int[] dimSize;
    int[] state;
    int[] newstate;
    int action;
    double reward;

    int epochs;
	public int epochsdone;
	
    Thread thisThread;
    public boolean running;

    Vector trace = new Vector();
    int[] saPair;

    long timer;

    boolean random = false;
	Runnable a;

    public RLearner( RLWorld world) {
		// Getting the world from the invoking method.
		thisWorld = world;

		// Get dimensions of the world.
		dimSize = thisWorld.getDimension();
	
		// Creating new policy with dimensions to suit the world.
		policy = new RLPolicy( dimSize );

		// Initializing the policy with the initial values defined by the world.
		policy.initValues( thisWorld.getInitValues() );
	
		learningMethod = Q_LEARNING;  //Q_LAMBDA;//SARSA;
		actionSelection = E_GREEDY;
	
		// set default values
		epsilon = 0.1;
		temp = 1;

		alpha = 1; // For CliffWorld alpha = 1 is good
		gamma = 0.1;
		lambda = 0.1;  // For CliffWorld gamma = 0.1, l = 0.5 (l*g=0.05)is a good choice.

		System.out.println( "RLearner initialised" );
	
    }

    // execute one trial
	public void runTrial() {
		System.out.println( "Learning! ("+epochs+" epochs)\n" );
		for( int i = 0 ; i < epochs ; i++ ) {
				if( ! running ) break;
		
				runEpoch();
				
			if( i % 1000 == 0 ) {
			    // give text output
			    timer = ( System.currentTimeMillis() - timer );
			    System.out.println("Epoch:" + i + " : " + timer);
			    timer = System.currentTimeMillis();
			}
		}
	}
	
	// execute one epoch
	public void runEpoch() {
	
		// Reset state to start position defined by the world.
		state = thisWorld.resetState();
		
		switch( learningMethod ) {
	    
		case Q_LEARNING : {
	    
	    	double this_Q;
		    double max_Q;
		    double new_Q;

			while( ! thisWorld.endState() ) {
		    
			    if( ! running ) break;
					action = selectAction( state );
		    		newstate = thisWorld.getNextState( action );
				    reward = thisWorld.getReward();
		    
				    this_Q = policy.getQValue( state, action );
				    max_Q = policy.getMaxQValue( newstate );

				    // Calculate new Value for Q
				    new_Q = this_Q + alpha * ( reward + gamma * max_Q - this_Q );
				    policy.setQValue( state, action, new_Q );

				    // Set state to the new state.
				    state = newstate;
			}
		
	    
	    }

	case SARSA : {
	    
	    int newaction;
	    double this_Q;
	    double next_Q;
	    double new_Q;

	    action = selectAction( state );
		while( ! thisWorld.endState() ) {
		
		    if( ! running ) break;
		    
		    newstate = thisWorld.getNextState( action );
		    reward = thisWorld.getReward();
		    
   		    newaction = selectAction( newstate );
		    
		    this_Q = policy.getQValue( state, action );
		    next_Q = policy.getQValue( newstate, newaction );
		    
		    new_Q = this_Q + alpha * ( reward + gamma * next_Q - this_Q );
		    
		    policy.setQValue( state, action, new_Q );
		    
		    // Set state to the new state and action to the new action.
		    state = newstate;
		    action = newaction;
		}
		
	}

	case Q_LAMBDA : {
	    
	    double max_Q;
	    double this_Q;
	    double new_Q;
	    double delta;

		// Remove all eligibility traces. 
		trace.removeAllElements();
		
		while( ! thisWorld.endState() ) {
		    
		    if( ! running ) break;
		    
		    action = selectAction( state );
		    
		    // Store state-action pair in eligibility trace.
		    saPair = new int[dimSize.length];
		    System.arraycopy( state, 0, saPair, 0, state.length );
		    saPair[state.length] = action;
		    trace.add( saPair );

		    // Store only 10 traced states.
		    if( trace.size() == 11 )
			trace.removeElementAt( 0 );
		    		    
		    newstate = thisWorld.getNextState( action );
		    reward = thisWorld.getReward();
		    
		    max_Q = policy.getMaxQValue( newstate );
		    this_Q = policy.getQValue( state, action );
		    
		    // Calculate new Value for Q
		    delta = reward + gamma * max_Q - this_Q;
		    new_Q = this_Q + alpha * delta;

		    policy.setQValue( state, action, new_Q );
		    
		    // Update values for the trace.
		    for( int e = trace.size() - 2 ; e >= 0 ; e-- ) {
			
			saPair = (int[]) trace.get( e );
			
			System.arraycopy( saPair, 0, state, 0, state.length );
			action = saPair[state.length];

			this_Q = policy.getQValue( state, action );
			new_Q = this_Q + alpha * delta * Math.pow( gamma * lambda, ( trace.size() - 1 - e ) );

			policy.setQValue( state, action, new_Q );

			//System.out.println("Set Q:" + new_Q + "for " + state[0] + "," + state[1] + " action " + action );
		    }
		    
		    if( random ) trace.removeAllElements();

		    // Set state to the new state.
		    state = newstate; 
		    

		}
		
		} // case
	} // switch
    } // runEpoch
    
    private int selectAction( int[] state ) {

	double[] qValues = policy.getQValuesAt( state );
	int selectedAction = -1;
    
	switch (actionSelection) {
	    
	case E_GREEDY : {
	    
	    random = false;
	    double maxQ = -Double.MAX_VALUE;
	    int[] doubleValues = new int[qValues.length];
	    int maxDV = 0;
	    
	    //Explore
	    if ( Math.random() < epsilon ) {
		selectedAction = -1;
		random = true;
	    }
	    else {
	    
		for( int action = 0 ; action < qValues.length ; action++ ) {
		    
		    if( qValues[action] > maxQ ) {
			selectedAction = action;
			maxQ = qValues[action];
			maxDV = 0;
			doubleValues[maxDV] = selectedAction;
		    }
		    else if( qValues[action] == maxQ ) {
			maxDV++;
			doubleValues[maxDV] = action; 
		    }
		}
		
		if( maxDV > 0 ) {
		    int randomIndex = (int) ( Math.random() * ( maxDV + 1 ) );
		    selectedAction = doubleValues[ randomIndex ];
		}
	    }
	    
	    // Select random action if all qValues == 0 or exploring.
	    if ( selectedAction == -1 ) {
		
		// System.out.println( "Exploring ..." );
		selectedAction = (int) (Math.random() * qValues.length);
	    }
	    
	    // Choose new action if not valid.
	    while( ! thisWorld.validAction( selectedAction ) ) {
		
		selectedAction = (int) (Math.random() * qValues.length);
		// System.out.println( "Invalid action, new one:" + selectedAction);
	    }
	    
	    break;
	}
	
	case SOFTMAX : {
	    
	    int action;
	    double prob[] = new double[ qValues.length ];
	    double sumProb = 0;
	    
	    for( action = 0 ; action < qValues.length ; action++ ) {
		prob[action] = Math.exp( qValues[action] / temp );
		sumProb += prob[action];
	    }
	    for( action = 0 ; action < qValues.length ; action++ )
		prob[action] = prob[action] / sumProb;
	    
	    boolean valid = false;
	    double rndValue;
	    double offset;
	    
	    while( ! valid ) {
		
		rndValue = Math.random();
		offset = 0;
		
		for( action = 0 ; action < qValues.length ; action++ ) {
		    if( rndValue > offset && rndValue < offset + prob[action] )
			selectedAction = action;
		    offset += prob[action];
		    // System.out.println( "Action " + action + " chosen with " + prob[action] );
		}

		if( thisWorld.validAction( selectedAction ) )
		    valid = true;
	    }
	    break;
	    
	}
	}
	return selectedAction;
    }
    
    /* private double getMaxQValue( int[] state, int action ) {
	
	double maxQ = 0;
	
	double[] qValues = policy.getQValuesAt( state );
	
	for( action = 0 ; action < qValues.length ; action++ ) {
	    if( qValues[action] > maxQ ) {
		maxQ = qValues[action];
	    }
	}
	return maxQ;
    }
    */


    public RLPolicy getPolicy() {
	
	return policy;
    }

    public void setAlpha( double a ) {
    
	if( a >= 0 && a < 1 )
	    alpha = a;
    }

    public double getAlpha() {
    
	return alpha;
    } 

    public void setGamma( double g ) {
    
	if( g > 0 && g < 1 )
	    gamma = g;
    }

    public double getGamma() {
	
	return gamma;
    }

    public void setEpsilon( double e ) {

	if( e > 0 && e < 1 )
	    epsilon = e;
    }
    
    public double getEpsilon() {
	
	return epsilon;
    }
    
    public void setEpisodes( int e ) {
	
	if( e > 0 )
	    epochs = e;
    }
    
    public int getEpisodes() {

	return epochs;
    }
    
    public void setActionSelection( int as ) {
	
	switch ( as ) {
	    
	case SOFTMAX : { 
	    actionSelection = SOFTMAX;
	    break;
	}
	case E_GREEDY :
	default : {
	    actionSelection = E_GREEDY;
	}
	
	}
    }
    
    public int getActionSelection() {

	return actionSelection;
    }
    
    public void setLearningMethod( int lm ) {
	
	switch ( lm ) {
	    
	case SARSA : {
	    learningMethod = SARSA;
	    break;
	}
	case Q_LAMBDA : {
	    learningMethod = Q_LAMBDA;
	    break;
	}
	case Q_LEARNING :
	default : { 
	    learningMethod = Q_LEARNING;
	}
	}
    }

    public int getLearningMethod() {

	return learningMethod;
    }

	//AK: let us clear the policy
	public RLPolicy newPolicy() {
		policy = new RLPolicy( dimSize );
		// Initializing the policy with the initial values defined by the world.
		policy.initValues( thisWorld.getInitValues() );
		return policy;
	}
}
	

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -