📄 exprdag.ml
字号:
(* simplifiers for various kinds of nodes *) let rec snumM = function n when Number.is_zero n -> makeNode (Num (Number.zero)) | n when Number.negative n -> makeNode (Num (Number.negate n)) >>= suminusM | n -> makeNode (Num n) and suminusM = function Uminus x -> makeNode x | Num a when (Number.is_zero a) -> snumM Number.zero | a -> makeNode (Uminus a) and stimesM = function | (Uminus a, b) -> stimesM (a, b) >>= suminusM | (a, Uminus b) -> stimesM (a, b) >>= suminusM | (Num a, Num b) -> snumM (Number.mul a b) | (Num a, Times (Num b, c)) -> snumM (Number.mul a b) >>= fun x -> stimesM (x, c) | (Num a, b) when Number.is_zero a -> snumM Number.zero | (Num a, b) when Number.is_one a -> makeNode b | (Num a, b) when Number.is_mone a -> suminusM b | (a, (Num _ as b')) -> stimesM (b', a) | (a, b) -> makeNode (Times (a, b)) and reduce_sumM x = match x with [] -> returnM [] | [Num a] -> if (Number.is_zero a) then returnM [] else returnM x | [Uminus (Num a)] -> if (Number.is_zero a) then returnM [] else returnM x | (Num a) :: (Num b) :: s -> snumM (Number.add a b) >>= fun x -> reduce_sumM (x :: s) | (Num a) :: (Uminus (Num b)) :: s -> snumM (Number.sub a b) >>= fun x -> reduce_sumM (x :: s) | (Uminus (Num a)) :: (Num b) :: s -> snumM (Number.sub b a) >>= fun x -> reduce_sumM (x :: s) | (Uminus (Num a)) :: (Uminus (Num b)) :: s -> snumM (Number.add a b) >>= suminusM >>= fun x -> reduce_sumM (x :: s) | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s) | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s) | a :: s -> reduce_sumM s >>= fun s' -> returnM (a :: s') (* collectCoeffM transforms * n x + n y => n (x + y) * where n is a number *) and collectCoeffM x = let rec filterM coeff = function Times (Num a, b) as y :: rest -> filterM coeff rest >>= fun (w, wo) -> if (Number.equal a coeff) then returnM (b :: w, wo) else returnM (w, y :: wo) | Uminus (Times (Num a, b)) as y :: rest -> filterM coeff rest >>= fun (w, wo) -> if (Number.equal a coeff) then suminusM b >>= fun b' -> returnM (b' :: w, wo) else returnM (w, y :: wo) | y :: rest -> filterM coeff rest >>= fun (w, wo) -> returnM (w, y :: wo) | [] -> returnM ([], []) and foundCoeffM a x = filterM a x >>= fun (w, wo) -> collectCoeffM wo >>= fun wo' -> (match w with [d] -> makeNode d | _ -> splusM w) >>= fun p -> snumM a >>= fun a' -> stimesM (a', p) >>= fun ap -> returnM (ap :: wo') in match x with [] -> returnM [] | Times (Num a, _) :: _ -> foundCoeffM a x | (Uminus (Times (Num a, b))) :: _ -> foundCoeffM a x | (a :: c) -> collectCoeffM c >>= fun c' -> returnM (a :: c') (* transform n1 * x + n2 * x ==> (n1 + n2) * x *) and collectExprM x = let rec findCoeffM = function Times (Num a as a', b) -> returnM (a', b) | Uminus (Times (Num a as a', b)) -> suminusM a' >>= fun a'' -> returnM (a'', b) | Uminus x -> snumM Number.one >>= suminusM >>= fun mone -> returnM (mone, x) | x -> snumM Number.one >>= fun one -> returnM (one, x) and filterM xpr = function [] -> returnM ([], []) | a :: b -> filterM xpr b >>= fun (w, wo) -> findCoeffM a >>= fun (c, x) -> if (xpr == x) then returnM (c :: w, wo) else returnM (w, a :: wo) in match x with [] -> returnM x | [a] -> returnM x | a :: b -> findCoeffM a >>= fun (_, xpr) -> filterM xpr x >>= fun (w, wo) -> collectExprM wo >>= fun wo' -> splusM w >>= fun w' -> stimesM (w', xpr) >>= fun t' -> returnM (t':: wo') and mangleSumM x = returnM x >>= reduce_sumM >>= collectExprM >>= collectCoeffM >>= reduce_sumM >>= eliminateButterflyishPatternsM >>= reduce_sumM and reorder_uminus = function (* push all Uminuses to the end *) [] -> [] | ((Uminus a) as a' :: b) -> (reorder_uminus b) @ [a'] | (a :: b) -> a :: (reorder_uminus b) and canonicalizeM = function [] -> snumM Number.zero | [a] -> makeNode a (* one term *) | a -> makeNode (Plus (reorder_uminus a)) >>= generateFusedMultAddM and negative = function Uminus _ -> true | _ -> false (* * simplify patterns of the form * * (c_1 * a + ...) + (c_2 * a + ...) * * The pattern includes arbitrary coefficients and minus signs. * A common case of this pattern is the butterfly * (a + b) + (a - b) * (a + b) - (a - b) *) and eliminateButterflyishPatternsM l = let rec findTerms depth x = match x with | Uminus x -> findTerms depth x | Times (Num a, b) -> findTerms (depth - 1) b | Plus l when depth > 0 -> x :: List.flatten (List.map (findTerms (depth - 1)) l) | x -> [x] and duplicates = function [] -> [] | a :: b -> if List.memq a b then a :: duplicates b else duplicates b in let rec flattenPlusM d coef x = if (List.memq x d) then snumM coef >>= fun coef' -> stimesM (coef', x) >>= fun x' -> returnM [x'] else match x with | Times (Num a, b) -> flattenPlusM d (Number.mul a coef) b | Uminus x -> flattenPlusM d (Number.negate coef) x | Plus l -> snumM coef >>= fun coef' -> mapM (fun x -> stimesM (coef', x)) l | x -> snumM coef >>= fun coef' -> stimesM (coef', x) >>= fun x' -> returnM [x'] in let l' = List.flatten (List.map (findTerms 1) l) in let d = duplicates l' in if (List.length d) > 0 then mapM (flattenPlusM d Number.one) l >>= fun a -> collectExprM (List.flatten a) >>= mangleSumM else returnM l and splusM l = mangleSumM l >>= fun l' -> (* no terms are negative. Don't do anything *) if not (List.exists negative l') then canonicalizeM l' (* all terms are negative. Negate all of them and collect the minus sign *) else if List.for_all negative l' then mapM suminusM l' >>= splusM >>= suminusM (* some terms are positive and some are negative. We are in trouble. Ask the Oracle *) else if Oracle.should_flip_sign (Plus l') then mapM suminusM l' >>= splusM >>= suminusM else canonicalizeM l' and generateFusedMultAddM = let rec is_multiplication = function | Times (Num a, b) -> true | Uminus (Times (Num a, b)) -> true | _ -> false and separate = function [] -> ([], [], Number.zero) | (Times (Num a, b)) as this :: c -> let (x, y, max) = separate c in let newmax = if (Number.greater a max) then a else max in (this :: x, y, newmax) | (Uminus (Times (Num a, b))) as this :: c -> let (x, y, max) = separate c in let newmax = if (Number.greater a max) then a else max in (this :: x, y, newmax) | this :: c -> let (x, y, max) = separate c in (x, this :: y, max) in function Plus l when (count is_multiplication l >= 2) && !Magic.enable_fma -> let (w, wo, max) = separate l in snumM (Number.div Number.one max) >>= fun invmax' -> snumM max >>= fun max' -> mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' -> stimesM (max', pw') >>= fun mw' -> splusM (wo @ [mw']) | x -> returnM x (* monadic style algebraic simplifier for the dag *) let rec algsimpM x = memoizing lookupSimpM insertSimpM (function Num a -> snumM a | Plus a -> mapM algsimpM a >>= splusM | Times (a, b) -> algsimpM a >>= fun a' -> algsimpM b >>= fun b' -> stimesM (a', b') | Uminus a -> algsimpM a >>= suminusM | Store (v, a) -> algsimpM a >>= fun a' -> makeNode (Store (v, a')) | x -> makeNode x) x let initialTable = (empty, empty) let simp_roots = mapM algsimpM let algsimp (Dag dag) = Dag (runM initialTable simp_roots dag)end(* simplify the dag *)let algsimp v = let _ = info " first simplification pass..." in let _ = info (Stats.complexity v) in let v = AlgSimp.algsimp v in let _ = info " second simplification pass..." in let _ = info (Stats.complexity v) in let v = Reverse.reverse v in let _ = info " third simplification pass..." in let _ = info (Stats.complexity v) in let v = AlgSimp.algsimp v in let _ = info " fourth simplification pass..." in let _ = info (Stats.complexity v) in let v = Reverse.reverse v in let _ = info " fifth simplification pass..." in let _ = info (Stats.complexity v) in let v = AlgSimp.algsimp v in let _ = info " simplification done..." in let _ = info (Stats.complexity v) in vlet make nodes = Dag nodes(************************************************************* * Conversion of the dag to an assignment list *************************************************************)(* * This function is messy. The main problem is that we want to * inline dag nodes conditionally, depending on how many times they * are used. The Right Thing to do would be to modify the * state monad to propagate some of the state backwards, so that * we know whether a given node will be used again in the future. * This modification is trivial in a lazy language, but it is * messy in a strict language like ML. * * In this implementation, we just do the obvious thing, i.e., visit * the dag twice, the first to count the node usages, and the second to * produce the output. *)module Destructor : sig val to_assignments : dag -> (Variable.variable * Expr.expr) listend = struct open StateMonad open MemoMonad open AssocTable let fresh = Variable.make_temporary let fetchAl = fetchState >>= (fun (al, _, _) -> returnM al) let storeAl al = fetchState >>= (fun (_, visited, visited') -> storeState (al, visited, visited')) let fetchVisited = fetchState >>= (fun (_, v, _) -> returnM v) let storeVisited visited = fetchState >>= (fun (al, _, visited') -> storeState (al, visited, visited')) let fetchVisited' = fetchState >>= (fun (_, _, v') -> returnM v') let storeVisited' visited' = fetchState >>= (fun (al, visited, _) -> storeState (al, visited, visited')) let lookupVisitedM' key = fetchVisited' >>= fun table -> returnM (AssocTable.lookup hash_node (==) key table) let insertVisitedM' key value = fetchVisited' >>= fun table -> storeVisited' (AssocTable.insert hash_node key value table) let counting f x = fetchVisited >>= (fun v -> match AssocTable.lookup hash_node (==) x v with Some count -> fetchVisited >>= (fun v' -> storeVisited (AssocTable.insert hash_node x (count + 1) v')) | None -> f x >>= fun () -> fetchVisited >>= (fun v' -> storeVisited (AssocTable.insert hash_node x 1 v'))) let with_varM v x = fetchAl >>= (fun al -> storeAl ((v, x) :: al)) >> returnM (Expr.Var v) let inlineM = returnM let with_tempM x = with_varM (fresh ()) x (* declare a temporary only if node is used more than once *) let with_temp_maybeM node x = fetchVisited >>= (fun v -> match AssocTable.lookup hash_node (==) node v with Some count -> if (count = 1 && !Magic.inline_single) then inlineM x else with_tempM x | None -> failwith "with_temp_maybeM") type fma = NO_FMA | FMA of node * node * node (* FMA (a, b, c) => a + b * c *) | FMS of node * node * node (* FMS (a, b, c) => -a + b * c *) | FNMS of node * node * node (* FNMS (a, b, c) => a - b * c *) let build_fma l = if (not !Magic.enable_fma_expansion) then NO_FMA else match l with | [Uminus a; Times (b, c)] -> FMS (a, b, c) | [Times (b, c); Uminus a] -> FMS (a, b, c) | [a; Uminus (Times (b, c))] -> FNMS (a, b, c) | [Uminus (Times (b, c)); a] -> FNMS (a, b, c) | [a; Times (b, c)] -> FMA (a, b, c) | [Times (b, c); a] -> FMA (a, b, c) | _ -> NO_FMA let children_fma l = match build_fma l with FMA (a, b, c) -> Some (a, b, c) | FMS (a, b, c) -> Some (a, b, c) | FNMS (a, b, c) -> Some (a, b, c) | NO_FMA -> None let rec visitM x = counting (function Load v -> returnM () | Num a -> returnM () | Store (v, x) -> visitM x | Plus a -> (match children_fma a with None -> mapM visitM a >> returnM () | Some (a, b, c) -> (* visit fma's arguments twice to make sure they get a variable *) visitM a >> visitM a >> visitM b >> visitM b >> visitM c >> visitM c) | Times (a, b) -> visitM a >> visitM b | Uminus a -> visitM a) x let visit_rootsM = mapM visitM let rec expr_of_nodeM x = memoizing lookupVisitedM' insertVisitedM' (function x -> match x with Load v -> if (!Magic.inline_loads) then inlineM (Expr.Var v) else with_tempM (Expr.Var v) | Num a -> inlineM (Expr.Num a) | Store (v, x) -> expr_of_nodeM x >>= with_varM v | Plus a -> (match build_fma a with FMA (a, b, c) -> expr_of_nodeM a >>= fun a' -> expr_of_nodeM b >>= fun b' -> expr_of_nodeM c >>= fun c' -> with_temp_maybeM x (Expr.Plus [a'; Expr.Times (b', c')]) | FMS (a, b, c) -> expr_of_nodeM a >>= fun a' -> expr_of_nodeM b >>= fun b' -> expr_of_nodeM c >>= fun c' -> with_temp_maybeM x (Expr.Plus [Expr.Times (b', c'); Expr.Uminus a']) | FNMS (a, b, c) -> expr_of_nodeM a >>= fun a' -> expr_of_nodeM b >>= fun b' -> expr_of_nodeM c >>= fun c' -> with_temp_maybeM x (Expr.Plus [a'; Expr.Uminus (Expr.Times (b', c'))]) | NO_FMA -> mapM expr_of_nodeM a >>= fun a' -> with_temp_maybeM x (Expr.Plus a')) | Times (a, b) -> expr_of_nodeM a >>= fun a' -> expr_of_nodeM b >>= fun b' -> with_temp_maybeM x (Expr.Times (a', b')) | Uminus a -> expr_of_nodeM a >>= fun a' -> inlineM (Expr.Uminus a')) x let expr_of_rootsM = mapM expr_of_nodeM let peek_alistM roots = visit_rootsM roots >> expr_of_rootsM roots >> fetchAl let to_assignments (Dag dag) = List.rev (runM ([], empty, empty) peek_alistM dag)endlet to_assignments = Destructor.to_assignmentslet wrap_assign (a, b) = Expr.Assign (a, b)let simplify_to_alist dag = let d1 = algsimp dag in List.map wrap_assign (to_assignments d1)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -