⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 exprdag.ml

📁 用于FFT,来自MIT的源码
💻 ML
📖 第 1 页 / 共 2 页
字号:
  (* simplifiers for various kinds of nodes *)  let rec snumM = function      n when Number.is_zero n -> 	makeNode (Num (Number.zero))    | n when Number.negative n -> 	makeNode (Num (Number.negate n)) >>= suminusM    | n -> makeNode (Num n)  and suminusM = function      Uminus x -> makeNode x    | Num a when (Number.is_zero a) -> snumM Number.zero    | a -> makeNode (Uminus a)  and stimesM = function     | (Uminus a, b) -> stimesM (a, b) >>= suminusM    | (a, Uminus b) -> stimesM (a, b) >>= suminusM    | (Num a, Num b) -> snumM (Number.mul a b)    | (Num a, Times (Num b, c)) -> 	snumM (Number.mul a b) >>= fun x -> stimesM (x, c)    | (Num a, b) when Number.is_zero a -> snumM Number.zero    | (Num a, b) when Number.is_one a -> makeNode b    | (Num a, b) when Number.is_mone a -> suminusM b    | (a, (Num _ as b')) -> stimesM (b', a)    | (a, b) -> makeNode (Times (a, b))  and reduce_sumM x = match x with    [] -> returnM []  | [Num a] ->       if (Number.is_zero a) then 	returnM []       else returnM x  | [Uminus (Num a)] ->       if (Number.is_zero a) then 	returnM []       else returnM x  | (Num a) :: (Num b) :: s ->       snumM (Number.add a b) >>= fun x ->	reduce_sumM (x :: s)  | (Num a) :: (Uminus (Num b)) :: s ->       snumM (Number.sub a b) >>= fun x ->	reduce_sumM (x :: s)  | (Uminus (Num a)) :: (Num b) :: s ->       snumM (Number.sub b a) >>= fun x ->	reduce_sumM (x :: s)  | (Uminus (Num a)) :: (Uminus (Num b)) :: s ->       snumM (Number.add a b) >>=       suminusM >>= fun x ->	reduce_sumM (x :: s)  | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s)  | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s)  | a :: s ->       reduce_sumM s >>= fun s' -> returnM (a :: s')  (* collectCoeffM transforms   *       n x + n y   =>  n (x + y)   * where n is a number *)  and collectCoeffM x =     let rec filterM coeff = function	Times (Num a, b) as y :: rest ->	  filterM coeff rest >>= fun (w, wo) ->	    if (Number.equal a coeff) then	      returnM (b :: w, wo)	    else	      returnM (w, y :: wo)      | Uminus (Times (Num a, b)) as y :: rest ->	  filterM coeff rest >>= fun (w, wo) ->	    if (Number.equal a coeff) then	      suminusM b >>= fun b' ->		returnM (b' :: w, wo)	    else	      returnM (w, y :: wo)      | y :: rest -> 	  filterM coeff rest >>= fun (w, wo) ->	    returnM (w, y :: wo)      |	[] -> returnM ([], [])    and foundCoeffM a x =      filterM a x >>= fun (w, wo) ->	collectCoeffM wo >>= fun wo' ->	  (match w with 	    [d] -> makeNode d 	  | _ -> splusM w) >>= fun p ->	      snumM a >>= fun a' ->		stimesM (a', p) >>= fun ap ->		  returnM (ap :: wo')    in match x with      [] -> returnM []    | Times (Num a, _) :: _ -> foundCoeffM a x    | (Uminus (Times (Num a, b))) :: _  -> foundCoeffM a x    | (a :: c) ->  	collectCoeffM c >>= fun c' ->	  returnM (a :: c')  (* transform   n1 * x + n2 * x ==> (n1 + n2) * x *)  and collectExprM x =     let rec findCoeffM = function	Times (Num a as a', b) -> returnM (a', b)      | Uminus (Times (Num a as a', b)) -> 	  suminusM a' >>= fun a'' ->	    returnM (a'', b)      | Uminus x -> 	  snumM Number.one >>= suminusM >>= fun mone ->	    returnM (mone, x)      | x -> 	  snumM Number.one >>= fun one ->	    returnM (one, x)    and filterM xpr = function	[] -> returnM ([], [])      |	a :: b ->	  filterM xpr b >>= fun (w, wo) ->	    findCoeffM a >>= fun (c, x) ->	      if (xpr == x) then		returnM (c :: w, wo)	      else		returnM (w, a :: wo)    in match x with      [] -> returnM x    | [a] -> returnM x    | a :: b ->	findCoeffM a >>= fun (_, xpr) ->	  filterM xpr x >>= fun (w, wo) ->	    collectExprM wo >>= fun wo' ->	      splusM w >>= fun w' ->		stimesM (w', xpr) >>= fun t' ->		  returnM (t':: wo')  and mangleSumM x = returnM x      >>= reduce_sumM       >>= collectExprM       >>= collectCoeffM       >>= reduce_sumM       >>= eliminateButterflyishPatternsM      >>= reduce_sumM  and reorder_uminus = function  (* push all Uminuses to the end *)      [] -> []    | ((Uminus a) as a' :: b) -> (reorder_uminus b) @ [a']    | (a :: b) -> a :: (reorder_uminus b)                        and canonicalizeM = function       [] -> snumM Number.zero    | [a] -> makeNode a                    (* one term *)    |	a -> makeNode (Plus (reorder_uminus a)) >>= generateFusedMultAddM  and negative = function      Uminus _ -> true    | _ -> false  (*   * simplify patterns of the form   *   *  (c_1 * a + ...) +  (c_2 * a + ...)   *   * The pattern includes arbitrary coefficients and minus signs.   * A common case of this pattern is the butterfly   *   (a + b) + (a - b)   *   (a + b) - (a - b)   *)  and eliminateButterflyishPatternsM l =    let rec findTerms depth x = match x with      | Uminus x -> findTerms depth x      |	Times (Num a, b) -> findTerms (depth - 1) b      |	Plus l when depth > 0 ->	  x :: List.flatten (List.map (findTerms (depth - 1)) l)      |	x -> [x]    and duplicates = function	[] -> []      |	a :: b -> if List.memq a b then a :: duplicates b      else duplicates b    in let rec flattenPlusM d coef x =      if (List.memq x d) then 	snumM coef >>= fun coef' ->	  stimesM (coef', x) >>= fun x' -> returnM [x']      else match x with      |	Times (Num a, b) ->	  flattenPlusM d (Number.mul a coef) b      | Uminus x -> 	  flattenPlusM d (Number.negate coef) x      |	Plus l -> 	  snumM coef >>= fun coef' ->	    mapM (fun x -> stimesM (coef', x)) l       |	x -> snumM coef >>= fun coef' ->	  stimesM (coef', x) >>= fun x' -> returnM [x']    in let l' = List.flatten (List.map (findTerms 1) l)    in let d = duplicates l'    in if (List.length d) > 0 then      mapM (flattenPlusM d Number.one) l >>= fun a ->	collectExprM (List.flatten a) >>=	mangleSumM    else      returnM l  and splusM l = mangleSumM l >>=  fun l' ->  (* no terms are negative.  Don't do anything *)  if not (List.exists negative l') then    canonicalizeM l'  (* all terms are negative.  Negate all of them and collect the minus sign *)  else if List.for_all negative l' then    mapM suminusM l' >>= splusM >>= suminusM  (* some terms are positive and some are negative.  We are in trouble.     Ask the Oracle *)  else if Oracle.should_flip_sign (Plus l') then    mapM suminusM l' >>= splusM >>= suminusM  else    canonicalizeM l'  and generateFusedMultAddM =     let rec is_multiplication = function      | Times (Num a, b) -> true      | Uminus (Times (Num a, b)) -> true      | _ -> false    and separate = function	[] -> ([], [], Number.zero)      | (Times (Num a, b)) as this :: c -> 	  let (x, y, max) = separate c in	  let newmax = if (Number.greater a max) then a else max in	  (this :: x, y, newmax)      | (Uminus (Times (Num a, b))) as this :: c -> 	  let (x, y, max) = separate c in	  let newmax = if (Number.greater a max) then a else max in	  (this :: x, y, newmax)      | this :: c ->	  let (x, y, max) = separate c in	  (x, this :: y, max)    in function	Plus l when (count is_multiplication l >= 2) && !Magic.enable_fma ->	  let (w, wo, max) = separate l in	  snumM (Number.div Number.one max) >>= fun invmax' ->	    snumM max >>= fun max' ->	      mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' ->		stimesM (max', pw') >>= fun mw' ->		  splusM (wo @ [mw'])      | x -> returnM x  (* monadic style algebraic simplifier for the dag *)  let rec algsimpM x =    memoizing lookupSimpM insertSimpM       (function  	  Num a -> snumM a 	| Plus a ->  	    mapM algsimpM a >>= splusM 	| Times (a, b) ->  	    algsimpM a >>= fun a' -> 	      algsimpM b >>= fun b' -> 		stimesM (a', b') 	| Uminus a ->  	    algsimpM a >>= suminusM  	| Store (v, a) -> 	    algsimpM a >>= fun a' -> 	      makeNode (Store (v, a')) 	| x -> makeNode x)      x   let initialTable = (empty, empty)   let simp_roots = mapM algsimpM   let algsimp (Dag dag) = Dag (runM initialTable simp_roots dag)end(* simplify the dag *)let algsimp v =   let _ = info "  first simplification pass..." in  let _ = info (Stats.complexity v) in  let v = AlgSimp.algsimp v in  let _ = info "  second simplification pass..." in  let _ = info (Stats.complexity v) in  let v = Reverse.reverse v in  let _ = info "  third simplification pass..." in  let _ = info (Stats.complexity v) in  let v = AlgSimp.algsimp v in  let _ = info "  fourth simplification pass..." in  let _ = info (Stats.complexity v) in  let v = Reverse.reverse v in  let _ = info "  fifth simplification pass..." in  let _ = info (Stats.complexity v) in  let v = AlgSimp.algsimp v in  let _ = info "  simplification done..." in  let _ = info (Stats.complexity v) in  vlet make nodes = Dag nodes(************************************************************* * Conversion of the dag to an assignment list *************************************************************)(* * This function is messy.  The main problem is that we want to * inline dag nodes conditionally, depending on how many times they * are used.  The Right Thing to do would be to modify the * state monad to propagate some of the state backwards, so that * we know whether a given node will be used again in the future. * This modification is trivial in a lazy language, but it is * messy in a strict language like ML.   * * In this implementation, we just do the obvious thing, i.e., visit * the dag twice, the first to count the node usages, and the second to * produce the output. *)module Destructor :  sig  val to_assignments : dag -> (Variable.variable * Expr.expr) listend = struct  open StateMonad  open MemoMonad  open AssocTable  let fresh = Variable.make_temporary  let fetchAl =     fetchState >>= (fun (al, _, _) -> returnM al)  let storeAl al =    fetchState >>= (fun (_, visited, visited') ->      storeState (al, visited, visited'))  let fetchVisited = fetchState >>= (fun (_, v, _) -> returnM v)  let storeVisited visited =    fetchState >>= (fun (al, _, visited') ->      storeState (al, visited, visited'))  let fetchVisited' = fetchState >>= (fun (_, _, v') -> returnM v')  let storeVisited' visited' =    fetchState >>= (fun (al, visited, _) ->      storeState (al, visited, visited'))  let lookupVisitedM' key =    fetchVisited' >>= fun table ->      returnM (AssocTable.lookup hash_node (==) key table)  let insertVisitedM' key value =    fetchVisited' >>= fun table ->      storeVisited' (AssocTable.insert hash_node key value table)  let counting f x =    fetchVisited >>= (fun v ->      match AssocTable.lookup hash_node (==) x v with	Some count -> 	  fetchVisited >>= (fun v' ->	    storeVisited (AssocTable.insert hash_node 			    x (count + 1) v'))      |	None ->	  f x >>= fun () ->	    fetchVisited >>= (fun v' ->	      storeVisited (AssocTable.insert hash_node 			      x 1 v')))  let with_varM v x =     fetchAl >>= (fun al -> storeAl ((v, x) :: al)) >> returnM (Expr.Var v)  let inlineM = returnM  let with_tempM x = with_varM (fresh ()) x  (* declare a temporary only if node is used more than once *)  let with_temp_maybeM node x =    fetchVisited >>= (fun v ->      match AssocTable.lookup hash_node (==) node v with	Some count -> 	  if (count = 1 && !Magic.inline_single) then	    inlineM x	  else	    with_tempM x      |	None ->	  failwith "with_temp_maybeM")  type fma =       NO_FMA    | FMA of node * node * node   (* FMA (a, b, c) => a + b * c *)    | FMS of node * node * node   (* FMS (a, b, c) => -a + b * c *)    | FNMS of node * node * node  (* FNMS (a, b, c) => a - b * c *)  let build_fma l =     if (not !Magic.enable_fma_expansion) then NO_FMA    else match l with    | [Uminus a; Times (b, c)] -> FMS (a, b, c)    | [Times (b, c); Uminus a] -> FMS (a, b, c)    | [a; Uminus (Times (b, c))] -> FNMS (a, b, c)    | [Uminus (Times (b, c)); a] -> FNMS (a, b, c)    | [a; Times (b, c)] -> FMA (a, b, c)    | [Times (b, c); a] -> FMA (a, b, c)    | _ -> NO_FMA  let children_fma l = match build_fma l with    FMA (a, b, c) -> Some (a, b, c)  | FMS (a, b, c) -> Some (a, b, c)  | FNMS (a, b, c) -> Some (a, b, c)  | NO_FMA -> None  let rec visitM x =    counting (function	Load v -> returnM ()      |	Num a -> returnM ()      |	Store (v, x) -> visitM x      |	Plus a -> (match children_fma a with	  None -> mapM visitM a >> returnM ()	| Some (a, b, c) ->           (* visit fma's arguments twice to make sure they get a variable *)	    visitM a >> visitM a >>	    visitM b >> visitM b >>	    visitM c >> visitM c)      |	Times (a, b) ->	  visitM a >> visitM b      |	Uminus a ->  visitM a)      x  let visit_rootsM = mapM visitM  let rec expr_of_nodeM x =    memoizing lookupVisitedM' insertVisitedM'      (function x -> match x with	Load v -> 	  if (!Magic.inline_loads) then	    inlineM (Expr.Var v)	  else	    with_tempM (Expr.Var v)      | Num a ->	  inlineM (Expr.Num a)      | Store (v, x) -> 	  expr_of_nodeM x >>= 	  with_varM v       | Plus a -> (match build_fma a with	  FMA (a, b, c) ->	  	    expr_of_nodeM a >>= fun a' ->	      expr_of_nodeM b >>= fun b' ->		expr_of_nodeM c >>= fun c' ->		  with_temp_maybeM x (Expr.Plus [a'; Expr.Times (b', c')])	| FMS (a, b, c) ->	  	    expr_of_nodeM a >>= fun a' ->	      expr_of_nodeM b >>= fun b' ->		expr_of_nodeM c >>= fun c' ->		  with_temp_maybeM x 		    (Expr.Plus [Expr.Times (b', c'); Expr.Uminus a'])	| FNMS (a, b, c) ->	  	    expr_of_nodeM a >>= fun a' ->	      expr_of_nodeM b >>= fun b' ->		expr_of_nodeM c >>= fun c' ->		  with_temp_maybeM x 		    (Expr.Plus [a'; Expr.Uminus (Expr.Times (b', c'))])	| NO_FMA ->	    mapM expr_of_nodeM a >>= fun a' ->	      with_temp_maybeM x (Expr.Plus a'))      | Times (a, b) ->	  expr_of_nodeM a >>= fun a' ->	    expr_of_nodeM b >>= fun b' ->	      with_temp_maybeM x (Expr.Times (a', b'))      | Uminus a ->	  expr_of_nodeM a >>= fun a' ->	    inlineM (Expr.Uminus a'))      x  let expr_of_rootsM = mapM expr_of_nodeM  let peek_alistM roots =    visit_rootsM roots >> expr_of_rootsM roots >> fetchAl  let to_assignments (Dag dag) =    List.rev (runM ([], empty, empty) peek_alistM dag)endlet to_assignments = Destructor.to_assignmentslet wrap_assign (a, b) = Expr.Assign (a, b)let simplify_to_alist dag =   let d1 = algsimp dag  in List.map wrap_assign (to_assignments d1)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -