📄 exprdag.ml
字号:
(* * Copyright (c) 1997-1999, 2003 Massachusetts Institute of Technology * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA * *)(* $Id: exprdag.ml,v 1.43 2003/03/16 23:43:46 stevenj Exp $ *)let cvsid = "$Id: exprdag.ml,v 1.43 2003/03/16 23:43:46 stevenj Exp $"open Utiltype node = | Num of Number.number | Load of Variable.variable | Store of Variable.variable * node | Plus of node list | Times of node * node | Uminus of node(* a dag is represented by the list of its roots *)type dag = Dag of (node list)module Hash = struct (* various hash functions *) let hash_float x = let (mantissa, exponent) = frexp x in truncate (mantissa *. 10000.0) let hash_variable = Variable.hash let rec hash_node = function Num x -> hash_float (Number.to_float x) | Load v -> 1 + 1237 * hash_variable v | Store (v, x) -> 2 * hash_variable v - 2345 * hash_node x | Plus l -> 5 + 23451 * sum_list (List.map Hashtbl.hash l) | Times (a, b) -> 31415 * Hashtbl.hash a + 2718 * Hashtbl.hash b | Uminus x -> 42 + 12345 * (hash_node x)endopen Hashmodule LittleSimplifier = struct (* * The LittleSimplifier module implements a subset of the simplifications * of the AlgSimp module. These simplifications can be executed * quickly here, while they would take a long time using the heavy * machinery of AlgSimp. * * For example, 0 * x is simplified to 0 tout court by the LittleSimplifier. * On the other hand, AlgSimp would first simplify x, generating lots * of common subexpressions, storing them in a table etc, just to * discard all the work later. Similarly, the LittleSimplifier * reduces the constant FFT in Rader's algorithm to a constant sequence. *) let rec makeNum = function | n -> Num n and makeUminus = function | Uminus a -> a | Num a -> makeNum (Number.negate a) | a -> Uminus a and makeTimes = function | (Num a, Num b) -> makeNum (Number.mul a b) | (Num a, Times (Num b, c)) -> makeTimes (makeNum (Number.mul a b), c) | (Num a, b) when Number.is_zero a -> makeNum (Number.zero) | (Num a, b) when Number.is_one a -> b | (Num a, b) when Number.is_mone a -> makeUminus b | (Num a, Uminus b) -> Times (makeUminus (Num a), b) | (a, (Num b as b')) -> makeTimes (b', a) | (a, b) -> Times (a, b) and makePlus l = let rec reduceSum x = match x with [] -> [] | [Num a] -> if Number.is_zero a then [] else x | (Num a) :: (Num b) :: c -> reduceSum ((makeNum (Number.add a b)) :: c) | ((Num _) as a') :: b :: c -> b :: reduceSum (a' :: c) | a :: s -> a :: reduceSum s in match reduceSum l with [] -> makeNum (Number.zero) | [a] -> a | [a; b] when a == b -> makeTimes (Num Number.two, a) | [Times (Num a, b); Times (Num c, d)] when b == d -> makeTimes (makePlus [Num a; Num c], b) | a -> Plus aend(************************************************************* * Functional associative table *************************************************************)(* * this module implements a functional associative table. * The table is parametrized by an equality predicate and * a hash function, with the restriction that (equal a b) ==> * hash a == hash b. * The table is purely functional and implemented using a binary * search tree (not balanced for now) *)module AssocTable : sig type ('a, 'b) elem = | Leaf | Node of int * ('a, 'b) elem * ('a, 'b) elem * ('a * 'b) list val empty : ('a, 'b) elem val lookup : ('a -> int) -> ('a -> 'b -> bool) -> 'a -> ('b, 'c) elem -> 'c option val insert : ('a -> int) -> 'a -> 'c -> ('a, 'c) elem -> ('a, 'c) elemend = struct type ('a, 'b) elem = Leaf | Node of int * ('a, 'b) elem * ('a, 'b) elem * ('a * 'b) list let empty = Leaf let lookup hash equal key table = let h = hash key in let rec look = function Leaf -> None | Node (hash_key, left, right, this_list) -> if (hash_key < h) then look left else if (hash_key > h) then look right else let rec loop = function [] -> None | (a, b) :: rest -> if (equal key a) then Some b else loop rest in loop this_list in look table let insert hash key value table = let h = hash key in let rec ins = function Leaf -> Node (h, Leaf, Leaf, [(key, value)]) | Node (hash_key, left, right, this_list) -> if (hash_key < h) then Node (hash_key, ins left, right, this_list) else if (hash_key > h) then Node (hash_key, left, ins right, this_list) else Node (hash_key, left, right, (key, value) :: this_list) in ins tableendlet node_insert = AssocTable.insert hash_nodelet node_lookup = AssocTable.lookup hash_node (==)(************************************************************* * Monads *************************************************************)(* * Phil Wadler has many well written papers about monads. See * http://cm.bell-labs.com/cm/cs/who/wadler/ *)(* vanilla state monad *)module StateMonad = struct let returnM x = fun s -> (x, s) let (>>=) = fun m k -> fun s -> let (a', s') = m s in let (a'', s'') = k a' s' in (a'', s'') let (>>) = fun m k -> m >>= fun _ -> k let rec mapM f = function [] -> returnM [] | a :: b -> f a >>= fun a' -> mapM f b >>= fun b' -> returnM (a' :: b') let runM m x initial_state = let (a, _) = m x initial_state in a let fetchState = fun s -> s, s let storeState newState = fun _ -> (), newStateend(* monad with built-in memoizing capabilities *)module MemoMonad = struct open StateMonad let memoizing lookupM insertM f k = lookupM k >>= fun vMaybe -> match vMaybe with Some value -> returnM value | None -> f k >>= fun value -> insertM k value >> returnM value let runM initial_state m x = StateMonad.runM m x initial_stateendmodule Oracle : sig val should_flip_sign : node -> boolend = struct open AssocTable let make_memoizer hash equal = let table = ref empty in fun f k -> match lookup hash equal k !table with Some value -> value | None -> let value = f k in begin table := insert hash k value !table; value end let almost_equal x y = let epsilon = 1.0E-8 in (abs_float (x -. y) < epsilon) or (abs_float (x -. y) < epsilon *. (abs_float x +. abs_float y)) let memoizing_numbers = make_memoizer (fun x -> hash_float (abs_float x)) (fun a b -> almost_equal a b or almost_equal (-. a) b) let absid = memoizing_numbers (fun x -> x) let memoizing_variables = make_memoizer hash_variable Variable.same let memoizing_nodes = make_memoizer hash_node (==) let random_oracle = memoizing_variables (fun _ -> (float (Random.bits())) /. 1073741824.0) let sum_list l = List.fold_right (+.) l 0.0 let rec eval x = memoizing_nodes (function Num x -> Number.to_float x | Load v -> random_oracle v | Store (v, x) -> random_oracle v | Plus l -> sum_list (List.map eval l) | Times (a, b) -> (eval a) *. (eval b) | Uminus x -> -. (eval x) ) x let should_flip_sign node = let v = eval node in let v' = absid v in not (almost_equal v v')endmodule Reverse = struct open StateMonad open MemoMonad open AssocTable open LittleSimplifier let fetchDuals = fetchState let storeDuals = storeState let lookupDualsM key = fetchDuals >>= fun table -> returnM (node_lookup key table) let insertDualsM key value = fetchDuals >>= fun table -> storeDuals (node_insert key value table) let rec visit visited vtable parent_table = function [] -> (visited, parent_table) | node :: rest -> match AssocTable.lookup hash_node (==) node vtable with Some _ -> visit visited vtable parent_table rest | None -> let children = match node with Store (v, n) -> [n] | Plus l -> l | Times (a, b) -> [a; b] | Uminus x -> [x] | _ -> [] in let rec loop t = function [] -> t | a :: rest -> (match AssocTable.lookup hash_node (==) a t with None -> loop (AssocTable.insert hash_node a [node] t) rest | Some c -> loop (AssocTable.insert hash_node a (node :: c) t) rest) in visit (node :: visited) (AssocTable.insert hash_node node () vtable) (loop parent_table children) (children @ rest) let make_reverser parent_table = let rec termM node candidate_parent = match candidate_parent with Store (_, n) when n == node -> dualM candidate_parent >>= fun x' -> returnM [x'] | Plus (l) when List.memq node l -> dualM candidate_parent >>= fun x' -> returnM [x'] | Times (a, b) when b == node -> dualM candidate_parent >>= fun x' -> returnM [makeTimes (a, x')] | Uminus n when n == node -> dualM candidate_parent >>= fun x' -> returnM [makeUminus x'] | _ -> returnM [] and dualExpressionM this_node = mapM (termM this_node) (match AssocTable.lookup hash_node (==) this_node parent_table with Some a -> a | None -> failwith "bug in dualExpressionM" ) >>= fun l -> returnM (makePlus (List.flatten l)) and dualM this_node = memoizing lookupDualsM insertDualsM (function Load v as x -> if (Variable.is_twiddle v) then returnM (Load v) else (dualExpressionM x >>= fun d -> returnM (Store (v, d))) | Store (v, x) -> returnM (Load v) | x -> dualExpressionM x) this_node in dualM let is_store = function Store _ -> true | _ -> false let reverse (Dag dag) = let (all_nodes, parent_table) = visit [] empty empty dag in let reverserM = make_reverser parent_table in let mapReverserM = mapM reverserM in let duals = runM empty mapReverserM all_nodes in let roots = filter is_store duals in Dag rootsend(************************************************************* * Various dag statistics *************************************************************)module Stats : sig type complexity val complexity : dag -> complexity val same_complexity : complexity -> complexity -> boolend = struct type complexity = int * int * int * int * int * int let rec visit visited vtable = function [] -> visited | node :: rest -> match AssocTable.lookup hash_node (==) node vtable with Some _ -> visit visited vtable rest | None -> let children = match node with Store (v, n) -> [n] | Plus l -> l | Times (a, b) -> [a; b] | Uminus x -> [x] | _ -> [] in visit (node :: visited) (AssocTable.insert hash_node node () vtable) (children @ rest) let complexity (Dag dag) = let rec loop (load, store, plus, times, uminus, num) = function [] -> (load, store, plus, times, uminus, num) | node :: rest -> loop (match node with Load _ -> (load + 1, store, plus, times, uminus, num) | Store _ -> (load, store + 1, plus, times, uminus, num) | Plus _ -> (load, store, plus + 1, times, uminus, num) | Times _ -> (load, store, plus, times + 1, uminus, num) | Uminus _ -> (load, store, plus, times, uminus + 1, num) | Num _ -> (load, store, plus, times, uminus, num + 1)) rest in let (l, s, p, t, u, n) = loop (0, 0, 0, 0, 0, 0) (visit [] AssocTable.empty dag) in (l, s, p, t, u, n) let same_complexity a b = (a = b)end (************************************************************* * Algebraic simplifier/elimination of common subexpressions *************************************************************)module AlgSimp : sig val algsimp : dag -> dagend = struct open StateMonad open MemoMonad open AssocTable let fetchSimp = fetchState >>= fun (s, _) -> returnM s let storeSimp s = fetchState >>= (fun (_, c) -> storeState (s, c)) let lookupSimpM key = fetchSimp >>= fun table -> returnM (node_lookup key table) let insertSimpM key value = fetchSimp >>= fun table -> storeSimp (node_insert key value table) let subset a b = List.for_all (fun x -> List.exists (fun y -> x == y) b) a let equalCSE a b = match (a, b) with (Num a, Num b) -> Number.equal a b | (Load a, Load b) -> Variable.same a b && (!Magic.collect_common_twiddle or not (Variable.is_twiddle a)) && (!Magic.collect_common_inputs or not (Variable.is_input a)) | (Times (a, a'), Times (b, b')) -> ((a == b) && (a' == b')) or ((a == b') && (a' == b)) | (Plus a, Plus b) -> subset a b && subset b a | (Uminus a, Uminus b) -> (a == b) | _ -> false let fetchCSE = fetchState >>= fun (_, c) -> returnM c let storeCSE c = fetchState >>= (fun (s, _) -> storeState (s, c)) let lookupCSEM key = fetchCSE >>= fun table -> returnM (AssocTable.lookup hash_node equalCSE key table) let insertCSEM key value = fetchCSE >>= fun table -> storeCSE (AssocTable.insert hash_node key value table) (* memoize both x and Uminus x (unless x is already negated) *) let identityM x = let memo x = memoizing lookupCSEM insertCSEM returnM x in match x with Uminus _ -> memo x | _ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x' let makeNode = identityM
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -