diff --git a/lib/AstToMiniRust.ml b/lib/AstToMiniRust.ml index dc64d70b..d27a9dec 100644 --- a/lib/AstToMiniRust.ml +++ b/lib/AstToMiniRust.ml @@ -410,11 +410,8 @@ let translate_unknown_lid (m, n) = let m = compress_prefix m in List.map String.lowercase_ascii m @ [ n ] -let borrow_kind_of_bool b: MiniRust.borrow_kind = - if b (* const *) then - Shared - else - Mut +let borrow_kind_of_bool _b: MiniRust.borrow_kind = + Shared type config = { box: bool; @@ -511,7 +508,7 @@ let lookup_split (env: env) (v_base: MiniRust.db_index) (path: Splits.root_or_pa | Some (v_base', path') when v_base' = v_base - ofs - 1 -> begin match Splits.accessible_via path path' with | Some pe -> - MiniRust.(Field (Var ofs, Splits.string_of_path_elem pe)) + MiniRust.(Field (Var ofs, Splits.string_of_path_elem pe, None)) | None -> find (ofs + 1) bs end @@ -590,7 +587,7 @@ and translate_array (env: env) is_toplevel (init: Ast.expr): env * MiniRust.expr necessitate the insertion of conversions *) and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env * MiniRust.expr = let erase_lifetime_info = (object(self) - inherit [_] MiniRust.map + inherit [_] MiniRust.DeBruijn.map method! visit_Ref env _ bk t = Ref (None, bk, self#visit_typ env t) method! visit_tname _ n _ = Name (n, []) end)#visit_typ () @@ -615,18 +612,17 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env do retain them in our Function type -- so we need to relax the comparison here *) x (* More conversions due to box-ing types. *) - | _, App (Name (["Box"], _), [Slice _]), Ref (_, Mut, Slice _) -> - Borrow (Mut, Deref x) - | _, Ref (_, Mut, Slice _), App (Name (["Box"], _), [Slice _]) -> + | _, App (Name (["Box"], _), [Slice _]), Ref (_, k, Slice _) -> + Borrow (k, Deref x) + | _, Ref (_, _, Slice _), App (Name (["Box"], _), [Slice _]) -> MethodCall (Borrow (Shared, Deref x), ["into"], []) - | _, Ref (_, Shared, Slice _), App (Name (["Box"], _), [Slice _]) -> - MethodCall (x, ["into"], []) + (* | _, Ref (_, Shared, Slice _), App (Name (["Box"], _), [Slice _]) -> *) + (* MethodCall (x, ["into"], []) *) | _, Vec _, App (Name (["Box"], _), [Slice _]) -> MethodCall (MethodCall (x, ["try_into"], []), ["unwrap"], []) (* More conversions due to vec-ing types *) - | _, Ref (_, Mut, Slice _), Vec _ - | _, Ref (_, Shared, Slice _), Vec _ -> + | _, Ref (_, _, Slice _), Vec _ -> MethodCall (x, ["to_vec"], []) | _, Array _, Vec _ -> Call (Name ["Vec"; "from"], [], [x]) @@ -644,8 +640,8 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env | _, Ref (_, _, t), t' when t = t' -> Deref x - | Borrow (Mut, e) , Ref (_, _, t), Ref (_, _, Slice t') when t = t' -> - Borrow (Mut, Array (List [ e ])) + | Borrow (k, e) , Ref (_, _, t), Ref (_, _, Slice t') when t = t' -> + Borrow (k, Array (List [ e ])) | _ -> (* If we reach this case, we perform one last try by erasing the lifetime @@ -727,8 +723,9 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env | EApp ({ node = EQualified ([ "LowStar"; "BufferOps" ], s); _ }, e1 :: e2 :: _ ) when KString.starts_with s "op_Star_Equals__" -> let env, e1 = translate_expr env e1 in + let t2 = translate_type env e2.typ in let env, e2 = translate_expr env e2 in - env, Assign (Index (e1, MiniRust.zero_usize), e2) + env, Assign (Index (e1, MiniRust.zero_usize), e2, t2) | EApp ({ node = ETApp (e, cgs, cgs', ts); _ }, es) -> assert (cgs @ cgs' = []); @@ -805,10 +802,10 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env | [] -> failwith "[ELet] unexpected: path not found" in - env, MiniRust.(Field (find 0 env.vars, Splits.string_of_path_elem path_elem)) + env, MiniRust.(Field (find 0 env.vars, Splits.string_of_path_elem path_elem, None)) in - let split_at = match b.typ with TBuf (_, true) -> "split_at" | _ -> "split_at_mut" in + let split_at = "split_at" in let e1 = MiniRust.MethodCall (e_nearest , [split_at], [ index ]) in let t = translate_type env b.typ in let binding : MiniRust.binding * Splits.info = @@ -832,7 +829,7 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env let env, e1, t = translate_array env false init in (* KPrint.bprintf "Let %s: %a\n" b.node.name PrintMiniRust.ptyp t; *) - let binding: MiniRust.binding = { name = b.node.name; typ = t; mut = true } in + let binding: MiniRust.binding = { name = b.node.name; typ = t; mut = false } in let env = push env binding in env0, Let (binding, e1, snd (translate_expr_with_type env e2 t_ret)) @@ -847,6 +844,8 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env | TQualified lid when Idents.LidSet.mem lid env.heap_structs -> true | _ -> false in + (* TODO how does this play out with the new "translate as non-mut by + default" strategy? *) let mut = b.node.mut || is_owned_struct in (* Here, the idea is to detect forbidden move-outs that are certain to result in a compilation error. Typically, selecting a field, dereferencing an array, etc. when the underlying type @@ -858,7 +857,7 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env cannot be mutated in place in Low*, so it's ok to borrow instead of copy. *) let e1, t = match e1 with | (Field _ | Index _) when is_owned_struct -> - MiniRust.(Borrow (Mut, e1), Ref (None, Mut, t)) + MiniRust.(Borrow (Shared, e1), Ref (None, Shared, t)) | _ -> e1, t in @@ -883,7 +882,7 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env in let env, e1 = translate_expr_with_type env e1 lvalue_type in let env, e2 = translate_expr_with_type env e2 lvalue_type in - env, Assign (e1, e2) + env, Assign (e1, e2, lvalue_type) | EBufCreate _ -> failwith "unexpected: EBufCreate" | EBufCreateL _ -> @@ -895,19 +894,16 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env | EBufWrite (e1, e2, e3) -> let env, e1 = translate_expr env e1 in let env, e2 = translate_expr_with_type env e2 (Constant SizeT) in + let t3 = translate_type env e3.typ in let env, e3 = translate_expr env e3 in - env, Assign (Index (e1, e2), e3) + env, Assign (Index (e1, e2), e3, t3) | EBufSub (e1, e2) -> (* This is a fallback for the analysis above. Happens if, for instance, the pointer arithmetic appears in subexpression position (like, function call), in which case there's a chance this might still work! *) - let is_const_tbuf = match e1.typ with - | TBuf (_, b) -> b - | _ -> false - in let env, e1 = translate_expr env e1 in let env, e2 = translate_expr_with_type env e2 (Constant SizeT) in - env, Borrow ((if is_const_tbuf then Shared else Mut), Index (e1, Range (Some e2, None, false))) + env, Borrow (Shared, Index (e1, Range (Some e2, None, false))) | EBufDiff _ -> failwith "unexpected: EBufDiff" (* Silly pattern in Low*: for historical reasons, the blit operations takes a @@ -944,7 +940,7 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env (* Rather than error out, we do nothing, as some functions may allocate then free. *) env, Unit | EBufNull -> - env, possibly_convert (Borrow (Mut, Array (List []))) (translate_type env e.typ) + env, possibly_convert (Borrow (Shared, Array (List []))) (translate_type env e.typ) | EPushFrame -> failwith "unexpected: EPushFrame" | EPopFrame -> @@ -1000,7 +996,8 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env | EField (e, f) -> let env, e_ = translate_expr env e in - env, possibly_convert (Field (e_, f)) (field_type env e f) + let t = translate_type env e.typ in + env, possibly_convert (Field (e_, f, Some t)) (field_type env e f) | EBreak -> failwith "TODO: EBreak" @@ -1011,66 +1008,10 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env | EWhile _ -> failwith "TODO: EWhile" - (* Loop with all constant bounds. *) - | EFor (b, - ({ node = EConstant (UInt32, init as k_init); _ } as e_init), - { node = EApp ( - { node = EOp (Lt, UInt32); _ }, - [{ node = EBound 0; _ }; - ({ node = EConstant (UInt32, max); _ })]); _}, - { node = EAssign ( - { node = EBound 0; _ }, - { node = EApp ( - { node = EOp (Add, UInt32); _ }, - [{ node = EBound 0; _ }; - ({ node = EConstant (UInt32, incr as k_incr); _ })]); _}); _}, - body) - when ( - let init = int_of_string init in - let max = int_of_string max in - let incr = int_of_string incr in - let n_loops = (max - init + incr - 1) / incr in - n_loops <= !Options.unroll_loops - ) - -> - (* Keep initial environment to return after translation *) - let env0 = env in - - let init = int_of_string init in - let max = int_of_string max in - let incr = int_of_string incr in - let n_loops = (max - init + incr - 1) / incr in - - if n_loops = 0 then - env, Unit - - else if n_loops = 1 then - let body = DeBruijn.subst e_init 0 body in - translate_expr env body - - else begin - let unused = snd !(b.node.mark) = AtMost 3 in - (* We do an ad-hoc thing since this didn't go through lowstar.ignore - insertion. Rust uses the OCaml convention (which I recall I did suggest - to Graydon back in 2010). *) - let unused = if unused then "_" else "" in - let b: MiniRust.binding = { name = unused ^ b.node.name; typ = translate_type env b.typ; mut = false } in - let _, body = translate_expr (push env b) body in - - (* This is a weird node, because it contains a binder, but does not rely - on a MiniRust.binding to encode that fact. For that reason, the - printer needs to be kept in sync and catch this special application - node. We could do this rewrite on the fly when pretty-printing, but - we'd have to retain the number of uses (above) in the node to figure - out whether to make the binder ignored or not. *) - env0, Call (Name ["krml"; "unroll_for!"], [], [ - Constant (CInt, string_of_int n_loops); - ConstantString b.name; - Constant k_init; - Constant k_incr; - body ]) - end - + (* The introduction of the unroll_loops macro requires a "fake" binder + for the iterated value, which messes up with variable substitutions + mutability inference. We instead perform it in RustMacroize, after + all substitutions are done. *) | EFor (b, e_start, e_test, e_incr, e_body) -> (* Keep initial environment to return after translation *) let env0 = env in diff --git a/lib/MiniRust.ml b/lib/MiniRust.ml index 0fc57af7..c2f07500 100644 --- a/lib/MiniRust.ml +++ b/lib/MiniRust.ml @@ -1,5 +1,10 @@ (* A minimalistic representation of Rust *) +module Name = struct + type t = string list + let compare = compare +end + type borrow_kind_ = Mut | Shared [@@deriving show] @@ -7,7 +12,7 @@ type borrow_kind = borrow_kind_ [@ opaque ] and constant = Constant.t [@ opaque ] and width = Constant.width [@ opaque ] and op = Constant.op [@ opaque ] -and name = string list [@ opaque ] +and name = Name.t [@ opaque ] (* Some design choices. - We don't intend to perform any deep semantic transformations on this Rust @@ -48,7 +53,10 @@ let box t = App (Name (["Box"], []), [t]) let bool = Constant Bool +let u8 = Constant UInt8 +let u16 = Constant UInt16 let u32 = Constant UInt32 +let u64 = Constant UInt64 let usize = Constant SizeT type binding = { name: string; typ: typ; mut: bool } @@ -79,7 +87,7 @@ and expr = | Unit | Panic of string | IfThenElse of expr * expr * expr option - | Assign of expr * expr + | Assign of expr * expr * typ | As of expr * typ | For of binding * expr * expr | While of expr * expr @@ -90,8 +98,12 @@ and expr = (* Place expressions *) | Var of db_index + | Open of open_var | Index of expr * expr - | Field of expr * string + (* The type corresponds to the structure we are accessing. + We will store None when accessing a native Rust tuple, + corresponding to an array slice *) + | Field of expr * string * typ option (* Operator expressions *) | Operator of op @@ -104,6 +116,12 @@ and pat = | Wildcard | StructP of name (* TODO *) +and open_var = { + name: string; + atom: atom_t +} + +and atom_t = Atom.t [@ visitors.opaque] (* TODO: visitors incompatible with inline records *) type decl = @@ -158,48 +176,92 @@ and trait = (* Some visitors for name management *) -(* A usable map where the user can hook up to extend, called every time a new - binding is added to the environment *) -class ['self] map = object (self: 'self) - inherit [_] map_expr as super - - (* To be overridden by the user *) - method extend env _ = env +module DeBruijn = struct + + (* A usable map where the user can hook up to extend, called every time a new + binding is added to the environment *) + class ['self] map = object (self: 'self) + inherit [_] map_expr as _super + + (* To be overridden by the user *) + method extend env _ = env + + (* We list all binding nodes and feed those binders to the environment *) + method! visit_Let env b e1 e2 = + let e1 = self#visit_expr env e1 in + let e2 = self#visit_expr (self#extend env b) e2 in + Let (b, e1, e2) + + method! visit_For env b e1 e2 = + let e1 = self#visit_expr env e1 in + let e2 = self#visit_expr (self#extend env b) e2 in + For (b, e1, e2) + end + + class map_counting = object + (* The environment [i] has type [int]. *) + inherit [_] map + + (* The environment [i] keeps track of how many binders have been + entered. It is incremented at each binder. *) + method! extend (i: int) (_: binding) = + i + 1 + end + + class lift (k: int) = object + inherit map_counting + (* A local variable (one that is less than [i]) is unaffected; + a free variable is lifted up by [k]. *) + method! visit_Var i j = + if j < i then + Var j + else + Var (j + k) + end + + class close (a: Atom.t) (e: expr) = object + inherit map_counting + + method! visit_Open i ({ atom; _ } as v) = + if Atom.equal a atom then + (new lift i)#visit_expr 0 e + else + Open v + end + + class subst e2 = object + inherit map_counting + + method! visit_Var i j = + if j = i then + (new lift i)#visit_expr 0 e2 + else + Var (if j < i then j else j - 1) + end - (* We list all binding nodes and feed those binders to the environment *) - method! visit_Let env b e1 e2 = - super#visit_Let (self#extend env b) b e1 e2 - - method! visit_For env b e1 e2 = - super#visit_For (self#extend env b) b e1 e2 -end - -class map_counting = object - (* The environment [i] has type [int]. *) - inherit [_] map - - (* The environment [i] keeps track of how many binders have been - entered. It is incremented at each binder. *) - method! extend (i: int) (_: binding) = - i + 1 -end - -class lift (k: int) = object - inherit map_counting - (* A local variable (one that is less than [i]) is unaffected; - a free variable is lifted up by [k]. *) - method! visit_Var i j = - if j < i then - Var j - else - Var (j + k) end +(* Lift `expr` by `k` places so as to place it underneath `k` additional + binders. *) let lift (k: int) (expr: expr): expr = if k = 0 then expr else - (new lift k)#visit_expr 0 expr + (new DeBruijn.lift k)#visit_expr 0 expr + +(* Close `a`, replacing it on the fly with `e2` in `e1` *) +let close a e2 e1 = + (new DeBruijn.close a e2)#visit_expr 0 e1 + +(* Substitute `e2` for bound variable `i` in `e1` *) +let subst e2 i e1 = + (new DeBruijn.subst e2)#visit_expr i e1 + +(* Open b in e2, replacing occurrences of a bound variable with the + corresponding atom. *) +let open_ (b: binding) e2 = + let atom = Atom.fresh () in + atom, subst (Open { atom; name = b.name }) 0 e2 (* Helpers *) diff --git a/lib/OptimizeMiniRust.ml b/lib/OptimizeMiniRust.ml new file mode 100644 index 00000000..494d77e3 --- /dev/null +++ b/lib/OptimizeMiniRust.ml @@ -0,0 +1,789 @@ +(* AstToMiniRust generates code that only uses shared borrows; that is obviously + incorrect, and so the purpose of this phase is to infer the minimum number of + variables that need to be marked as `mut`, and the minimum number of borrows + that need themselves to be `&mut`. + + This improves on an earlier iteration of the compilation scheme where + everything was marked as mutable by default, a conservative, but suboptimal + choice. + + We proceed as follows. We carry two sets of variables: + - V stands for mutable variables, i.e. the set of variables that need to + marked as mut using `let mut x = ...`. A variable needs to be marked as mut + if it is mutably-borrowed, i.e. if `&mut x` occurs. + - R stands for mutable references, i.e. the set of variables that have type + `&mut T`. R is initially populated with function parameters. + This is the state of our transformation, and as such, we return an augmented + state after performing our inference, so that the callee can mark variables + accordingly. + + An environment keeps track of the functions that have been visited already, + along with their updated types. + + Finally, the transformation receives a contextual flag as an input parameter; + the flag indicates whether the subexpression being visited (e.g. &x) needs to + return a mutable borrow, meaning it gets rewritten (e.g. into &mut x) and the + set V increases (because the Rust rule is that you can only write `&mut x` if + `x` itself is declared with `let mut`). +*) + +open MiniRust + +module NameMap = Map.Make(Name) +module VarSet = Set.Make(Atom) + +type env = { + seen: typ list NameMap.t; + structs: MiniRust.struct_field list NameMap.t; +} + +type known = { + structs: MiniRust.struct_field list NameMap.t; + v: VarSet.t; + r: VarSet.t; +} + +let assert_borrow = function + | Ref (_, _, t) -> t + | _ -> failwith "impossible: assert_borrow" + +let assert_name (t: typ option) = match t with + | Some (Name (n, _)) -> n + | _ -> failwith "impossible: assert_name" + +let add_mut_var a known = + (* KPrint.bprintf "%s is let mut\n" (Ast.show_atom_t a); *) + { known with v = VarSet.add a known.v } + +let add_mut_borrow a known = + (* KPrint.bprintf "%s is &mut\n" (Ast.show_atom_t a); *) + { known with r = VarSet.add a known.r } + +let want_mut_var a known = + VarSet.mem a known.v + +let want_mut_borrow a known = + VarSet.mem a known.r + +let is_mut_borrow = function + | Ref (_, Mut, _) -> true + (* Special-case for tuples; they should only occur with array slices *) + | Tuple [Ref (_, Mut, _); Ref (_, Mut, _)] -> true + | _ -> false + +let make_mut_borrow = function + | Ref (l, _, t) -> Ref (l, Mut, t) + | Tuple [Ref (l1, _, t1); Ref (l2, _, t2)] -> Tuple [Ref (l1, Mut, t1); Ref (l2, Mut, t2)] + | Vec t -> Vec t + | _ -> failwith "impossible: make_mut_borrow" + +let add_mut_field ty f known = + let n = assert_name ty in + let fields = NameMap.find n known.structs in + (* Update the mutability of the field element *) + let fields = List.map (fun (sf: MiniRust.struct_field) -> + if sf.name = f then {sf with typ = make_mut_borrow sf.typ} else sf) fields in + {known with structs = NameMap.add n fields known.structs} + +let retrieve_pair_type = function + | Tuple [e1; e2] -> assert (e1 = e2); e1 + | _ -> failwith "impossible: retrieve_pair_type" + +let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr = + match e with + | Borrow (k, e) -> + (* If we expect this borrow to be a mutable borrow, then we make it a mutable borrow + (obviously!), and also remember that the variable x itself needs to be `let mut` *) + if is_mut_borrow expected then + match e with + | Var _ -> + failwith "impossible: missing open" + | Open { atom; _ } -> + add_mut_var atom known, Borrow (Mut, e) + | Index (e1, (Range _ as r)) -> + let known, e1 = infer env expected known e1 in + known, Borrow (Mut, Index (e1, r)) + + | Field (Open _, "0", None) + | Field (Open _, "1", None) -> failwith "TODO: borrowing slices" + + | Field (Open {atom; _}, _, _) -> + add_mut_var atom known, Borrow (Mut, e) + + | Field (Deref (Open {atom; _}), _, _) -> + add_mut_borrow atom known, Borrow (Mut, e) + + | Field (Index (Open {atom; _}, _), _, _) -> + add_mut_borrow atom known, Borrow (Mut, e) + + | _ -> + KPrint.bprintf "[infer-mut, borrow] borrwing %a is not supported\n" PrintMiniRust.pexpr e; + failwith "TODO: borrowing something other than a variable" + else + let known, e = infer env (assert_borrow expected) known e in + known, Borrow (k, e) + + | Open { atom; _ } -> + (* If we expect this variable to be a mutable borrow, then we remember it and let the caller + act accordingly. *) + if is_mut_borrow expected then + add_mut_borrow atom known, e + else + known, e + + | Let (b, e1, e2) -> + (* KPrint.bprintf "[infer-mut,let] %a\n" PrintMiniRust.pexpr e; *) + let a, e2 = open_ b e2 in + (* KPrint.bprintf "[infer-mut,let] opened %s[%s]\n" b.name (show_atom_t a); *) + let known, e2 = infer env expected known e2 in + let mut_var = want_mut_var a known in + let mut_borrow = want_mut_borrow a known in + (* KPrint.bprintf "[infer-mut,let-done-e2] %s[%s]: %a let mut ? %b &mut ? %b\n" b.name + (show_atom_t a) + PrintMiniRust.ptyp b.typ mut_var mut_borrow; *) + let t1 = if mut_borrow then make_mut_borrow b.typ else b.typ in + let known, e1 = infer env t1 known e1 in + known, Let ({ b with mut = mut_var; typ = t1 }, e1, close a (Var 0) (lift 1 e2)) + + | Call (Name n, targs, es) -> + if NameMap.mem n env.seen then + (* TODO: substitute targs in ts -- for now, we assume we don't have a type-polymorphic + function that gets instantiated with a reference type *) + let ts = NameMap.find n env.seen in + let known, es = List.fold_left2 (fun (known, es) e t -> + let known, e = infer env t known e in + known, e :: es + ) (known, []) es ts + in + let es = List.rev es in + known, Call (Name n, targs, es) + else if n = ["lowstar";"ignore";"ignore"] then + (* Since we do not have type-level substitutions in MiniRust, we special-case ignore here. + Ideally, it would be added to builtins with `Bound 0` as a suitable type for the + argument. *) + let known, e = infer env (KList.one targs) known (KList.one es) in + known, Call (Name n, targs, [ e ]) + else if n = [ "lib"; "memzero0"; "memzero" ] then ( + (* Same as ignore above *) + assert (List.length es = 2); + let e1, e2 = KList.two es in + let known, e1 = infer env (Ref (None, Mut, Slice (KList.one targs))) known e1 in + let known, e2 = infer env u32 known e2 in + known, Call (Name n, targs, [ e1; e2 ]) + ) else ( + KPrint.bprintf "[infer-mut,call] recursing on %s\n" (String.concat " :: " n); + failwith "TODO: recursion" + ) + + | Call (Operator o, [], _) -> begin match o with + (* Most operators are wrapping and were translated to a methodcall. + We list the few remaining ones here *) + | Add | Sub + | BOr | BAnd | BXor | BNot + | Eq | Neq | Lt | Lte | Gt | Gte + | And | Or | Xor | Not -> known, e + | _ -> + KPrint.bprintf "[infer-mut,call] %a not supported\n" PrintMiniRust.pexpr e; + failwith "TODO: operator not supported" + end + + | Call _ -> + failwith "TODO: Call" + + (* atom = e3 *) + | Assign (Open { atom; _ }, e3, t) -> + (* KPrint.bprintf "[infer-mut,assign] %a\n" PrintMiniRust.pexpr e; *) + let known, e3 = infer env t known e3 in + add_mut_var atom known, e3 + + (* atom[e2] = e2 *) + | Assign (Index (Open { atom; _ } as e1, e2), e3, t) + + (* Special-case when we perform a field assignment that comes from + a slice. This is the only case where we use native Rust tuples. + In this case, we mark the atom as mutable, and will replace + the corresponding call to split by split_at_mut when we reach + let-binding. + *) + (* atom.0[e2] = e3 *) + | Assign (Index (Field (Open {atom;_}, "0", None) as e1, e2), e3, t) + (* atom.1[e2] = e3 *) + | Assign (Index (Field (Open {atom;_}, "1", None) as e1, e2), e3, t) -> + (* KPrint.bprintf "[infer-mut,assign] %a\n" PrintMiniRust.pexpr e; *) + let known = add_mut_borrow atom known in + let known, e2 = infer env usize known e2 in + let known, e3 = infer env t known e3 in + known, Assign (Index (e1, e2), e3, t) + + (* x.f[e2] = e3 *) + | Assign (Index (Field (_, f, st) as e1, e2), e3, t) -> + let known = add_mut_field st f known in + let known, e2 = infer env usize known e2 in + let known, e3 = infer env t known e3 in + known, Assign (Index (e1, e2), e3, t) + + (* (&atom)[e2] = e3 *) + | Assign (Index (Borrow (_, (Open { atom; _ } as e1)), e2), e3, t) -> + (* KPrint.bprintf "[infer-mut,assign] %a\n" PrintMiniRust.pexpr e; *) + let known = add_mut_var atom known in + let known, e2 = infer env usize known e2 in + let known, e3 = infer env t known e3 in + known, Assign (Index (Borrow (Mut, e1), e2), e3, t) + + | Assign (Field (_, "0", None), _, _) + | Assign (Field (_, "1", None), _, _) -> + failwith "TODO: assignment on slice" + + (* (atom.f)[e2] = e3 *) + | Assign (Field (Index ((Open {atom; _} as e1), e2), f, st), e3, t) -> + let known = add_mut_borrow atom known in + let known, e2 = infer env usize known e2 in + let known, e3 = infer env t known e3 in + known, Assign (Field (Index (e1, e2), f, st), e3, t) + + (* (&n)[e2] = e3 *) + | Assign (Index (Borrow (_, Name n), e2), e3, t) -> + (* This case should only occur for globals. For now, we simply mutably borrow it *) + let known, e2 = infer env usize known e2 in + let known, e3 = infer env t known e3 in + known, Assign (Index (Borrow (Mut, Name n), e2), e3, t) + + (* (&(&atom)[e2])[e3] = e4 *) + | Assign (Index (Borrow (_, Index (Borrow (_, (Open {atom; _} as e1)), e2)), e3), e4, t) -> + let known = add_mut_var atom known in + let known, e2 = infer env usize known e2 in + let known, e3 = infer env usize known e3 in + let known, e4 = infer env t known e4 in + known, Assign (Index (Borrow (Mut, Index (Borrow (Mut, e1), e2)), e3), e4, t) + + | Assign _ -> + KPrint.bprintf "[infer-mut,assign] %a unsupported\n" PrintMiniRust.pexpr e; + failwith "TODO: unknown assignment" + + | Var _ + | Array _ + | VecNew _ + | Name _ + | Constant _ + | ConstantString _ + | Unit + | Panic _ + | Operator _ -> + known, e + + | IfThenElse (e1, e2, e3) -> + let known, e1 = infer env bool known e1 in + let known, e2 = infer env expected known e2 in + let known, e3 = + match e3 with + | Some e3 -> + let known, e3 = infer env expected known e3 in + known, Some e3 + | None -> + known, None + in + known, IfThenElse (e1, e2, e3) + + | As (e, t) -> + (* Not really correct, but As is only used for integer casts *) + let known, e = infer env t known e in + known, As (e, t) + + | For (b, e1, e2) -> + let known, e2 = infer env Unit known e2 in + known, For (b, e1, e2) + + | While (e1, e2) -> + let known, e2 = infer env Unit known e2 in + known, While (e1, e2) + + | MethodCall (e1, m, e2) -> + (* There are only a few instances of these generated by AstToMiniRust, so we just review them + all here. Note that there are two possible strategies: AstToMiniRust could use an IndexMut + AST node to indicate e.g. that the destination of `copy_from_slice` ought to be mutable, or + we just bake that knowledge in right here. *) + begin match m with + | [ "wrapping_add" ] | [ "wrapping_div" ] | [ "wrapping_mul" ] + | [ "wrapping_neg" ] | [ "wrapping_rem" ] | [ "wrapping_shl" ] + | [ "wrapping_shr" ] | [ "wrapping_sub" ] + | [ "to_vec" ] -> + known, MethodCall (e1, m, e2) + | ["split_at"] -> + assert (List.length e2 = 1); + let known, e2 = infer env usize known (List.hd e2) in + let t1 = retrieve_pair_type expected in + let known, e1 = infer env t1 known e1 in + if is_mut_borrow expected then + known, MethodCall (e1, ["split_at_mut"], [e2]) + else + known, MethodCall (e1, m, [e2]) + | ["copy_from_slice"] -> begin match e1 with + | Index (dst, range) -> + assert (List.length e2 = 1); + (* We do not have access to the types of e1 and e2. However, the concrete + type should not matter during mut inference, we thus use Unit as a default *) + let known, dst = infer env (Ref (None, Mut, Unit)) known dst in + let known, e2 = infer env (Ref (None, Shared, Unit)) known (List.hd e2) in + known, MethodCall (Index (dst, range), m, [e2]) + (* The AstToMiniRust translation should always introduce an index + as the left argument of copy_from_slice *) + | _ -> failwith "ill-formed copy_from_slice" + end + | [ "push" ] -> begin match e1 with + | Open {atom; _} -> add_mut_var atom known, MethodCall (e1, m, e2) + | _ -> failwith "TODO: push on complex expressions" + end + | _ -> + KPrint.bprintf "%a unsupported\n" PrintMiniRust.pexpr e; + failwith "TODO: MethodCall" + end + + | Range (e1, e2, b) -> + known, Range (e1, e2, b) + + | Struct (name, _es) -> + (* The declaration of the struct should have been traversed beforehand, hence + it should be in the map *) + let _fields_mut = NameMap.find name known.structs in + (* TODO: This should be modified depending on the current struct + in known. *) + known, e + + | Match (e, arms) -> + (* For now, all pattern-matching occur on simple terms, e.g., an enum for an + alg, hence we do not mutify the scrutinee. If this happens to be needed, + we would need to add the expected type of the scrutinee to the Match node, + similar to what is done for Assign and Field, in order to recurse on + the scrutinee *) + let known, arms = List.fold_left_map (fun known (pat, e) -> + let known, e = infer env expected known e in + known, (pat, e) + ) known arms in + known, Match (e, arms) + + | Index (e1, e2) -> + (* The cases where we perform an assignment on an index should be caught + earlier. This should therefore only occur when accessing a variable + in an array *) + let expected = Ref (None, Shared, expected) in + let known, e1 = infer env expected known e1 in + let known, e2 = infer env usize known e2 in + known, Index (e1, e2) + + (* Special case for array slices. This occurs, e.g., when calling a function with + a struct field *) + | Field (Open { atom; _ }, "0", None) | Field (Open { atom; _}, "1", None) -> + if is_mut_borrow expected then + add_mut_borrow atom known, e + else known, e + + | Field _ -> + (* We should be able to ignore this case, on the basis that we are not going to get any + mutability constraints from a field expression. However, we need to modify all of the cases + above (such as assignment) to handle the case where the assignee is a field. *) + known, e + + | Deref _ -> + failwith "TODO: Deref" + +(* We store here a list of builtins, with the types of their arguments. + Type substitutions are currently not supported, functions with generic + args should be added directly to Call in infer *) +let builtins : (name * typ list) list = [ + (* EverCrypt.TargetConfig. The following two functions are handwritten, + while the rest of EverCrypt is generated *) + [ "evercrypt"; "targetconfig"; "has_vec128_not_avx" ], []; + [ "evercrypt"; "targetconfig"; "has_vec256_not_avx2" ], []; + + (* FStar.UInt8 *) + [ "fstar"; "uint8"; "eq_mask" ], [ u8; u8 ]; + [ "fstar"; "uint8"; "gte_mask" ], [ u8; u8 ]; + + (* FStar.UInt16 *) + [ "fstar"; "uint16"; "eq_mask" ], [ u16; u16 ]; + [ "fstar"; "uint16"; "gte_mask" ], [ u16; u16 ]; + + (* FStar.UInt32 *) + [ "fstar"; "uint32"; "eq_mask" ], [ u32; u32 ]; + [ "fstar"; "uint32"; "gte_mask" ], [ u32; u32 ]; + + (* FStar.UInt64 *) + [ "fstar"; "uint64"; "eq_mask" ], [ u64; u64 ]; + [ "fstar"; "uint64"; "gte_mask" ], [ u64; u64 ]; + + + (* FStar.UInt128 *) + [ "fstar"; "uint128"; "add" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "add_mod" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "sub" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "sub_mod" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "logand" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "logxor" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "logor" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "lognot" ], + [Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "shift_left" ], + [Name (["fstar"; "uint128"; "uint128"], []); u32]; + [ "fstar"; "uint128"; "shift_right" ], + [Name (["fstar"; "uint128"; "uint128"], []); u32]; + [ "fstar"; "uint128"; "eq" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "gt" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "lt" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "gte" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "lte" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "eq_mask" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "gte_mask" ], + [Name (["fstar"; "uint128"; "uint128"], []); Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "uint64_to_uint128" ], [u64]; + [ "fstar"; "uint128"; "uint128_to_uint64" ], [Name (["fstar"; "uint128"; "uint128"], [])]; + [ "fstar"; "uint128"; "mul32" ], [u64; u32]; + [ "fstar"; "uint128"; "mul_wide" ], [u64; u32]; + + (* Lib.Inttypes_Intrinsics *) + [ "lib"; "inttypes_intrinsics"; "add_carry_u32"], [ u32; u32; u32; Ref (None, Mut, Slice u32) ]; + [ "lib"; "inttypes_intrinsics"; "sub_borrow_u32"], [ u32; u32; u32; Ref (None, Mut, Slice u32) ]; + [ "lib"; "inttypes_intrinsics"; "add_carry_u64"], [ u64; u64; u64; Ref (None, Mut, Slice u64) ]; + [ "lib"; "inttypes_intrinsics"; "sub_borrow_u64"], [ u64; u64; u64; Ref (None, Mut, Slice u64) ]; + + + (* Lib.IntVector_intrinsics, Vec128 *) + [ "lib"; "intvector_intrinsics"; "vec128_add32"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_add64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_and"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_eq64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_extract64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec128_gt64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_insert64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); u64; u32]; + [ "lib"; "intvector_intrinsics"; "vec128_interleave_low32"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_interleave_low64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_interleave_high32"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_interleave_high64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_load32"], [u32]; + [ "lib"; "intvector_intrinsics"; "vec128_load32s"], [u32; u32; u32; u32]; + [ "lib"; "intvector_intrinsics"; "vec128_load32_be"], [Ref (None, Shared, Slice u8)]; + [ "lib"; "intvector_intrinsics"; "vec128_load32_le"], [Ref (None, Shared, Slice u8)]; + [ "lib"; "intvector_intrinsics"; "vec128_load64"], [u64]; + [ "lib"; "intvector_intrinsics"; "vec128_load64_le"], [Ref (None, Shared, Slice u8)]; + [ "lib"; "intvector_intrinsics"; "vec128_lognot"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_mul32"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_mul64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_or"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_rotate_left32"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec128_rotate_right32"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec128_rotate_right_lanes32"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec128_shift_left64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec128_shift_right32"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec128_shift_right64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec128_smul64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); u64]; + [ "lib"; "intvector_intrinsics"; "vec128_store32_be"], + [Ref (None, Mut, Slice u8); Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_store32_le"], + [Ref (None, Mut, Slice u8); Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_sub64"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + [ "lib"; "intvector_intrinsics"; "vec128_xor"], + [Name (["lib"; "intvector_intrinsics"; "vec128"], []); + Name (["lib"; "intvector_intrinsics"; "vec128"], [])]; + + (* Lib.IntVector_intrinsics, Vec256 *) + [ "lib"; "intvector_intrinsics"; "vec256_add32"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_add64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_and"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_eq64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_extract64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec256_gt64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_insert64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u64; u32]; + [ "lib"; "intvector_intrinsics"; "vec256_interleave_low32"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_interleave_low64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_interleave_low128"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_interleave_high32"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_interleave_high64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_interleave_high128"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_load32"], [u32]; + [ "lib"; "intvector_intrinsics"; "vec256_load32s"], [u32; u32; u32; u32; u32; u32; u32; u32]; + [ "lib"; "intvector_intrinsics"; "vec256_load32_be"], [Ref (None, Shared, Slice u8)]; + [ "lib"; "intvector_intrinsics"; "vec256_load32_le"], [Ref (None, Shared, Slice u8)]; + [ "lib"; "intvector_intrinsics"; "vec256_load64"], [u64]; + [ "lib"; "intvector_intrinsics"; "vec256_load64s"], [u64; u64; u64; u64]; + [ "lib"; "intvector_intrinsics"; "vec256_load64_be"], [Ref (None, Shared, Slice u8)]; + [ "lib"; "intvector_intrinsics"; "vec256_load64_le"], [Ref (None, Shared, Slice u8)]; + [ "lib"; "intvector_intrinsics"; "vec256_lognot"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_mul64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_or"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_rotate_left32"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec256_rotate_right32"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec256_rotate_right64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec256_rotate_right_lanes64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec256_shift_left64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec256_shift_right"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec256_shift_right32"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec256_shift_right64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u32]; + [ "lib"; "intvector_intrinsics"; "vec256_smul64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); u64]; + [ "lib"; "intvector_intrinsics"; "vec256_store32_be"], + [Ref (None, Mut, Slice u8); Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_store32_le"], + [Ref (None, Mut, Slice u8); Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_store64_be"], + [Ref (None, Mut, Slice u8); Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_store64_le"], + [Ref (None, Mut, Slice u8); Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_sub64"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + [ "lib"; "intvector_intrinsics"; "vec256_xor"], + [Name (["lib"; "intvector_intrinsics"; "vec256"], []); + Name (["lib"; "intvector_intrinsics"; "vec256"], [])]; + + (* Lib.RandomBuffer_System *) + [ "lib"; "randombuffer_system"; "randombytes"], [Ref (None, Mut, Slice u8); u32]; + + (* LowStar.Endianness, little-endian *) + [ "lowstar"; "endianness"; "load16_le" ], [Ref (None, Shared, Slice u8)]; + [ "lowstar"; "endianness"; "store16_le" ], [Ref (None, Mut, Slice u8); u16]; + [ "lowstar"; "endianness"; "load32_le" ], [Ref (None, Shared, Slice u8)]; + [ "lowstar"; "endianness"; "store32_le" ], [Ref (None, Mut, Slice u8); u32]; + [ "lowstar"; "endianness"; "load64_le" ], [Ref (None, Shared, Slice u8)]; + [ "lowstar"; "endianness"; "store64_le" ], [Ref (None, Mut, Slice u8); u64]; + + (* LowStar.Endianness, big-endian *) + [ "lowstar"; "endianness"; "store16_be" ], [Ref (None, Mut, Slice u8); u16]; + [ "lowstar"; "endianness"; "load32_be" ], [Ref (None, Shared, Slice u8)]; + [ "lowstar"; "endianness"; "store32_be" ], [Ref (None, Mut, Slice u8); u32]; + [ "lowstar"; "endianness"; "load64_be" ], [Ref (None, Shared, Slice u8)]; + [ "lowstar"; "endianness"; "store64_be" ], [Ref (None, Mut, Slice u8); u64]; + [ "lowstar"; "endianness"; "load128_be" ], [Ref (None, Shared, Slice u8)]; + [ "lowstar"; "endianness"; "store128_be" ], + [Ref (None, Mut, Slice u8); Name (["fstar"; "uint128"; "uint128"], [])]; + + (* Vec *) + [ "Vec"; "new" ], []; + + (* Vale assembly functions marked as extern. This should probably be handled earlier *) + [ "vale"; "stdcalls_x64_sha"; "sha256_update"], [ + Ref (None, Mut, Slice u32); Ref (None, Shared, Slice u8); u32; + Ref (None, Shared, Slice u32) + ]; + [ "vale"; "inline_x64_fadd_inline"; "add_scalar" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Shared, Slice u64) + ]; + [ "vale"; "stdcalls_x64_fadd"; "add_scalar_e" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Shared, Slice u64) + ]; + [ "vale"; "inline_x64_fadd_inline"; "fadd" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Shared, Slice u64) + ]; + [ "vale"; "stdcalls_x64_fadd"; "fadd_e" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Shared, Slice u64) + ]; + [ "vale"; "inline_x64_fadd_inline"; "fsub" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Shared, Slice u64) + ]; + [ "vale"; "stdcalls_x64_fsub"; "fsub_e" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Shared, Slice u64) + ]; + [ "vale"; "inline_x64_fmul_inline"; "fmul" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); + Ref (None, Shared, Slice u64); Ref (None, Mut, Slice u64); + ]; + [ "vale"; "stdcalls_x64_fmul"; "fmul_e" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); + Ref (None, Shared, Slice u64); Ref (None, Mut, Slice u64); + ]; + [ "vale"; "inline_x64_fmul_inline"; "fmul2" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); + Ref (None, Shared, Slice u64); Ref (None, Mut, Slice u64); + ]; + [ "vale"; "stdcalls_x64_fmul"; "fmul2_e" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); + Ref (None, Shared, Slice u64); Ref (None, Mut, Slice u64); + ]; + [ "vale"; "inline_x64_fmul_inline"; "fmul_scalar" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); u64 + ]; + [ "vale"; "stdcalls_x64_fmul"; "fmul_scalar_e" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); u64 + ]; + [ "vale"; "inline_x64_fsqr_inline"; "fsqr" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Mut, Slice u64) + ]; + [ "vale"; "stdcalls_x64_fsqr"; "fsqr_e" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Mut, Slice u64) + ]; + [ "vale"; "inline_x64_fsqr_inline"; "fsqr2" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Mut, Slice u64) + ]; + [ "vale"; "stdcalls_x64_fsqr"; "fsqr2_e" ], [ + Ref (None, Mut, Slice u64); Ref (None, Shared, Slice u64); Ref (None, Mut, Slice u64) + ]; + [ "vale"; "inline_x64_fswap_inline"; "cswap2" ], [ + u64; Ref (None, Mut, Slice u64); Ref (None, Mut, Slice u64) + ]; + [ "vale"; "stdcalls_x64_fswap"; "cswap2_e" ], [ + u64; Ref (None, Mut, Slice u64); Ref (None, Mut, Slice u64) + ]; + + (* TODO: These functions are recursive, and should be handled properly. + For now, we hardcode their expected type and mutability in HACL *) + [ "hacl"; "bignum"; "bn_karatsuba_mul_uint32"], [ + u32; Ref (None, Shared, Slice u32); Ref (None, Shared, Slice u32); + Ref (None, Mut, Slice u32); Ref (None, Mut, Slice u32) + ]; + [ "hacl"; "bignum"; "bn_karatsuba_mul_uint64"], [ + u64; Ref (None, Shared, Slice u64); Ref (None, Shared, Slice u64); + Ref (None, Mut, Slice u64); Ref (None, Mut, Slice u64) + ]; + [ "hacl"; "bignum"; "bn_karatsuba_sqr_uint32"], [ + u32; Ref (None, Shared, Slice u32); + Ref (None, Mut, Slice u32); Ref (None, Mut, Slice u32) + ]; + [ "hacl"; "bignum"; "bn_karatsuba_sqr_uint64"], [ + u64; Ref (None, Shared, Slice u64); + Ref (None, Mut, Slice u64); Ref (None, Mut, Slice u64) + ]; + + +] + +let infer_mut_borrows files = + (* Map.of_list is only available from OCaml 5.1 onwards *) + let env = { seen = List.to_seq builtins |> NameMap.of_seq; structs = NameMap.empty } in + let known = { structs = NameMap.empty; v = VarSet.empty; r = VarSet.empty } in + let env, files = + List.fold_left (fun (env, files) (filename, decls) -> + let env, decls = List.fold_left (fun (env, decls) decl -> + match decl with + | Function ({ name; body; return_type; parameters; _ } as f) -> + (* KPrint.bprintf "[infer-mut] visiting %s\n%a\n" (String.concat "." name) + PrintMiniRust.pdecl decl; *) + let atoms, body = + List.fold_right (fun binder (atoms, e) -> + let a, e = open_ binder e in + (* KPrint.bprintf "[infer-mut] opened %s[%s]\n%a\n" binder.name (show_atom_t a) PrintMiniRust.pexpr e; *) + a :: atoms, e + ) parameters ([], body) + in + (* KPrint.bprintf "[infer-mut] done opening %s\n%a\n" (String.concat "." name) + PrintMiniRust.pexpr body; *) + (* Start the analysis with the current state of struct mutability *) + let known, body = infer env return_type {known with structs = env.structs} body in + let parameters, body = + List.fold_left2 (fun (parameters, e) (binder: binding) atom -> + let e = close atom (Var 0) (lift 1 e) in + (* KPrint.bprintf "[infer-mut] closed %s[%s]\n%a\n" binder.name (show_atom_t atom) PrintMiniRust.pexpr e; *) + let mut = want_mut_var atom known in + let typ = if want_mut_borrow atom known then make_mut_borrow binder.typ else binder.typ in + { binder with mut; typ } :: parameters, e + ) ([], body) parameters atoms + in + let parameters = List.rev parameters in + (* We update the environment in two ways. First, we add the function declaration, + with the mutability of the parameters inferred during the analysis. + Second, we propagate the information about the mutability of struct fields + inferred while traversing this function to the global environment. Note, since + the traversal does not add or remove any bindings, but only increases the + mutability, we can do a direct replacement instead of a more complex merge *) + let env = { seen = NameMap.add name (List.map (fun (x: binding) -> x.typ) parameters) env.seen; structs = known.structs } in + env, Function { f with body; parameters } :: decls + | Struct ({name; fields; _}) -> + {env with structs = NameMap.add name fields env.structs}, decl :: decls + | _ -> + env, decl :: decls + ) (env, []) decls in + let decls = List.rev decls in + env, (filename, decls) :: files + ) (env, []) files + in + + (* We traverse all declarations again, and update the structure decls + with the new mutability info *) + List.map (fun (filename, decls) -> filename, List.map (function + | Struct ({ name; _ } as s) -> Struct { s with fields = NameMap.find name env.structs } + | x -> x + ) decls + ) (List.rev files) diff --git a/lib/PrintMiniRust.ml b/lib/PrintMiniRust.ml index 744ca0aa..1922079a 100644 --- a/lib/PrintMiniRust.ml +++ b/lib/PrintMiniRust.ml @@ -372,7 +372,7 @@ and print_expr env (context: int) (e: expr): document = | None -> empty end - | Assign (e1, e2) -> + | Assign (e1, e2, _) -> let mine, left, right = 18, 17, 18 in paren_if mine @@ group (print_expr env left e1 ^^ space ^^ equals ^^ @@ -427,7 +427,7 @@ and print_expr env (context: int) (e: expr): document = let mine = 4 in paren_if mine @@ print_expr env mine p ^^ group (brackets (print_expr env max_int e)) - | Field (e, s) -> + | Field (e, s, _) -> group (print_expr env 3 e ^^ dot ^^ string s) | Deref e -> let mine = 6 in @@ -442,6 +442,8 @@ and print_expr env (context: int) (e: expr): document = group (print_pat env p ^/^ string "=>") ^^ group (nest 2 (break1 ^^ print_expr env max_int e)) ) patexprs) + | Open { name; _ } -> at ^^ string name + and print_pat env (p: pat) = match p with | Literal c -> print_constant c @@ -544,3 +546,4 @@ let print_decls ns ds = let pexpr = printf_of_pprint (print_expr debug max_int) let ptyp = printf_of_pprint (print_typ debug) +let pdecl = printf_of_pprint (fun x -> snd (print_decl debug x)) diff --git a/lib/RustMacroize.ml b/lib/RustMacroize.ml new file mode 100644 index 00000000..30263c6b --- /dev/null +++ b/lib/RustMacroize.ml @@ -0,0 +1,54 @@ +(* Rewritings on the Rust AST after translation and mutability inference *) + +open MiniRust + +(* Loop unrolling introduces an implicit binder that does not interact well + with the substitutions occurring in mutability inference. + We perform it after the translation *) +let unroll_loops = object + inherit [_] map_expr as super + method! visit_For _ b e1 e2 = + let e2 = super#visit_expr () e2 in + + match e1 with + | Range (Some (Constant (UInt32, init as k_init) as e_init), Some (Constant (UInt32, max)), false) + when ( + let init = int_of_string init in + let max = int_of_string max in + let n_loops = max - init in + n_loops <= !Options.unroll_loops + ) -> + let init = int_of_string init in + let max = int_of_string max in + let n_loops = max - init in + + if n_loops = 0 then Unit + + else if n_loops = 1 then subst e_init 0 e2 + + else Call (Name ["krml"; "unroll_for!"], [], [ + Constant (CInt, string_of_int n_loops); + ConstantString b.name; + Constant k_init; + Constant (UInt32, "1"); + e2 + ]) + + | _ -> For (b, e1, e2) +end + +let macroize files = + let files = + List.fold_left (fun files (filename, decls) -> + let decls = List.fold_left (fun decls decl -> + match decl with + | Function ({ body; _ } as f) -> + let body = unroll_loops#visit_expr () body in + Function {f with body} :: decls + | _ -> decl :: decls + ) [] decls in + let decls = List.rev decls in + (filename, decls) :: files + ) [] files + in + List.rev files diff --git a/lib/lib/KList.ml b/lib/lib/KList.ml index 7cf08ba3..231a9ad9 100644 --- a/lib/lib/KList.ml +++ b/lib/lib/KList.ml @@ -72,6 +72,11 @@ let one l = | [ x ] -> x | _ -> invalid_arg ("one: argument is of length " ^ string_of_int (List.length l)) +let two l = + match l with + | [ x; y ] -> (x, y) + | _ -> invalid_arg ("one: argument is of length " ^ string_of_int (List.length l)) + (* NOTE: provided by {!Stdlib.List} in OCaml 5.1. *) let is_empty = function | [] -> true diff --git a/src/Karamel.ml b/src/Karamel.ml index 55f9104a..2fcbb7a0 100644 --- a/src/Karamel.ml +++ b/src/Karamel.ml @@ -750,6 +750,8 @@ Supported options:|} if Options.debug "rs" then print PrintAst.print_files files; let files = AstToMiniRust.translate_files files in + let files = OptimizeMiniRust.infer_mut_borrows files in + let files = RustMacroize.macroize files in OutputRust.write_all files else diff --git a/test/Makefile b/test/Makefile index efab490f..30b48487 100644 --- a/test/Makefile +++ b/test/Makefile @@ -88,7 +88,7 @@ FSTAR = $(FSTAR_EXE) --cache_checked_modules \ --trivial_pre_for_unannotated_effectful_fns false \ --cmi --warn_error -274 -all: $(FILES) $(RUST_FILES) $(WASM_FILES) $(CUSTOM) ctypes-test sepcomp-test +all: $(FILES) rust $(WASM_FILES) $(CUSTOM) ctypes-test sepcomp-test # Needs node wasm: $(WASM_FILES) @@ -298,14 +298,14 @@ WasmTrap.wasm-test: NEGATIVE = true .PRECIOUS: %.rs %.rs: $(ALL_KRML_FILES) $(KRML_BIN) - $(KRML) -minimal -bundle $(notdir $*)=\* \ + $(KRML) -minimal -bundle $(notdir $(subst rust,Rust,$*))=\* \ -backend rust $(EXTRA) -tmpdir $(dir $@) $(filter %.krml,$^) - $(SED) -i 's/\(assignments..\)/\1\nmod lowstar { pub mod ignore { pub fn ignore(_x: T) {}}}\n/' $@ + $(SED) -i 's/\(mutation..\)/\1\nmod lowstar { pub mod ignore { pub fn ignore(_x: T) {}}}\n/' $@ echo 'fn main () { let r = main_ (); if r != 0 { println!("main_ returned: {}\\n", r); panic!() } }' >> $@ %.rust-test: $(OUTPUT_DIR)/%.rs rustc $< && ./$* -rust: $(RUST_FILES) $(patsubst %.fst,%.rust-test,$(filter-out Rust1.fst Rust2.fst Rust3.fst,$(wildcard Rust*.fst))) +rust: $(RUST_FILES) $(patsubst Rust%.fst,rust%.rust-test,$(filter-out Rust1.fst Rust2.fst Rust3.fst,$(wildcard Rust*.fst))) RUST_FILES = $(patsubst %.rs,%.rust-test,$(wildcard Rust*.fst)) diff --git a/test/Rust4.fst b/test/Rust4.fst index 6e65a448..8d047dda 100644 --- a/test/Rust4.fst +++ b/test/Rust4.fst @@ -2,6 +2,7 @@ module Rust4 let f (): HyperStack.ST.St UInt32.t = 1ul + let main_ () = if not (f () = 0ul) then 0l diff --git a/test/Rust5.fst b/test/Rust5.fst index e3f85643..3f02573c 100644 --- a/test/Rust5.fst +++ b/test/Rust5.fst @@ -6,6 +6,7 @@ module B = LowStar.Buffer let ignore #a (x: a): Stack unit (fun h0 -> True) (fun h0 r h1 -> h0 == h1) = () + let main_ (): St Int32.t = push_frame (); let base = B.alloca 0l 2ul in diff --git a/test/Rust6.fst b/test/Rust6.fst index 81815bc7..d09cea3c 100644 --- a/test/Rust6.fst +++ b/test/Rust6.fst @@ -6,6 +6,7 @@ module B = LowStar.Buffer module C = LowStar.ConstBuffer module HS = FStar.HyperStack + inline_for_extraction noextract val sub_len: b: C.const_buffer UInt32.t -> diff --git a/test/Rust7.fst b/test/Rust7.fst new file mode 100644 index 00000000..dd3a84d4 --- /dev/null +++ b/test/Rust7.fst @@ -0,0 +1,134 @@ +module Rust7 + +module U32 = FStar.UInt32 +module B = LowStar.Buffer +open LowStar.BufferOps + +open FStar +open FStar.HyperStack.ST + + +val add_carry_u32: + x:U32.t + -> y:U32.t + -> r:B.lbuffer U32.t 1 + -> p:B.lbuffer U32.t 1 -> + Stack U32.t + (requires fun h -> B.live h r /\ B.live h p) + (ensures fun h0 c h1 -> True) + // modifies1 r h0 h1 /\ v c <= 1 /\ + // (let r = Seq.index (as_seq h1 r) 0 in + // v r + v c * pow2 (bits t) == v x + v y + v cin)) + +let add_carry_u32 x y r p = + let z = B.index p 0ul in + let res = U32.add_mod x y in + let res = U32.add_mod res z in + // let c = (U32.shift_right res 32ul) in + B.upd r 0ul res; + 0ul + +let test_alloca (x: UInt32.t) : Stack UInt32.t + (requires (fun h0 -> True)) + (ensures (fun h0 r h1 -> True)) = + push_frame(); + let ptr = B.alloca 0ul 10ul in + B.upd ptr 0ul x; + let res = B.index ptr 0ul in + pop_frame(); + x + +// simple for loop example - note that there is no framing +let loop (ptr:B.lbuffer U32.t 10) : Stack UInt32.t + (requires (fun h0 -> B.live h0 ptr)) + (ensures (fun h0 r h1 -> True)) = + push_frame(); + C.Loops.for 0ul 0ul + (fun h i -> B.live h ptr) + (fun i -> B.upd ptr i 1ul); + C.Loops.for 0ul 1ul + (fun h i -> B.live h ptr) + (fun i -> B.upd ptr i 1ul); + C.Loops.for 0ul 10ul + (fun h i -> B.live h ptr) + (fun i -> B.upd ptr i 1ul); + let sum = B.index ptr 0ul in + pop_frame(); + sum + + +let loop_alloc () : Stack UInt32.t + (requires (fun h0 -> True)) + (ensures (fun h0 r h1 -> True)) = + push_frame(); + let ptr = B.alloca 0ul 10ul in + C.Loops.for 0ul 9ul + (fun h i -> B.live h ptr) + (fun i -> B.upd ptr i 1ul); + let sum = B.index ptr 0ul in + pop_frame(); + sum + +let touch (#a: Type) (x: a): Stack unit (fun _ -> True) (fun _ _ _ -> True) = + () + +let upd (x: B.buffer UInt64.t): Stack unit (fun h -> B.live h x /\ B.length x >= 1) + (fun h0 _ h1 -> B.modifies (B.loc_buffer x) h0 h1) = + B.upd x 0ul 0UL + +let root_alias (): Stack unit (fun _ -> True) (fun _ _ _ -> True) = + push_frame (); + let x = B.alloca 0UL 6ul in + let x0 = B.sub x 0ul 2ul in + let x1 = B.sub x 2ul 2ul in + + let x00 = B.sub x0 0ul 1ul in + let x01 = B.sub x0 1ul 1ul in + + touch x0; + touch x1; + touch x00; + touch x01; + + pop_frame() + +let slice_upd (): Stack unit (fun _ -> True) (fun _ _ _ -> True) = + push_frame (); + let x = B.alloca 0UL 6ul in + let x0 = B.sub x 0ul 2ul in + let x1 = B.sub x 2ul 2ul in + + let x00 = B.sub x0 0ul 1ul in + let x01 = B.sub x0 1ul 1ul in + + upd x00; + + pop_frame() + +let basic_copy1 (): Stack unit (fun _ -> True) (fun _ _ _ -> True) = + push_frame (); + let x = B.alloca 0ul 6ul in + let y = B.alloca 1ul 6ul in + B.blit y 0ul x 0ul 6ul; + pop_frame() + +let basic_copy2 (): Stack unit (fun _ -> True) (fun _ _ _ -> True) = + push_frame (); + let x = B.alloca 0ul 6ul in + let y = B.alloca 1ul 6ul in + let x0 = B.sub x 0ul 2ul in + let x1 = B.sub x0 0ul 1ul in + + B.upd x1 0ul 5ul; + B.blit x0 0ul y 0ul 2ul; + pop_frame() + +noeq +type point = { x : B.lbuffer U32.t 1; y : B.lbuffer U32.t 1 } + +let struct_upd (p: point) : Stack UInt32.t (fun h -> B.live h p.x) (fun _ _ _ -> True) = + B.upd p.x 0ul 0ul; + 0ul + + +let main_ () = 0l