📄 chainmult.m
字号:
% using negative integers. error('Cannot write temporary result over input matrices.'); end tempIdx = -x12_idx; % create positive temp index % Record current temp size tempSpecs.Sizes{tempIdx} = x12_size; tempSpecs.Complexity(tempIdx) = x12_cplx; temp_realEle = prod(x12_size); if x12_cplx temp_realEle = 2*temp_realEle; end if length(max_temp_realEle) < tempIdx, % This temp matrix was never previously used % entry is the maximum max_temp_realEle(tempIdx) = temp_realEle; else % Determine if # elements in this temp matrix % exceeds that of the corresponding temp storage area % If so, store new max storage size: % if temp_realEle > max_temp_realEle(tempIdx), max_temp_realEle(tempIdx) = temp_realEle; end end end % output vs. temp storage area end % scalar exception end % loop over all multiplication pairs in the chain % Determine if any temp storage area could overlay the % output storage area, thus removing the need for one % additional temporary matrix. % % NOTE: Matrix multiplications cannot be performed "in place". % Thus, the last multiply producing the output matrix cannot % use the output area as one of the (temp) multiplicands. % % If one of the temp areas can overlay the output area, % remove the temp area from the list. % % Also, if any of the input spaces are re-usable, we should % consider those spaces as alternatives to additional temp areas. % % xxx At present, we assume input spaces are read-only. % % Determine # elements in each temp matrix: if allowOutputOverwrite, % Disqualify any temp area which is used in the last % matrix multiply, as the output area receives the result % from the last multiply, and no area can be used twice % during one multiply. Identify & skip these temp areas: % (Could be 0, 1, or 2 temp areas that become disqualified) % ti=find(chain{end}<0); % neg indices indicate temp areas disqualifiedTempIdx = -chain{end}(ti); % pos temp indices modifiedTempEle = max_temp_realEle; modifiedTempEle(disqualifiedTempIdx) = Inf; % Determine # real-equivalent elements in output matrix: output_realEle = matrix_sizes{1}(1) * matrix_sizes{end}(2); if any(cplx), output_realEle = 2*output_realEle; end % Find which temp matrices fit in output space: i = find(modifiedTempEle <= output_realEle); if ~isempty(i), % At least one temp matrix can fit in output space. % Get index of the largest temp matrix that can fit. % If multiple "largest" temps, choose the first: [val,j]=max(max_temp_realEle(i)); % index of largest element in tempEle that is <= outputEle tempOverOutputIdx = i(j); % Remove the first such temp matrix, since it does not % need explicit storage any longer: % tempSpecs.Sizes(tempOverOutputIdx) = []; tempSpecs.Complexity(tempOverOutputIdx) = []; max_temp_realEle(tempOverOutputIdx) = []; end endend % single_mat% Determine total size of all temp storage areas% for this multiplication chain:sum_temp_realEle = sum(max_temp_realEle);% Return structure result:metrics.flopCount = flopCount;metrics.tempOverOutputIdx = tempOverOutputIdx;metrics.temp_ele = sum_temp_realEle;%metrics.tempSpecs = tempSpecs;% ---------------------------------------------------------------function Opt = chainOpt(metrics)% chainOpt Find optimal solutions for flops and memory storage.% Get memory and flops metrics from structure% into individual metric vectors:memCount = [metrics.temp_ele];flopsCount = [metrics.flopCount];% Optimization:%% 1) Choose smallest among all maximum storage requirements% for each permutation.% 2) Choose smallest flop count.%val = min(memCount); % Memory optimizationmemIdx = find(memCount == val); % find all minimaval = min(flopsCount); % Flop optimizationflopsIdx = find(flopsCount == val); % find all minima% Find common index/indices in memIdx and flopIdx% May be empty:Opt.bothIdx = intersect(memIdx,flopsIdx);% If there are multiple flop minima, only keep that% which (locally) minimizes memory:val = min(memCount(flopsIdx));Opt.flopsIdx = intersect(flopsIdx, find(memCount == val));% If there are multiple mem minima, only keep that% which (locally) minimizes flops:val = min(flopsCount(memIdx));Opt.memIdx = intersect(memIdx, find(flopsCount == val));% ---------------------------------------------------------------function [x_size, x_cplx] = getMatrixSize(x_idx, multSpecs, tempSpecs)% getMatrixSize Return the size of the i'th matrix chain result%% Positive indices are taken to be input matrix sizes,% negative indices are temp matrix sizes.% The 0'th index is the output matrix, and is not available% for use in intermediate results (at least in the enumerated% specifications).% If index is negative, return size-vector from temp_sizes% Otherwise, return size-vector from matrix_sizesif x_idx<0, x_size = tempSpecs.Sizes{-x_idx}; x_cplx = tempSpecs.Complexity(-x_idx);elseif x_idx>0, x_size = multSpecs.Sizes{x_idx}; x_cplx = multSpecs.Complexity(x_idx);else error(['Input index to matrix multiplier pair cannot be ' ... 'specified as the output matrix (index 0).']);end%-----------------------------------------------------------------function list = resultList(y)% Construct display list for results:%x = y.Input.Sizes;x_cplx = y.Input.Complexity;perms = y.All.Chain;ppv = y.All.ChainStr;maxTempEle = [y.All.Metrics.temp_ele];flopCount = [y.All.Metrics.flopCount];bothIdx = y.Optimized.bothIdx;flopsIdx = y.Optimized.flopsIdx;memIdx = y.Optimized.memIdx;% max # chars in Idx:idxChars = max(3, round(log10(length(perms))+.5));% max # chars in Mem:if maxTempEle == 0, memChars=3;else memChars = max(3, round(log10(max(maxTempEle))+.5));end% max # chars in Flops:if flopCount == 0, flopChars = 5;else flopChars = max(5,round(log10(max(flopCount))+.5));end% # chars in prettyprint expression:% = n variables + 2*(n-2) parens + n-1 asterisks + 5 charsexprChars = max(10, 4*length(x));% Construct printing format strings:fmtStr1 = [' ' ... '%' num2str(idxChars) 's ' ... '%' num2str(memChars) 's ' ... '%' num2str(flopChars) 's ' ... '%-' num2str(exprChars) 's ' ... '%s'];fmtStr2 = [' ' ... '%' num2str(idxChars) 'd ' ... '%' num2str(memChars) 'd ' ... '%' num2str(flopChars) 'd ' ... '%-' num2str(exprChars) 's ' ... '%s'];% List header:list{1} = sprintf(fmtStr1, 'Idx','Mem','Flops','Expression','Optimization');% Create result list entries:for i=1:length(perms), % If this is the "optimal" entry, flag it: isOpt=''; if ~isempty(bothIdx), if any(i==bothIdx), isOpt='<- flops & storage'; end else if any(i==flopsIdx), isOpt = '<- flops then storage'; elseif any(i==memIdx), isOpt = '<- storage then flops'; end end list{i+1} = sprintf(fmtStr2, ... i, maxTempEle(i), flopCount(i), ppv{i}, isOpt);endlist = char(list);% -----------------------------------------------------------function [Sizes, Complexity, varNames] = getRandomInput% getRandomInput Generate random problem with 1-8 matrices and sizes from 1-20% Generate random integers in range [2,6].% This yields 1 to 5 matrix size entries:Nsiz = 1+round(6*rand+.5);Nmats = Nsiz-1;% Generate random size elements in range [0,20]S = round(20 * rand(1,Nsiz));for i=1:Nmats, x{i} = S([i i+1]);endSizes = x;Complexity = (rand(1,Nmats) > 0.5);varNames = cellstr(char(double('A') + (0:Nmats-1))');% x = {[5 8] [8 6][6 8][8 5][5 8]}; % 3 opt indices% -----------------------------------------------------------function multSpecs = parse_args(varargin)% Create parse args structure%% Recognized input syntax are:% f()% f(size_list)% f(size_list, complexity)% f(..., varNames)% f(..., string_opts)%% Structure contains the following fields:% .ReuseOutput% .Optimize% .Sizes% .Complexity% .varNames% Parse trailing string arguments:opt=1; % defaultsreuse=1;while (length(varargin)>0) && ischar(varargin{end}), str = varargin{end}; varargin(end)=[]; switch str case 'opt' opt=1; case 'noopt' opt=0; case 'reuse' reuse=1; case 'noreuse' reuse=0; otherwise error('Invalid input option specified.'); endendmultSpecs.ReuseOutput = reuse;multSpecs.Optimize = opt;% Parse primary inputs:%% f()% f(size_list)% f(size_list, complexity)% f(size_list, complexity, varNames)if isempty(varargin), % Generate random defaults: [multSpecs.Sizes,multSpecs.Complexity,multSpecs.varNames] = getRandomInput; else % Parse user inputs: % (size_list, complexity, varNames) multSpecs.Sizes = varargin{1}; varargin(1)=[]; % Parse variable name list: if (length(varargin)>0), if iscell(varargin{end}), % should contain only strings in the varList cell array multSpecs.varNames = varargin{end}; varargin(end)=[]; else multSpecs.varNames = {}; end else multSpecs.varNames = {}; end % Parse complexity vector: if (length(varargin)>0), multSpecs.Complexity = (varargin{1} ~= 0); else multSpecs.Complexity = false(size(multSpecs.Sizes)); end if length(varargin)<2, % Create logical 0's cplx = false(size(multSpecs.Sizes)); else % is this complexity or varNames? cplx = varargin{2}; endend% Check that length of cplx matches x:if ~isequal(length(multSpecs.Sizes), length(multSpecs.Complexity) ), error('Lengths of inputs must be equal.');end% ==================================================================================% Support functions% ==================================================================================% ==================================================================================% genStructs: % Generate cell array of structures enumerating all associative-% rule combinations for the chain multiplication of N multiplicands.% ==================================================================================function list = genStructs(N)list={};if N==0, return; endTreeStack('reset');TreeStack('push',char('A'+(0:N-1)));[v,ok]=TreeStack('pop');while ok, [y,stop]=SplitOneLevel(v,[],0); if stop==-1, list{end+1}=y; % Done with this entry end %if stop==0, % error('Should not have stop=0 at top level.'); %end [v,ok]=TreeStack('pop');end% ----------------------------------------------------function field = convertIdxToField(idx)field={};for i=length(idx):-1:1, if idx(i)==1, field{i}='Lvar'; else field{i}='Rvar'; endend% ----------------------------------------------------function [y,stop]=SplitOneLevel(node,idx,stop)% Recursive associative-rule decomposition:persistent origif stop, y=node; returnendif isempty(idx), orig=node;endif ~isstruct(node), N=length(node); if N==1, y=node; stop=-1; % xxx else field=convertIdxToField(idx); for i=1:N-1,
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -