⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 chainmult.m

📁 Matlab中优化多个矩阵相乘以提高计算效率的源代码。
💻 M
📖 第 1 页 / 共 3 页
字号:
function y = chainmult(varargin)%CHAINMULT Optimization of chain matrix multiply problem.%   CHAINMULT(X) performs optimization of a specified matrix%   multiply chain by determining the parenthesization of%   multiplicands that best minimizes two metrics:%   - the maximum size of all temporary matrices that must be%     stored in memory at any time during the evaluation, and%   - the total number of floating point operations (flops)%     required to compute the result.%%   The simplest way to make use of this optimization function%   is via the GUI front-end, CHAINGUI.%%   X is an N-element cell array where each entry in the%   array contains the size of each matrix in the multi-%   plication chain.  For the chain multiply Y=A*B*C,%   X is specified as X = {size(A) size(B) size(C)}.%%   CHAINMULT(X,C) specifies an N-element logical vector C that%   indicates if each input matrix is complex.  A zero indicates%   real matrices, non-zero indicates complex.  Note that length(C)%   must equal length(x).  If omitted, all input matrices are%   assumed to be real.%%   CHAINMULT(X,C,'reuse') specifies that the output matrix%   may be used for temporary matrix computation results.%   By default, the output matrix area is reused.  To disallow%   reuse of the output matrix, specify 'noreuse'.%%   CHAINMULT(...,'noopt') indicates that no optimization is to%   be performed.  The natural forward-chain multiplication is%   selected, that is, (((A*B)*C)*D)...  To perform full optimization,%   omit this string or specify 'opt'.%%   CHAINMULT returns a structure with the following fields:%   .Optimization%       A structure containing the indices of the results that%       jointly optimize the flops and memory metrics.  Note%       that no single method of computing the chain multiply%       problem may jointly minimize both metrics simultaneously.%       Thus, three different minimization results are returned.%       The indices returned are used in conjunction with the list%       of results returned in the Y.All field.%       The specific fields of .Optimization are:%%       .bothIdx - a vector of indices that jointly%          minimize both flops and memory, or empty if no%          index minimized both.  If several methods of%          computation return optimal (and thus identical)%          results, a vector of indices is returned.%%       .flopsIdx - the index of the result that minimizes%          the flops count.  If more than one such result occurs,%          only those that secondarily minimize the memory%          usage are returned.  At least one such result always%          exists.  More than one method may lead to identical%          minimization results, in which case a vector is returned.%           %       .memIdx - the index of the result that minimizes%          memory usage.  If more than one such result occurs,%          only those that secondarily minimize the flops count%          are returned.  At least one such result always%          exists.  More than one method may lead to identical%          minimization results, in which case a vector is returned.%%   .All:%       A structure containing the optimization results.%        .ChainStr%           A cell-array containing one entry for each possible%           parenthesization of the input matrix list.  Each entry%           is a string representing the fully-parenthesized MATLAB%           command that would execute the specified matrix chain.%        .Chain%           A cell-array containing the steps necessary to compute%           each specified matrix chain.  As opposed to the MATLAB%           expressions given in .ChainStr, this field contains%           vectors of indices that specify the matrix pairs to%           be multiplied, and the output matrix to receive each%           result.  This data structure is used internally to %           store and transfer results.%        .ChainCplx%           Specifies the complexity of each matrix in the chain%           multiplication, including both inputs and the output%           of each step.  True indicates complex.%        .Metrics%           A structure containing the optimization metrics%           for each parenthesization of the input matrices.%             .flopCount%                Total flops for each entry in .Chain%             .temp_ele%                Total number of temporary equivalent-real elements%                required for each entry in .Chain.  Note that a single%                complex element has 2 equivalent-real elements.%             .tempOverOutputIdx%                Index of the temporary matrix that overwrites the%                output matrix area for the purpose of memory optimization.%%   .Input:%       A structure containing the input arguments to chainmult.%        .Sizes%           A copy of the matrix size specification (X).%        .Complexity%           A copy of the matrix complexity specification (C).%        .ReuseOutput%           Indicates whether the output matrix was specified as reusable%           for temporary matrix results.%        .Optimize%           Indicates whether full optimization is to be performed, or%           whether a simple forward-multiplication chain is to be substituted.%        .varNames%           An optional cell-array of strings containing the variable names%           of each of the expressions in the matrix expression list.  By%           default, this cell-array is empty.%%   .ResultList%       A string matrix with all optimization results%       formatted for display in a GUI (such as CHAINGUI).%%   EXAMPLE Three real matrices are to be multiplied together:%      A=rand([4 3]); B=rand([3 4]); C=rand([4 2]);%%   Y may be computed as either Y=(A*B)*C or Y=A*(B*C).%   The two methods of computing the result differ in total%   flops and memory, however.%     y=(A*B)*C;  % requires 160 flops%     y=A*(B*C);  % requires 96 flops%   %   To optimize this chain matrix multiplication, execute:%     Y=CHAINMULT({size(A) size(B) size(C)})%% See also CHAINGUI.% Author: D. Orofino% Copyright 1984-2002 The MathWorks, Inc.% $Revision: 1.8.6.1 $ $Date: 2003/10/24 12:02:00 $if nargin<1,	multSpecs = parse_args;else	multSpecs = parse_args(varargin{:});end	% Get cell-array of computational steps necessary to% compute all associative-rule pairings for chain% multiplication, based on the number of input matrices.%Nmats = length(multSpecs.Sizes);chains = getAllChains(Nmats, multSpecs.Optimize);% Compute cost of each permutation:%for i=1:length(chains),   [metrics(i),chainCplx{i}] = chainMetrics(multSpecs, chains{i});   chainStrs{i} = convertCell2Str(chains{i}, multSpecs.varNames);endy.All.Metrics   = metrics;y.All.Chain     = chains;y.All.ChainCplx = chainCplx;y.All.ChainStr  = chainStrs;y.Optimized     = chainOpt(metrics);y.Input         = multSpecs;y.ResultList    = resultList(y);% -----------------------------------------------------------function P = getAllChains(N, opt)%getAllChains Return cell-array of matrix chain specifications.%   P=getAllChains(N) returns a description of all possible%   ways of computing the chain multiplication on N inputs.%   N is the number of input matrices to be multiplied.%%   Each possible parenthesization of the matrix chain problem%   is returned in each entry of a vector cell-array, P,%       P = { parenChain1, parenChain2, ...}%%   Each parenthesization is itself a cell-array, specifying%   the order in which the input matrices are to be multiplied%   together as matrix pairs.%      parenChain1 =  {matrixPair1, matrixPair2, ...}%%   The multiplication of each matrix pair is represented using%   a vector of 3 indices.  The first two indices of the vector%   determine the input matrices to be multiplied in this step%   of the chain multiplication. The third index selects one%   of possibly several result matrices to store intermediate%   and final chain results.%      matrixPair1 = [inputMatrix1 inputMatrix2 resultMatrix]%%   The indices are specified as integer-valued scalars.%   Input matrix indices are positive integers > 0.%        1,2,3,...   -> index of input matrices A, B, C, ...%   Result matrix indices are either negative integers < 0,%   or may be 0 indicating the ending result matrix location.%        0           -> index of final output matrix%       -1,-2,-3,... -> index of temporary result matrices%       %   Note that the ending result matrix, index 0, may be re-used%   for one or more intermediate chain computations only if the%   intermediate result does not exceed the final result size.%%   Note that matrix multiplies are not generally computed%   in-place.  Therefore, the location of the result matrix%   is never the same as either input matrix.%%  Example:%  --------%  Determine the possible multiply chains for Y=A*B*C.%%  P = {{[1 2 -1], [-1 3 0]}  % (A*B)*C%       {[2 3 -1], [1 -1 0]}  % A*(B*C)%      }%%  The first row of P represents the computation of%  (A*B)*C in 2 steps:%    Step 1: [1 2 -1]%       Multiply input matrices 1 and 2, storing the%       result in temporary matrix 1.%    Step 2: [-1 3 0]%       Multiply temporary matrix 1 by input matrix 3,%       storing the result in the output matrix Y.%%  The second row of p represents that computation of%  A*(B*C), and is carried out in a similar fashion.% The cell-array description of each matrix multiply chain% is produced in two steps:%  1 - Produce a structure representation of each possible%      parenthesization, with no temporary matrix specifications.%      A cell-array containing each possible structure is returned.%%  2 - Convert from the internal structure representation to%      cell-array form, including temporary matrix specification.%% N is the # of matrices in the multiply chain%if opt,	list = genStructs(N);else    list = genFwdStruct(N);end	P = convertStruct2Cell(list);% ---------------------------------------------------------------------------------% Manual enumeration of results for testing:%    switch length(x),  % # matrices%    case 1           % A%        p={{[1 0]}};  % (A)%        %    case 2,           % A*B%        p={{[1 2 0]}}; % (A*B)%        %    case 3,%        p={                      % A*B*C%            % 2/1%            {[1 2 -1], [-1 3 0]}  % (A*B)*C%            % 1/2%            {[2 3 -1], [1 -1 0]}  % A*(B*C)%        };%        %    case 4,%        p={                                 % A*B*C*D%            % 1/3%            {[3 4 -1] [2 -1 -2] [ 1 -2 0]}   % A * (B*(C*D))%            {[2 3 -1] [-1 4 -2] [ 1 -2 0]}   % A * ((B*C)*D)%            % 3/1%            {[1 2 -1] [-1 3 -2] [-2  4 0]}   % ((A*B)*C) * D%            {[2 3 -1] [1 -1 -2] [-2  4 0]}   % (A*(B*C)) * D%            % 2/2%            {[1 2 -1] [3  4 -2] [-1 -2 0]}   % (A*B) * (C*D)%        };%        %    case 5,%        p={                                        % A*B*C*D*E%            % 1/4%            {[4 5 -1] [3 -1 -2] [2 -2 -1] [1 -1 0]} % A * (B * (C*(D*E)))%            {[3 4 -1] [-1 5 -2] [2 -2 -1] [1 -1 0]} % A * (B * ((C*D)*E))%            {[2 3 -1] [-1 4 -2] [-2 5 -1] [1 -1 0]} % A * (((B*C)*D) * E)%            {[3 4 -1] [2 -1 -2] [-2 5 -1] [1 -1 0]} % A * ((B*(C*D)) * E)%            {[2 3 -1] [4 5 -2] [-1 -2 -3] [1 -3 0]} % A * ((B*C) * (D*E))%            % 4/1%            {[3 4 -1] [2 -1 -2] [1 -2 -1] [-1 5 0]} % (A * (B*(C*D))) * E%            {[2 3 -1] [-1 4 -2] [1 -2 -1] [-1 5 0]} % (A * ((B*C)*D)) * E%            {[1 2 -1] [-1 3 -2] [-2 4 -1] [-1 5 0]} % (((A*B)*C) * D) * E%            {[2 3 -1] [1 -1 -2] [-2 4 -1] [-1 5 0]} % ((A*(B*C)) * D) * E%            {[1 2 -1] [3 4 -2] [-1 -2 -3] [-3 5 0]} % ((A*B) * (C*D)) * E%            % 2/3%            {[4 5 -1] [3 -1 -2] [1 2 -1] [-1 -2 0]} % (A*B) * (C*(D*E))%            {[3 4 -1] [-1 5 -2] [1 2 -1] [-1 -2 0]} % (A*B) * ((C*D)*E)%            % 3/2%            {[2 3 -1] [1 -1 -2] [4 5 -1] [-2 -1 0]} % (A*(B*C)) * (D*E)%            {[1 2 -1] [-1 3 -2] [4 5 -1] [-2 -1 0]} % ((A*B)*C) * (D*E)%        };%        %    otherwise%        p={};  % no enumeration specified%    end%end% -----------------------------------------------------------function [metrics,chainCplx] = chainMetrics(multSpecs, chain)%chainMetrics Determine metrics for chain multliplication.%   CHAINMETRICS(X, CPLX, CHAIN) computes the flops count and%   temporary matrix sizes required to compute the chain%   multiplication specified by CHAIN using the matrices%   specified by cell-array X and vector CPLX.%%   Also returns the complexity counterpart to chain,%   which indicates the complexity of all parts of the%   chain matrix multiply.matrix_sizes = multSpecs.Sizes;cplx         = multSpecs.Complexity;allowOutputOverwrite = multSpecs.ReuseOutput;% Holds the current and maximum temporary matrix sizes% Cell-array of 2-element size vectors, one per temp matrixflopCount    = 0;max_temp_realEle = [];  % max real-equivalent elements in each temp areatempOverOutputIdx = 0;tempSpecs.Sizes      = {};tempSpecs.Complexity = [];single_mat = (length(matrix_sizes)==1);chainCplx = {multSpecs.Complexity(1)};  % default for single matif ~single_mat,      for i=1:length(chain),      % Get vector of indices for i'th matrix multipler pair.      % Vector specifies indices of the 2 input matrices in      % the matrix_sizes cell-array, and specifies the location      % for the result matrix.      %      % multIdx = [x1_idx x2_idx x12_idx]  -> x12 = x1 * x2            if length(chain{i})==3,    % skip single-input case         x1_idx  = chain{i}(1);  % x1 matrix index         x2_idx  = chain{i}(2);  % x2 matrix index         x12_idx = chain{i}(3);  % x12 result index                  % The output matrix index, x12_idx, cannot be the same as         % either of the input matrix indices, i.e., there's no such         % thing as an in-place matrix multiply (in general).         % Stated another way, no intermediate computation may read         % and write to the same temporary matrix simultaneously.         % Thus, a spec such as [2 -1 -1] is invalid, since temp 1 is         % read and written in the same step.         %         if any(x12_idx == [x1_idx x2_idx]),            error('Illegal in-place multiply specified in chain permutation.');         end                  % Get size vectors of x1 and x2         % Each size vector specifies [rows cols]         [x1_size, x1_cplx] = getMatrixSize(x1_idx, multSpecs, tempSpecs);         [x2_size, x2_cplx] = getMatrixSize(x2_idx, multSpecs, tempSpecs);                  % Check that pair of matrices is appropriate for         % matrix multiplication -> inner dimensions must agree:         if length(x1_size)~=2 || length(x2_size)~=2,            error('All matrix sizes must be specified as 2-element vectors.');         end         if x1_size(2) ~= x2_size(1),            error('Invalid matrix sizes specified for chain multiplication.');         end                  % Determine size of result matrix:         x12_size = [x1_size(1) x2_size(2)];         x12_cplx = x1_cplx | x2_cplx;                  % Construct chainCplx entry:         chainCplx{i} = [x1_cplx x2_cplx x12_cplx];                  % Flop count for purely-real matrix multiply         % Note that we assume 2 adds and 2 multiplies per element         %         curr_flops = 2*prod(x1_size)*x2_size(2);         %         % Adjust for complex, as required:         if x1_cplx && x2_cplx,            % Both input matrices are complex:            curr_flops = curr_flops * 6;         elseif x1_cplx || x2_cplx,            % One input matrix is complex:            curr_flops = curr_flops * 2;         end         flopCount = flopCount + curr_flops;                  % Determine maximum matrix sizes of all required         % temporary storage areas:         %         if x12_idx == 0,            % output area specified for destination            % this should be the last multiply to perform:            if i~=length(chain),               error(['Output index specified prior to last ' ...                     'multiply operation in chain.']);            end                        % Output area does not consume any temporary storage                     else            % A temporary storage area has been specified                        % Check that index is negative, indicating a temp area:            if x12_idx >= 0,               % All temporary storage indices must be specified

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -