blob: 2b25ab51a77c2dbb1a74ce5b4f97dfd48d585933 [file] [log] [blame] [edit]
open Ast
open Print
open Walk
open Util
open Source
open Xl
let (let*) = Option.bind
module Subst = struct
include Map.Make(String)
let subst_exp s e =
let subst_exp' walker e =
match e.it with
| VarE id when mem id s -> find id s
| _ -> base_walker.walk_expr walker e
in
let walker = {base_walker with walk_expr = subst_exp'} in
walker.walk_expr walker e
end
let rec get_subst lhs rhs s =
match lhs.it, rhs.it with
| VarE id, _ -> Some (Subst.add id rhs s)
| CvtE (e1, nt11, nt12), CvtE (e2, nt21, nt22) when nt11 = nt21 && nt12 = nt22 -> get_subst e1 e2 s
| UnE (op1, e1), UnE (op2, e2) when op1 = op2 -> get_subst e1 e2 s
| OptE (Some e1), OptE (Some e2) ->
get_subst e1 e2 s
| BinE (op1, e11, e12), BinE (op2, e21, e22) when op1 = op2 ->
let* s = s |> get_subst e11 e21 in
get_subst e12 e22 s
| CompE (e11, e12), CompE (e21, e22) | CatE (e11, e12), CatE (e21, e22) ->
let* s = s |> get_subst e11 e21 in
get_subst e12 e22 s
| TupE el1, TupE el2 | ListE el1, ListE el2 ->
List.fold_right2 (fun e1 e2 s -> let* s = s in get_subst e1 e2 s)
el1 el2 (Some s)
| CaseE (name1, el1), CaseE (name2, el2) when name1 = name2 ->
List.fold_right2 (fun e1 e2 s -> let* s = s in get_subst e1 e2 s)
el1 el2 (Some s)
| StrE r1, StrE r2 ->
List.fold_left (fun s (k, e) -> let* s = s in get_subst !e (Record.find k r2) s) (Some s) r1
(* | IterE _, _ -> (* TODO *) s *)
| _, _ when Eq.eq_expr lhs rhs -> Some s
| _ -> None
let get_subst_arg param arg s =
match param.it, arg.it with
| ExpA e1, ExpA e2 -> get_subst e1 e2 s
| _ -> Some s
let ($>) it e = {e with it}
let of_bool_exp = function
| BoolE b -> Some b
| _ -> None
let of_num_exp = function
| NumE n -> Some n
| _ -> None
let to_bool_exp b = BoolE b
let to_num_exp n = NumE n
let as_opt_exp e =
match e.it with
| OptE eo -> eo
| _ -> failwith "as_opt_exp"
let as_list_exp e =
match e.it with
| ListE es -> es
| _ -> failwith "as_list_exp"
let is_head_normal_exp e =
match e.it with
| BoolE _ | NumE _ | SubE _
| OptE _ | ListE _ | TupE _ | CaseE _ | StrE _-> true
| _ -> false
let rec is_normal_exp e =
match e.it with
| BoolE _ | NumE _ | SubE _ -> true
| ListE es | TupE es | CaseE (_, es) -> List.for_all is_normal_exp es
| OptE None -> true
| OptE (Some e) -> is_normal_exp e
| StrE efs -> List.for_all (fun (_, e) -> is_normal_exp !e) efs
| _ -> false
let rec reduce_exp s e : expr =
Debug_log.(log "al.reduce_exp"
(fun _ -> fmt "%s" (string_of_expr e))
(fun e' -> fmt "%s" (string_of_expr e'))
) @@ fun _ ->
match e.it with
| VarE _ | BoolE _ | NumE _ -> e
| CvtE (e1, _, nt) ->
let e1' = reduce_exp s e1 in
(match e1'.it with
| NumE n ->
(match Num.cvt nt n with
| Some n' -> NumE n' $> e
| None -> e1'
)
| _ -> e1'
)
| UnE (op, e1) ->
let e1' = reduce_exp s e1 in
(match op, e1'.it with
| #Bool.unop as op', BoolE b1 -> BoolE (Bool.un op' b1) $> e
| #Num.unop as op', NumE n1 ->
(match Num.un op' n1 with
| Some n -> NumE n
| None -> UnE (op, e1')
) $> e
| `NotOp, UnE (`NotOp, e11') -> e11'
| `MinusOp, UnE (`MinusOp, e11') -> e11'
| _ -> UnE (op, e1') $> e
)
| BinE (op, e1, e2) ->
let e1' = reduce_exp s e1 in
let e2' = reduce_exp s e2 in
(match op with
| #Bool.binop as op' ->
(match Bool.bin_partial op' e1'.it e2'.it of_bool_exp to_bool_exp with
| None -> BinE (op, e1', e2')
| Some e' -> e'
)
| #Num.binop as op' ->
(match Num.bin_partial op' e1'.it e2'.it of_num_exp to_num_exp with
| None -> BinE (op, e1', e2')
| Some e' -> e'
)
| #Num.cmpop as op' ->
(match of_num_exp e1'.it, of_num_exp e2'.it with
| Some n1, Some n2 ->
(match Num.cmp op' n1 n2 with
| Some b -> to_bool_exp b
| None -> BinE (op, e1', e2')
)
| _, _ -> BinE (op, e1', e2')
)
| `EqOp when Eq.eq_expr e1' e2' -> BoolE true
| `NeOp when Eq.eq_expr e1' e2' -> BoolE false
| `EqOp when is_normal_exp e1' && is_normal_exp e2' -> BoolE false
| `NeOp when is_normal_exp e1' && is_normal_exp e2' -> BoolE true
| #Bool.cmpop -> BinE (op, e1', e2')
) $> e
| AccE (e1, p) ->
(match p.it with
| IdxP e2 ->
let e1' = reduce_exp s e1 in
let e2' = reduce_exp s e2 in
(match e1'.it, e2'.it with
| ListE es, NumE (`Nat i) when i < Z.of_int (List.length es) -> List.nth es (Z.to_int i)
| _ -> AccE (e1', IdxP e2' $ p.at) $> e
)
| SliceP (e2, e3) ->
let e1' = reduce_exp s e1 in
let e2' = reduce_exp s e2 in
let e3' = reduce_exp s e3 in
(match e1'.it, e2'.it, e3'.it with
| ListE es, NumE (`Nat i), NumE (`Nat n) when Z.(i + n) < Z.of_int (List.length es) ->
ListE (Lib.List.take (Z.to_int n) (Lib.List.drop (Z.to_int i) es))
| _ -> AccE (e1', SliceP (e2', e3') $ p.at)
) $> e
| DotP atom ->
let e1' = reduce_exp s e1 in
(match e1'.it with
| StrE efs -> !(snd (List.find (fun (atomN, _) -> Atom.eq atomN atom) efs))
| _ -> AccE (e1', DotP atom $ p.at) $> e
)
)
| UpdE (e1, p, e2) ->
let e1' = reduce_exp s e1 in
let e2' = reduce_exp s e2 in
reduce_path s e1' p
(fun e' ps -> if ps = [] then e2' else UpdE (e', ps, e2') $> e')
| ExtE (_e1, _p, _e2, _dir) -> e
(* TODO
let e1' = reduce_exp s e1 in
let e2' = reduce_exp s e2 in
reduce_path s e1' p
(fun e' p' ->
if p'.it = RootP
then reduce_exp s (CatE (e', e2') $> e')
else ExtE (e', p', e2') $> e'
)
*)
| StrE efs -> StrE (List.map (reduce_expfield s) efs) $> e
| CompE (e1, e2) ->
(* TODO(4, rossberg): avoid overlap with CatE? *)
let e1' = reduce_exp s e1 in
let e2' = reduce_exp s e2 in
(match e1'.it, e2'.it with
| ListE es1, ListE es2 -> ListE (es1 @ es2)
| OptE None, OptE _ -> e2'.it
| OptE _, OptE None -> e1'.it
| StrE efs1, StrE efs2 ->
let merge (atom1, e1) (atom2, e2) =
assert (Atom.eq atom1 atom2);
(atom1, ref (reduce_exp s (CompE (!e1, !e2) $> !e1)))
in StrE (List.map2 merge efs1 efs2)
| _ -> CompE (e1', e2')
) $> e
| MemE (e1, e2) ->
let e1' = reduce_exp s e1 in
let e2' = reduce_exp s e2 in
(match e2'.it with
| OptE None -> BoolE false
| OptE (Some e2') when Eq.eq_expr e1' e2' -> BoolE true
| OptE (Some e2') when is_normal_exp e1' && is_normal_exp e2' -> BoolE false
| ListE [] -> BoolE false
| ListE es2' when List.exists (Eq.eq_expr e1') es2' -> BoolE true
| ListE es2' when is_normal_exp e1' && List.for_all is_normal_exp es2' -> BoolE false
| _ -> MemE (e1', e2')
) $> e
| LiftE e1 ->
let e1' = reduce_exp s e1 in
(match e1'.it with
| OptE None -> ListE []
| OptE (Some e11) -> ListE [e11]
| _ -> LiftE e1'
) $> e
| LenE e1 ->
let e1' = reduce_exp s e1 in
(match e1'.it with
| ListE es -> NumE (`Nat (Z.of_int (List.length es)))
| _ -> LenE e1'
) $> e
| TupE es -> TupE (List.map (reduce_exp s) es) $> e
| CallE (id, args) ->
let args' = List.map (reduce_arg s) args in
(match reduce_call id args' with
| None -> CallE (id, args') $> e
| Some e -> e
)
| IterE (e1, iterexp) ->
let e1' = reduce_exp s e1 in
let (iter', xes') as iterexp' = reduce_iterexp s iterexp in
let ids, es' = List.split xes' in
if not (List.for_all is_head_normal_exp es') || iter' <= List1 && es' = [] then
IterE (e1', iterexp') $> e
else
(match iter' with
| Opt ->
let eos' = List.map as_opt_exp es' in
if List.for_all Option.is_none eos' then
OptE None $> e
else if List.for_all Option.is_some eos' then
let es1' = List.map Option.get eos' in
let s = List.fold_right2 Subst.add ids es1' Subst.empty in
reduce_exp s (Subst.subst_exp s e1')
else
IterE (e1', iterexp') $> e
| List | List1 ->
let n = List.length (as_list_exp (List.hd es')) in
if iter' = List || n >= 1 then
let en = NumE (`Nat (Z.of_int n)) $$ e.at % (Il.Ast.NumT `NatT $ e.at) in
reduce_exp s (IterE (e1', (ListN (en, None), xes')) $> e)
else
IterE (e1', iterexp') $> e
| ListN ({it = NumE (`Nat n'); _}, ido) ->
let ess' = List.map as_list_exp es' in
let ns = List.map List.length ess' in
let n = Z.to_int n' in
if List.for_all ((=) n) ns then
(TupE (List.init n (fun i ->
let esI' = List.map (fun es -> List.nth es i) ess' in
let s = List.fold_right2 Subst.add ids esI' Subst.empty in
let s' =
Option.fold ido ~none:s ~some:(fun id ->
let en = NumE (`Nat (Z.of_int i)) $$ no_region % (Il.Ast.NumT `NatT $ no_region) in
Subst.add id en s
)
in Subst.subst_exp s' e1'
)) $> e) |> reduce_exp s
else
IterE (e1', iterexp') $> e
| ListN _ ->
IterE (e1', iterexp') $> e
)
| OptE eo -> OptE (Option.map (reduce_exp s) eo) $> e
| ListE es -> ListE (List.map (reduce_exp s) es) $> e
| CatE (e1, e2) ->
let e1' = reduce_exp s e1 in
let e2' = reduce_exp s e2 in
(match e1'.it, e2'.it with
| ListE es1, ListE es2 -> ListE (es1 @ es2)
| OptE None, OptE _ -> e2'.it
| OptE _, OptE None -> e1'.it
| _ -> CatE (e1', e2')
) $> e
| CaseE (op, es) -> CaseE (op, List.map (reduce_exp s) es) $> e
| SubE _ -> e
| _ -> e
and reduce_iter s = function
| ListN (e, ido) -> ListN (reduce_exp s e, ido)
| iter -> iter
and reduce_iterexp s (iter, xes) =
(reduce_iter s iter, List.map (fun (id, e) -> id, reduce_exp s e) xes)
and reduce_expfield s (atom, e) : (atom * expr ref) = (atom, ref (reduce_exp s !e))
and reduce_path s e p f =
match Lib.List.split_last_opt p with
| None -> f e []
| Some (ps, p') -> match p'.it with
| IdxP e1 ->
let e1' = reduce_exp s e1 in
let f' e' p1' =
match e'.it, e1'.it with
| ListE es, NumE (`Nat i) when i < Z.of_int (List.length es) ->
ListE (List.mapi (fun j eJ -> if Z.of_int j = i then f eJ p1' else eJ) es) $> e'
| _ ->
f e' (ps @ [IdxP (e1') $> p'])
in
reduce_path s e ps f'
| SliceP (e1, e2) ->
let e1' = reduce_exp s e1 in
let e2' = reduce_exp s e2 in
let f' e' p1' =
match e'.it, e1'.it, e2'.it with
| ListE es, NumE (`Nat i), NumE (`Nat n) when Z.(i + n) < Z.of_int (List.length es) ->
let e1' = ListE Lib.List.(take (Z.to_int i) es) $> e' in
let e2' = ListE Lib.List.(take (Z.to_int n) (drop (Z.to_int i) es)) $> e' in
let e3' = ListE Lib.List.(drop Z.(to_int (i + n)) es) $> e' in
reduce_exp s (CatE (e1', CatE (f e2' p1', e3') $> e') $> e')
| _ ->
f e' (ps @ [SliceP (e1', e2') $> p'])
in
reduce_path s e ps f'
| DotP atom ->
let f' e' p1' =
match e'.it with
| StrE efs ->
StrE (List.map (fun (atomI, eI) ->
if Atom.eq atomI atom then (atomI, ref (f !eI p1')) else (atomI, eI)) efs) $> e'
| _ ->
f e' (ps @ [DotP (atom) $> p'])
in
reduce_path s e ps f'
and reduce_arg s a : arg =
Debug_log.(log "al.reduce_arg"
(fun _ -> fmt "%s" (string_of_arg a))
(fun a' -> fmt "%s" (string_of_arg a'))
) @@ fun _ ->
match a.it with
| ExpA e -> ExpA (reduce_exp s e) $ a.at
| TypA _t -> a (* types are reduced on demand *)
| DefA _id -> a
(* | GramA _g -> a *)
and reduce_call id args : expr option =
let func_finder = fun al -> match al.it with | FuncA (fname, _, _) -> fname = id | RuleA _ -> false in
match (List.find func_finder !Lang.al).it with
| FuncA (_, params, il) ->
let* s = List.fold_right2
(fun p a s -> let* s = s in get_subst_arg p a s)
params args (Some Subst.empty) in
reduce_instrs s il
| _ -> assert (false)
and reduce_instrs s : instr list -> expr option = function
| [] -> None
| instr :: t ->
match instr.it with
| ReturnI expr_opt -> Option.map (reduce_exp s) expr_opt
| LetI (expr1, expr2) ->
let new_s = get_subst expr1 expr2 s in
Option.fold ~some:(fun s -> reduce_instrs s t) ~none:None new_s
| IfI (expr, il1, il2) ->
(* TODO: consider iter *)
(match (reduce_exp s expr).it with
| BoolE true -> reduce_instrs s (il1@t)
| BoolE false -> reduce_instrs s (il2@t)
| _ -> None
)
(* Can have side effect *)
| EitherI _ | PerformI _ | ReplaceI _ | AppendI _ -> None
(* Invalid instruction in FuncA *)
| EnterI _ | PushI _ | PopI _ | PopAllI _ | TrapI | FailI | ThrowI _
| ExecuteI _ | ExecuteSeqI _ | ExitI _ | ForEachI _ | OtherwiseI _ | YetI _ -> assert (false)
(* Nop *)
| (AssertI _ | NopI) -> reduce_instrs s t