compute_statematrix.m

来自「回声状态神经网络(ESN)是一种性能优异的递归神经网络」· M 代码 · 共 183 行

M
183
字号
function stateCollectMat = ...    compute_statematrix(inputSequence, outputSequence, esn, nForgetPoints, varargin)% compute_statematrix  runs the input through the ESN and writes the% obtained input+reservoir states into stateCollectMat.% The first nForgetPoints will be deleted, as the first few states could be% not reliable due to initial transients  %% inputs:% inputSequence = input time series of size nTrainingPoints x nInputDimension% outputSequence = output time series of size nTrainingPoints x nOutputDimension% esn = an ESN structure, through which we run our input sequence% nForgetPoints: an integer, may be negative, positive or zero.%    If positive: the first nForgetPoints will be disregarded (washing out%    initial reservoir transient)%    If negative: the network will be initially driven from zero state with%    the first input repeated |nForgetPoints| times; size(inputSequence,1)%    many states will be sorted into state matrix%    If zero: no washout accounted for, all states except the zero starting%    state will be sorted into state matrix%% Note: one of inputSequence and outputSequence may be the empty list [],% but not both. If the inputSequence is empty, we are dealing with a purely% generative task; states are then computed by teacher-forcing% outputSequence. If outputSequence is empty, we are using this function to% test a trained ESN; network output is then computed from network dynamics% via output weights. If both are non-empty, states are computed by% teacher-forcing outputSequence.%% optional input argument:% there may be one optional input, the starting vector by which the esn is% started. The starting vector must be given as a column vector of% dimension esn.nInternalUnits + esn.nOutputUnits + esn.nInputUnits  (that% is, it is a total state, not an internal reservoir state). If this input% is desired, call test_esn with fourth input 'startingState' and fifth% input the starting vector.%% output:% stateCollectMat = matrix of size (nTrainingPoints-nForgetPoints) x% nInputUnits + nInternalUnits % stateCollectMat(i,j) = internal activation of unit j after the % (i + nForgetPoints)th training point has been presented to the network%% Version 1.0, April 30, 2006% Copyright: Fraunhofer IAIS 2006 / Patents pending% Revision 1, June 6, 2006, H. Jaeger% Revision 2, June 23, 2007, H. Jaeger (added optional starting state% input)% Revision 3, July 1, 2007, H. Jaeger (added leaky1_esn update option)if isempty(inputSequence) && isempty(outputSequence)    error('error in compute_statematrix: two empty input args');endif isempty(outputSequence)    teacherForcing = 0;    nDataPoints = length(inputSequence(:,1));else    teacherForcing = 1;    nDataPoints = length(outputSequence(:,1));endif nForgetPoints >= 0    stateCollectMat = ...        zeros(nDataPoints - nForgetPoints, esn.nInputUnits + esn.nInternalUnits) ; else    stateCollectMat = ...        zeros(nDataPoints, esn.nInputUnits + esn.nInternalUnits) ; end%% set starting stateexternalStartStateFlag = 0;args = varargin; nargs= length(args);for i=1:2:nargs    switch args{i},        case 'startingState',             totalstate = args{i+1} ;             internalState = totalstate(1:esn.nInternalUnits,1) ;             externalStartStateFlag = 1;        otherwise error('the option does not exist');     end      endif externalStartStateFlag == 0    totalstate = zeros(esn.nInputUnits + esn.nInternalUnits + esn.nOutputUnits, 1);    internalState = zeros(esn.nInternalUnits, 1);end%%%% if nForgetPoints is negative, ramp up ESN by feeding first input%%%% |nForgetPoints| many timesif nForgetPoints < 0    for i = 1:-nForgetPoints        if esn.nInputUnits > 0            in = esn.inputScaling .* inputSequence(1,:)' + esn.inputShift;  % in is column vector        else in = [];        end        if esn.nInputUnits > 0            totalstate(esn.nInternalUnits+1:esn.nInternalUnits + esn.nInputUnits) = in;        end         % the internal state is computed based on the type of the network        switch esn.type            case 'plain_esn'                typeSpecificArg = [];            case 'leaky_esn'                typeSpecificArg = [];            case 'leaky1_esn'                typeSpecificArg = [];            case 'twi_esn'                if  esn.nInputUnits == 0                    error('twi_esn cannot be used without input to ESN');                end                typeSpecificArg = esn.avDist;                        end        internalState = feval(esn.type, totalstate, esn, typeSpecificArg) ;                 if teacherForcing            netOut = esn.teacherScaling .* outputSequence(1,:)' + esn.teacherShift;        else            netOut = feval(esn.outputActivationFunction, esn.outputWeights * [internalState; in]);        end                totalstate = [internalState; in; netOut];    endendcollectIndex = 0;for i = 1:nDataPoints        % scale and shift the value of the inputSequence    if esn.nInputUnits > 0        in = esn.inputScaling .* inputSequence(i,:)' + esn.inputShift;  % in is column vector    else in = [];    end        % write input into totalstate    if esn.nInputUnits > 0        totalstate(esn.nInternalUnits+1:esn.nInternalUnits + esn.nInputUnits) = in;    end            % the internal state is computed based on the type of the network    switch esn.type        case 'plain_esn'            typeSpecificArg = [];        case 'leaky_esn'            typeSpecificArg = [];        case 'leaky1_esn'            typeSpecificArg = [];        case 'twi_esn'            if  esn.nInputUnits == 0                error('twi_esn cannot be used without input to ESN');            end            if i == 1                typeSpecificArg = esn.avDist;            else                typeSpecificArg = norm(inputSequence(i,:) - inputSequence(i-1,:));            end    end    internalState = feval(esn.type, totalstate, esn, typeSpecificArg) ;         if teacherForcing        netOut = esn.teacherScaling .* outputSequence(i,:)' + esn.teacherShift;    else                netOut = feval(esn.outputActivationFunction, esn.outputWeights * [internalState; in]);    end            totalstate = [internalState; in; netOut];        %collect state    if nForgetPoints >= 0 &&  i > nForgetPoints        collectIndex = collectIndex + 1;        stateCollectMat(collectIndex,:) = [internalState' in'];     elseif nForgetPoints < 0        collectIndex = collectIndex + 1;        stateCollectMat(collectIndex,:) = [internalState' in'];     end    end

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?