📄 chainmult.m
字号:
function y = chainmult(varargin)%CHAINMULT Optimization of chain matrix multiply problem.% CHAINMULT(X) performs optimization of a specified matrix% multiply chain by determining the parenthesization of% multiplicands that best minimizes two metrics:% - the maximum size of all temporary matrices that must be% stored in memory at any time during the evaluation, and% - the total number of floating point operations (flops)% required to compute the result.%% The simplest way to make use of this optimization function% is via the GUI front-end, CHAINGUI.%% X is an N-element cell array where each entry in the% array contains the size of each matrix in the multi-% plication chain. For the chain multiply Y=A*B*C,% X is specified as X = {size(A) size(B) size(C)}.%% CHAINMULT(X,C) specifies an N-element logical vector C that% indicates if each input matrix is complex. A zero indicates% real matrices, non-zero indicates complex. Note that length(C)% must equal length(x). If omitted, all input matrices are% assumed to be real.%% CHAINMULT(X,C,'reuse') specifies that the output matrix% may be used for temporary matrix computation results.% By default, the output matrix area is reused. To disallow% reuse of the output matrix, specify 'noreuse'.%% CHAINMULT(...,'noopt') indicates that no optimization is to% be performed. The natural forward-chain multiplication is% selected, that is, (((A*B)*C)*D)... To perform full optimization,% omit this string or specify 'opt'.%% CHAINMULT returns a structure with the following fields:% .Optimization% A structure containing the indices of the results that% jointly optimize the flops and memory metrics. Note% that no single method of computing the chain multiply% problem may jointly minimize both metrics simultaneously.% Thus, three different minimization results are returned.% The indices returned are used in conjunction with the list% of results returned in the Y.All field.% The specific fields of .Optimization are:%% .bothIdx - a vector of indices that jointly% minimize both flops and memory, or empty if no% index minimized both. If several methods of% computation return optimal (and thus identical)% results, a vector of indices is returned.%% .flopsIdx - the index of the result that minimizes% the flops count. If more than one such result occurs,% only those that secondarily minimize the memory% usage are returned. At least one such result always% exists. More than one method may lead to identical% minimization results, in which case a vector is returned.% % .memIdx - the index of the result that minimizes% memory usage. If more than one such result occurs,% only those that secondarily minimize the flops count% are returned. At least one such result always% exists. More than one method may lead to identical% minimization results, in which case a vector is returned.%% .All:% A structure containing the optimization results.% .ChainStr% A cell-array containing one entry for each possible% parenthesization of the input matrix list. Each entry% is a string representing the fully-parenthesized MATLAB% command that would execute the specified matrix chain.% .Chain% A cell-array containing the steps necessary to compute% each specified matrix chain. As opposed to the MATLAB% expressions given in .ChainStr, this field contains% vectors of indices that specify the matrix pairs to% be multiplied, and the output matrix to receive each% result. This data structure is used internally to % store and transfer results.% .ChainCplx% Specifies the complexity of each matrix in the chain% multiplication, including both inputs and the output% of each step. True indicates complex.% .Metrics% A structure containing the optimization metrics% for each parenthesization of the input matrices.% .flopCount% Total flops for each entry in .Chain% .temp_ele% Total number of temporary equivalent-real elements% required for each entry in .Chain. Note that a single% complex element has 2 equivalent-real elements.% .tempOverOutputIdx% Index of the temporary matrix that overwrites the% output matrix area for the purpose of memory optimization.%% .Input:% A structure containing the input arguments to chainmult.% .Sizes% A copy of the matrix size specification (X).% .Complexity% A copy of the matrix complexity specification (C).% .ReuseOutput% Indicates whether the output matrix was specified as reusable% for temporary matrix results.% .Optimize% Indicates whether full optimization is to be performed, or% whether a simple forward-multiplication chain is to be substituted.% .varNames% An optional cell-array of strings containing the variable names% of each of the expressions in the matrix expression list. By% default, this cell-array is empty.%% .ResultList% A string matrix with all optimization results% formatted for display in a GUI (such as CHAINGUI).%% EXAMPLE Three real matrices are to be multiplied together:% A=rand([4 3]); B=rand([3 4]); C=rand([4 2]);%% Y may be computed as either Y=(A*B)*C or Y=A*(B*C).% The two methods of computing the result differ in total% flops and memory, however.% y=(A*B)*C; % requires 160 flops% y=A*(B*C); % requires 96 flops% % To optimize this chain matrix multiplication, execute:% Y=CHAINMULT({size(A) size(B) size(C)})%% See also CHAINGUI.% Author: D. Orofino% Copyright 1984-2002 The MathWorks, Inc.% $Revision: 1.8.6.1 $ $Date: 2003/10/24 12:02:00 $if nargin<1, multSpecs = parse_args;else multSpecs = parse_args(varargin{:});end % Get cell-array of computational steps necessary to% compute all associative-rule pairings for chain% multiplication, based on the number of input matrices.%Nmats = length(multSpecs.Sizes);chains = getAllChains(Nmats, multSpecs.Optimize);% Compute cost of each permutation:%for i=1:length(chains), [metrics(i),chainCplx{i}] = chainMetrics(multSpecs, chains{i}); chainStrs{i} = convertCell2Str(chains{i}, multSpecs.varNames);endy.All.Metrics = metrics;y.All.Chain = chains;y.All.ChainCplx = chainCplx;y.All.ChainStr = chainStrs;y.Optimized = chainOpt(metrics);y.Input = multSpecs;y.ResultList = resultList(y);% -----------------------------------------------------------function P = getAllChains(N, opt)%getAllChains Return cell-array of matrix chain specifications.% P=getAllChains(N) returns a description of all possible% ways of computing the chain multiplication on N inputs.% N is the number of input matrices to be multiplied.%% Each possible parenthesization of the matrix chain problem% is returned in each entry of a vector cell-array, P,% P = { parenChain1, parenChain2, ...}%% Each parenthesization is itself a cell-array, specifying% the order in which the input matrices are to be multiplied% together as matrix pairs.% parenChain1 = {matrixPair1, matrixPair2, ...}%% The multiplication of each matrix pair is represented using% a vector of 3 indices. The first two indices of the vector% determine the input matrices to be multiplied in this step% of the chain multiplication. The third index selects one% of possibly several result matrices to store intermediate% and final chain results.% matrixPair1 = [inputMatrix1 inputMatrix2 resultMatrix]%% The indices are specified as integer-valued scalars.% Input matrix indices are positive integers > 0.% 1,2,3,... -> index of input matrices A, B, C, ...% Result matrix indices are either negative integers < 0,% or may be 0 indicating the ending result matrix location.% 0 -> index of final output matrix% -1,-2,-3,... -> index of temporary result matrices% % Note that the ending result matrix, index 0, may be re-used% for one or more intermediate chain computations only if the% intermediate result does not exceed the final result size.%% Note that matrix multiplies are not generally computed% in-place. Therefore, the location of the result matrix% is never the same as either input matrix.%% Example:% --------% Determine the possible multiply chains for Y=A*B*C.%% P = {{[1 2 -1], [-1 3 0]} % (A*B)*C% {[2 3 -1], [1 -1 0]} % A*(B*C)% }%% The first row of P represents the computation of% (A*B)*C in 2 steps:% Step 1: [1 2 -1]% Multiply input matrices 1 and 2, storing the% result in temporary matrix 1.% Step 2: [-1 3 0]% Multiply temporary matrix 1 by input matrix 3,% storing the result in the output matrix Y.%% The second row of p represents that computation of% A*(B*C), and is carried out in a similar fashion.% The cell-array description of each matrix multiply chain% is produced in two steps:% 1 - Produce a structure representation of each possible% parenthesization, with no temporary matrix specifications.% A cell-array containing each possible structure is returned.%% 2 - Convert from the internal structure representation to% cell-array form, including temporary matrix specification.%% N is the # of matrices in the multiply chain%if opt, list = genStructs(N);else list = genFwdStruct(N);end P = convertStruct2Cell(list);% ---------------------------------------------------------------------------------% Manual enumeration of results for testing:% switch length(x), % # matrices% case 1 % A% p={{[1 0]}}; % (A)% % case 2, % A*B% p={{[1 2 0]}}; % (A*B)% % case 3,% p={ % A*B*C% % 2/1% {[1 2 -1], [-1 3 0]} % (A*B)*C% % 1/2% {[2 3 -1], [1 -1 0]} % A*(B*C)% };% % case 4,% p={ % A*B*C*D% % 1/3% {[3 4 -1] [2 -1 -2] [ 1 -2 0]} % A * (B*(C*D))% {[2 3 -1] [-1 4 -2] [ 1 -2 0]} % A * ((B*C)*D)% % 3/1% {[1 2 -1] [-1 3 -2] [-2 4 0]} % ((A*B)*C) * D% {[2 3 -1] [1 -1 -2] [-2 4 0]} % (A*(B*C)) * D% % 2/2% {[1 2 -1] [3 4 -2] [-1 -2 0]} % (A*B) * (C*D)% };% % case 5,% p={ % A*B*C*D*E% % 1/4% {[4 5 -1] [3 -1 -2] [2 -2 -1] [1 -1 0]} % A * (B * (C*(D*E)))% {[3 4 -1] [-1 5 -2] [2 -2 -1] [1 -1 0]} % A * (B * ((C*D)*E))% {[2 3 -1] [-1 4 -2] [-2 5 -1] [1 -1 0]} % A * (((B*C)*D) * E)% {[3 4 -1] [2 -1 -2] [-2 5 -1] [1 -1 0]} % A * ((B*(C*D)) * E)% {[2 3 -1] [4 5 -2] [-1 -2 -3] [1 -3 0]} % A * ((B*C) * (D*E))% % 4/1% {[3 4 -1] [2 -1 -2] [1 -2 -1] [-1 5 0]} % (A * (B*(C*D))) * E% {[2 3 -1] [-1 4 -2] [1 -2 -1] [-1 5 0]} % (A * ((B*C)*D)) * E% {[1 2 -1] [-1 3 -2] [-2 4 -1] [-1 5 0]} % (((A*B)*C) * D) * E% {[2 3 -1] [1 -1 -2] [-2 4 -1] [-1 5 0]} % ((A*(B*C)) * D) * E% {[1 2 -1] [3 4 -2] [-1 -2 -3] [-3 5 0]} % ((A*B) * (C*D)) * E% % 2/3% {[4 5 -1] [3 -1 -2] [1 2 -1] [-1 -2 0]} % (A*B) * (C*(D*E))% {[3 4 -1] [-1 5 -2] [1 2 -1] [-1 -2 0]} % (A*B) * ((C*D)*E)% % 3/2% {[2 3 -1] [1 -1 -2] [4 5 -1] [-2 -1 0]} % (A*(B*C)) * (D*E)% {[1 2 -1] [-1 3 -2] [4 5 -1] [-2 -1 0]} % ((A*B)*C) * (D*E)% };% % otherwise% p={}; % no enumeration specified% end%end% -----------------------------------------------------------function [metrics,chainCplx] = chainMetrics(multSpecs, chain)%chainMetrics Determine metrics for chain multliplication.% CHAINMETRICS(X, CPLX, CHAIN) computes the flops count and% temporary matrix sizes required to compute the chain% multiplication specified by CHAIN using the matrices% specified by cell-array X and vector CPLX.%% Also returns the complexity counterpart to chain,% which indicates the complexity of all parts of the% chain matrix multiply.matrix_sizes = multSpecs.Sizes;cplx = multSpecs.Complexity;allowOutputOverwrite = multSpecs.ReuseOutput;% Holds the current and maximum temporary matrix sizes% Cell-array of 2-element size vectors, one per temp matrixflopCount = 0;max_temp_realEle = []; % max real-equivalent elements in each temp areatempOverOutputIdx = 0;tempSpecs.Sizes = {};tempSpecs.Complexity = [];single_mat = (length(matrix_sizes)==1);chainCplx = {multSpecs.Complexity(1)}; % default for single matif ~single_mat, for i=1:length(chain), % Get vector of indices for i'th matrix multipler pair. % Vector specifies indices of the 2 input matrices in % the matrix_sizes cell-array, and specifies the location % for the result matrix. % % multIdx = [x1_idx x2_idx x12_idx] -> x12 = x1 * x2 if length(chain{i})==3, % skip single-input case x1_idx = chain{i}(1); % x1 matrix index x2_idx = chain{i}(2); % x2 matrix index x12_idx = chain{i}(3); % x12 result index % The output matrix index, x12_idx, cannot be the same as % either of the input matrix indices, i.e., there's no such % thing as an in-place matrix multiply (in general). % Stated another way, no intermediate computation may read % and write to the same temporary matrix simultaneously. % Thus, a spec such as [2 -1 -1] is invalid, since temp 1 is % read and written in the same step. % if any(x12_idx == [x1_idx x2_idx]), error('Illegal in-place multiply specified in chain permutation.'); end % Get size vectors of x1 and x2 % Each size vector specifies [rows cols] [x1_size, x1_cplx] = getMatrixSize(x1_idx, multSpecs, tempSpecs); [x2_size, x2_cplx] = getMatrixSize(x2_idx, multSpecs, tempSpecs); % Check that pair of matrices is appropriate for % matrix multiplication -> inner dimensions must agree: if length(x1_size)~=2 || length(x2_size)~=2, error('All matrix sizes must be specified as 2-element vectors.'); end if x1_size(2) ~= x2_size(1), error('Invalid matrix sizes specified for chain multiplication.'); end % Determine size of result matrix: x12_size = [x1_size(1) x2_size(2)]; x12_cplx = x1_cplx | x2_cplx; % Construct chainCplx entry: chainCplx{i} = [x1_cplx x2_cplx x12_cplx]; % Flop count for purely-real matrix multiply % Note that we assume 2 adds and 2 multiplies per element % curr_flops = 2*prod(x1_size)*x2_size(2); % % Adjust for complex, as required: if x1_cplx && x2_cplx, % Both input matrices are complex: curr_flops = curr_flops * 6; elseif x1_cplx || x2_cplx, % One input matrix is complex: curr_flops = curr_flops * 2; end flopCount = flopCount + curr_flops; % Determine maximum matrix sizes of all required % temporary storage areas: % if x12_idx == 0, % output area specified for destination % this should be the last multiply to perform: if i~=length(chain), error(['Output index specified prior to last ' ... 'multiply operation in chain.']); end % Output area does not consume any temporary storage else % A temporary storage area has been specified % Check that index is negative, indicating a temp area: if x12_idx >= 0, % All temporary storage indices must be specified
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -