⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 chainmult.m

📁 Matlab中优化多个矩阵相乘以提高计算效率的源代码。
💻 M
📖 第 1 页 / 共 3 页
字号:
               % using negative integers.               error('Cannot write temporary result over input matrices.');            end                        tempIdx = -x12_idx; % create positive temp index                        % Record current temp size            tempSpecs.Sizes{tempIdx}      = x12_size;            tempSpecs.Complexity(tempIdx) = x12_cplx;            temp_realEle                  = prod(x12_size);                        if x12_cplx              temp_realEle = 2*temp_realEle;            end                        if length(max_temp_realEle) < tempIdx,               % This temp matrix was never previously used               % entry is the maximum               max_temp_realEle(tempIdx) = temp_realEle;            else               % Determine if # elements in this temp matrix               % exceeds that of the corresponding temp storage area               % If so, store new max storage size:               %               if temp_realEle > max_temp_realEle(tempIdx),                  max_temp_realEle(tempIdx) = temp_realEle;               end            end         end  % output vs. temp storage area      end % scalar exception   end  % loop over all multiplication pairs in the chain      % Determine if any temp storage area could overlay the   % output storage area, thus removing the need for one   % additional temporary matrix.   %   % NOTE: Matrix multiplications cannot be performed "in place".   % Thus, the last multiply producing the output matrix cannot   % use the output area as one of the (temp) multiplicands.   %   % If one of the temp areas can overlay the output area,   % remove the temp area from the list.   %   % Also, if any of the input spaces are re-usable, we should   % consider those spaces as alternatives to additional temp areas.   %   % xxx At present, we assume input spaces are read-only.   %   % Determine # elements in each temp matrix:      if allowOutputOverwrite,      % Disqualify any temp area which is used in the last      % matrix multiply, as the output area receives the result      % from the last multiply, and no area can be used twice      % during one multiply.  Identify & skip these temp areas:      % (Could be 0, 1, or 2 temp areas that become disqualified)      %      ti=find(chain{end}<0);  % neg indices indicate temp areas      disqualifiedTempIdx = -chain{end}(ti); % pos temp indices      modifiedTempEle = max_temp_realEle;      modifiedTempEle(disqualifiedTempIdx) = Inf;            % Determine # real-equivalent elements in output matrix:      output_realEle = matrix_sizes{1}(1) * matrix_sizes{end}(2);      if any(cplx),         output_realEle = 2*output_realEle;      end            % Find which temp matrices fit in output space:      i = find(modifiedTempEle <= output_realEle);            if ~isempty(i),         % At least one temp matrix can fit in output space.         % Get index of the largest temp matrix that can fit.         % If multiple "largest" temps, choose the first:         [val,j]=max(max_temp_realEle(i));         % index of largest element in tempEle that is <= outputEle         tempOverOutputIdx = i(j);                  % Remove the first such temp matrix, since it does not         % need explicit storage any longer:         %         tempSpecs.Sizes(tempOverOutputIdx)      = [];         tempSpecs.Complexity(tempOverOutputIdx) = [];         max_temp_realEle(tempOverOutputIdx)     = [];      end   endend % single_mat% Determine total size of all temp storage areas% for this multiplication chain:sum_temp_realEle = sum(max_temp_realEle);% Return structure result:metrics.flopCount         = flopCount;metrics.tempOverOutputIdx = tempOverOutputIdx;metrics.temp_ele          = sum_temp_realEle;%metrics.tempSpecs         = tempSpecs;% ---------------------------------------------------------------function Opt = chainOpt(metrics)% chainOpt Find optimal solutions for flops and memory storage.% Get memory and flops metrics from structure% into individual metric vectors:memCount   = [metrics.temp_ele];flopsCount = [metrics.flopCount];% Optimization:%%  1) Choose smallest among all maximum storage requirements%     for each permutation.%  2) Choose smallest flop count.%val = min(memCount);          % Memory optimizationmemIdx = find(memCount == val); % find all minimaval = min(flopsCount);          % Flop optimizationflopsIdx = find(flopsCount == val); % find all minima% Find common index/indices in memIdx and flopIdx% May be empty:Opt.bothIdx = intersect(memIdx,flopsIdx);% If there are multiple flop minima, only keep that% which (locally) minimizes memory:val = min(memCount(flopsIdx));Opt.flopsIdx = intersect(flopsIdx, find(memCount == val));% If there are multiple mem minima, only keep that% which (locally) minimizes flops:val = min(flopsCount(memIdx));Opt.memIdx = intersect(memIdx, find(flopsCount == val));% ---------------------------------------------------------------function [x_size, x_cplx] = getMatrixSize(x_idx, multSpecs, tempSpecs)% getMatrixSize Return the size of the i'th matrix chain result%% Positive indices are taken to be input matrix sizes,% negative indices are temp matrix sizes.% The 0'th index is the output matrix, and is not available% for use in intermediate results (at least in the enumerated% specifications).%  If index is negative, return size-vector from temp_sizes%  Otherwise, return size-vector from matrix_sizesif x_idx<0,   x_size = tempSpecs.Sizes{-x_idx};   x_cplx = tempSpecs.Complexity(-x_idx);elseif x_idx>0,   x_size = multSpecs.Sizes{x_idx};   x_cplx = multSpecs.Complexity(x_idx);else   error(['Input index to matrix multiplier pair cannot be ' ...         'specified as the output matrix (index 0).']);end%-----------------------------------------------------------------function list = resultList(y)% Construct display list for results:%x          = y.Input.Sizes;x_cplx     = y.Input.Complexity;perms      = y.All.Chain;ppv        = y.All.ChainStr;maxTempEle = [y.All.Metrics.temp_ele];flopCount  = [y.All.Metrics.flopCount];bothIdx    = y.Optimized.bothIdx;flopsIdx   = y.Optimized.flopsIdx;memIdx     = y.Optimized.memIdx;% max # chars in Idx:idxChars = max(3, round(log10(length(perms))+.5));% max # chars in Mem:if maxTempEle == 0,   memChars=3;else   memChars = max(3, round(log10(max(maxTempEle))+.5));end% max # chars in Flops:if flopCount == 0,   flopChars = 5;else   flopChars = max(5,round(log10(max(flopCount))+.5));end% # chars in prettyprint expression:% = n variables + 2*(n-2) parens + n-1 asterisks + 5 charsexprChars = max(10, 4*length(x));% Construct printing format strings:fmtStr1 = [' ' ...      '%' num2str(idxChars)   's  ' ...      '%' num2str(memChars)   's  ' ...      '%' num2str(flopChars)  's  ' ...      '%-' num2str(exprChars) 's  ' ...      '%s'];fmtStr2 = [' ' ...      '%' num2str(idxChars)   'd  ' ...      '%' num2str(memChars)   'd  ' ...      '%' num2str(flopChars)  'd  ' ...      '%-' num2str(exprChars) 's  ' ...      '%s'];% List header:list{1} = sprintf(fmtStr1, 'Idx','Mem','Flops','Expression','Optimization');% Create result list entries:for i=1:length(perms),   % If this is the "optimal" entry, flag it:   isOpt='';   if ~isempty(bothIdx),      if any(i==bothIdx),         isOpt='<- flops & storage';      end   else      if any(i==flopsIdx),         isOpt = '<- flops then storage';      elseif any(i==memIdx),         isOpt = '<- storage then flops';      end   end   list{i+1} = sprintf(fmtStr2, ...      i, maxTempEle(i), flopCount(i), ppv{i}, isOpt);endlist = char(list);% -----------------------------------------------------------function [Sizes, Complexity, varNames] = getRandomInput% getRandomInput Generate random problem with 1-8 matrices and sizes from 1-20% Generate random integers in range [2,6].% This yields 1 to 5 matrix size entries:Nsiz = 1+round(6*rand+.5);Nmats = Nsiz-1;% Generate random size elements in range [0,20]S = round(20 * rand(1,Nsiz));for i=1:Nmats,    x{i} = S([i i+1]);endSizes = x;Complexity = (rand(1,Nmats) > 0.5);varNames = cellstr(char(double('A') + (0:Nmats-1))');% x = {[5 8] [8 6][6 8][8 5][5 8]}; % 3 opt indices% -----------------------------------------------------------function multSpecs = parse_args(varargin)% Create parse args structure%% Recognized input syntax are:%   f()%   f(size_list)%   f(size_list, complexity)%   f(..., varNames)%   f(..., string_opts)%% Structure contains the following fields:%   .ReuseOutput%   .Optimize%   .Sizes%   .Complexity%   .varNames% Parse trailing string arguments:opt=1;   % defaultsreuse=1;while (length(varargin)>0) && ischar(varargin{end}),    str = varargin{end};    varargin(end)=[];    switch str    case 'opt'        opt=1;    case 'noopt'       opt=0;    case 'reuse'       reuse=1;    case 'noreuse'       reuse=0;    otherwise        error('Invalid input option specified.');    endendmultSpecs.ReuseOutput = reuse;multSpecs.Optimize    = opt;% Parse primary inputs:%% f()% f(size_list)% f(size_list, complexity)% f(size_list, complexity, varNames)if isempty(varargin),   % Generate random defaults:   [multSpecs.Sizes,multSpecs.Complexity,multSpecs.varNames] = getRandomInput;   else   % Parse user inputs:   % (size_list, complexity, varNames)      multSpecs.Sizes = varargin{1};   varargin(1)=[];   % Parse variable name list:   if (length(varargin)>0),       if iscell(varargin{end}),           % should contain only strings in the varList cell array           multSpecs.varNames = varargin{end};           varargin(end)=[];       else           multSpecs.varNames = {};        end     else        multSpecs.varNames = {};   end   % Parse complexity vector:   if (length(varargin)>0),       multSpecs.Complexity = (varargin{1} ~= 0);   else      multSpecs.Complexity = false(size(multSpecs.Sizes));          end      if length(varargin)<2,      % Create logical 0's      cplx = false(size(multSpecs.Sizes));   else      % is this complexity or varNames?      cplx = varargin{2};   endend% Check that length of cplx matches x:if ~isequal(length(multSpecs.Sizes), length(multSpecs.Complexity) ),   error('Lengths of inputs must be equal.');end% ==================================================================================% Support functions% ==================================================================================% ==================================================================================% genStructs: %  Generate cell array of structures enumerating all associative-%  rule combinations for the chain multiplication of N multiplicands.% ==================================================================================function list = genStructs(N)list={};if N==0, return; endTreeStack('reset');TreeStack('push',char('A'+(0:N-1)));[v,ok]=TreeStack('pop');while ok,   [y,stop]=SplitOneLevel(v,[],0);   if stop==-1,      list{end+1}=y;      % Done with this entry   end   %if stop==0,   %   error('Should not have stop=0 at top level.');   %end   [v,ok]=TreeStack('pop');end% ----------------------------------------------------function field = convertIdxToField(idx)field={};for i=length(idx):-1:1,   if idx(i)==1,      field{i}='Lvar';   else      field{i}='Rvar';   endend% ----------------------------------------------------function [y,stop]=SplitOneLevel(node,idx,stop)% Recursive associative-rule decomposition:persistent origif stop,   y=node; returnendif isempty(idx),   orig=node;endif ~isstruct(node),   N=length(node);   if N==1,      y=node;      stop=-1; % xxx   else      field=convertIdxToField(idx);      for i=1:N-1,

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -