Skip to content

Commit

Permalink
Merge pull request #449 from FStarLang/protz_infer_mut
Browse files Browse the repository at this point in the history
Infer mutability of borrows during Rust extraction
  • Loading branch information
R1kM authored Jul 19, 2024
2 parents 65aab55 + 4514866 commit facdca2
Show file tree
Hide file tree
Showing 12 changed files with 1,127 additions and 134 deletions.
121 changes: 31 additions & 90 deletions lib/AstToMiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,8 @@ let translate_unknown_lid (m, n) =
let m = compress_prefix m in
List.map String.lowercase_ascii m @ [ n ]

let borrow_kind_of_bool b: MiniRust.borrow_kind =
if b (* const *) then
Shared
else
Mut
let borrow_kind_of_bool _b: MiniRust.borrow_kind =
Shared

type config = {
box: bool;
Expand Down Expand Up @@ -511,7 +508,7 @@ let lookup_split (env: env) (v_base: MiniRust.db_index) (path: Splits.root_or_pa
| Some (v_base', path') when v_base' = v_base - ofs - 1 ->
begin match Splits.accessible_via path path' with
| Some pe ->
MiniRust.(Field (Var ofs, Splits.string_of_path_elem pe))
MiniRust.(Field (Var ofs, Splits.string_of_path_elem pe, None))
| None ->
find (ofs + 1) bs
end
Expand Down Expand Up @@ -590,7 +587,7 @@ and translate_array (env: env) is_toplevel (init: Ast.expr): env * MiniRust.expr
necessitate the insertion of conversions *)
and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env * MiniRust.expr =
let erase_lifetime_info = (object(self)
inherit [_] MiniRust.map
inherit [_] MiniRust.DeBruijn.map
method! visit_Ref env _ bk t = Ref (None, bk, self#visit_typ env t)
method! visit_tname _ n _ = Name (n, [])
end)#visit_typ ()
Expand All @@ -615,18 +612,17 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
do retain them in our Function type -- so we need to relax the comparison here *)
x
(* More conversions due to box-ing types. *)
| _, App (Name (["Box"], _), [Slice _]), Ref (_, Mut, Slice _) ->
Borrow (Mut, Deref x)
| _, Ref (_, Mut, Slice _), App (Name (["Box"], _), [Slice _]) ->
| _, App (Name (["Box"], _), [Slice _]), Ref (_, k, Slice _) ->
Borrow (k, Deref x)
| _, Ref (_, _, Slice _), App (Name (["Box"], _), [Slice _]) ->
MethodCall (Borrow (Shared, Deref x), ["into"], [])
| _, Ref (_, Shared, Slice _), App (Name (["Box"], _), [Slice _]) ->
MethodCall (x, ["into"], [])
(* | _, Ref (_, Shared, Slice _), App (Name (["Box"], _), [Slice _]) -> *)
(* MethodCall (x, ["into"], []) *)
| _, Vec _, App (Name (["Box"], _), [Slice _]) ->
MethodCall (MethodCall (x, ["try_into"], []), ["unwrap"], [])

(* More conversions due to vec-ing types *)
| _, Ref (_, Mut, Slice _), Vec _
| _, Ref (_, Shared, Slice _), Vec _ ->
| _, Ref (_, _, Slice _), Vec _ ->
MethodCall (x, ["to_vec"], [])
| _, Array _, Vec _ ->
Call (Name ["Vec"; "from"], [], [x])
Expand All @@ -644,8 +640,8 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
| _, Ref (_, _, t), t' when t = t' ->
Deref x

| Borrow (Mut, e) , Ref (_, _, t), Ref (_, _, Slice t') when t = t' ->
Borrow (Mut, Array (List [ e ]))
| Borrow (k, e) , Ref (_, _, t), Ref (_, _, Slice t') when t = t' ->
Borrow (k, Array (List [ e ]))

| _ ->
(* If we reach this case, we perform one last try by erasing the lifetime
Expand Down Expand Up @@ -727,8 +723,9 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env

| EApp ({ node = EQualified ([ "LowStar"; "BufferOps" ], s); _ }, e1 :: e2 :: _ ) when KString.starts_with s "op_Star_Equals__" ->
let env, e1 = translate_expr env e1 in
let t2 = translate_type env e2.typ in
let env, e2 = translate_expr env e2 in
env, Assign (Index (e1, MiniRust.zero_usize), e2)
env, Assign (Index (e1, MiniRust.zero_usize), e2, t2)

| EApp ({ node = ETApp (e, cgs, cgs', ts); _ }, es) ->
assert (cgs @ cgs' = []);
Expand Down Expand Up @@ -805,10 +802,10 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
| [] ->
failwith "[ELet] unexpected: path not found"
in
env, MiniRust.(Field (find 0 env.vars, Splits.string_of_path_elem path_elem))
env, MiniRust.(Field (find 0 env.vars, Splits.string_of_path_elem path_elem, None))
in

let split_at = match b.typ with TBuf (_, true) -> "split_at" | _ -> "split_at_mut" in
let split_at = "split_at" in
let e1 = MiniRust.MethodCall (e_nearest , [split_at], [ index ]) in
let t = translate_type env b.typ in
let binding : MiniRust.binding * Splits.info =
Expand All @@ -832,7 +829,7 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env

let env, e1, t = translate_array env false init in
(* KPrint.bprintf "Let %s: %a\n" b.node.name PrintMiniRust.ptyp t; *)
let binding: MiniRust.binding = { name = b.node.name; typ = t; mut = true } in
let binding: MiniRust.binding = { name = b.node.name; typ = t; mut = false } in
let env = push env binding in
env0, Let (binding, e1, snd (translate_expr_with_type env e2 t_ret))

Expand All @@ -847,6 +844,8 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
| TQualified lid when Idents.LidSet.mem lid env.heap_structs -> true
| _ -> false
in
(* TODO how does this play out with the new "translate as non-mut by
default" strategy? *)
let mut = b.node.mut || is_owned_struct in
(* Here, the idea is to detect forbidden move-outs that are certain to result in a compilation
error. Typically, selecting a field, dereferencing an array, etc. when the underlying type
Expand All @@ -858,7 +857,7 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
cannot be mutated in place in Low*, so it's ok to borrow instead of copy. *)
let e1, t = match e1 with
| (Field _ | Index _) when is_owned_struct ->
MiniRust.(Borrow (Mut, e1), Ref (None, Mut, t))
MiniRust.(Borrow (Shared, e1), Ref (None, Shared, t))
| _ ->
e1, t
in
Expand All @@ -883,7 +882,7 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
in
let env, e1 = translate_expr_with_type env e1 lvalue_type in
let env, e2 = translate_expr_with_type env e2 lvalue_type in
env, Assign (e1, e2)
env, Assign (e1, e2, lvalue_type)
| EBufCreate _ ->
failwith "unexpected: EBufCreate"
| EBufCreateL _ ->
Expand All @@ -895,19 +894,16 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
| EBufWrite (e1, e2, e3) ->
let env, e1 = translate_expr env e1 in
let env, e2 = translate_expr_with_type env e2 (Constant SizeT) in
let t3 = translate_type env e3.typ in
let env, e3 = translate_expr env e3 in
env, Assign (Index (e1, e2), e3)
env, Assign (Index (e1, e2), e3, t3)
| EBufSub (e1, e2) ->
(* This is a fallback for the analysis above. Happens if, for instance, the pointer arithmetic
appears in subexpression position (like, function call), in which case there's a chance
this might still work! *)
let is_const_tbuf = match e1.typ with
| TBuf (_, b) -> b
| _ -> false
in
let env, e1 = translate_expr env e1 in
let env, e2 = translate_expr_with_type env e2 (Constant SizeT) in
env, Borrow ((if is_const_tbuf then Shared else Mut), Index (e1, Range (Some e2, None, false)))
env, Borrow (Shared, Index (e1, Range (Some e2, None, false)))
| EBufDiff _ ->
failwith "unexpected: EBufDiff"
(* Silly pattern in Low*: for historical reasons, the blit operations takes a
Expand Down Expand Up @@ -944,7 +940,7 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
(* Rather than error out, we do nothing, as some functions may allocate then free. *)
env, Unit
| EBufNull ->
env, possibly_convert (Borrow (Mut, Array (List []))) (translate_type env e.typ)
env, possibly_convert (Borrow (Shared, Array (List []))) (translate_type env e.typ)
| EPushFrame ->
failwith "unexpected: EPushFrame"
| EPopFrame ->
Expand Down Expand Up @@ -1000,7 +996,8 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env

| EField (e, f) ->
let env, e_ = translate_expr env e in
env, possibly_convert (Field (e_, f)) (field_type env e f)
let t = translate_type env e.typ in
env, possibly_convert (Field (e_, f, Some t)) (field_type env e f)

| EBreak ->
failwith "TODO: EBreak"
Expand All @@ -1011,66 +1008,10 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
| EWhile _ ->
failwith "TODO: EWhile"

(* Loop with all constant bounds. *)
| EFor (b,
({ node = EConstant (UInt32, init as k_init); _ } as e_init),
{ node = EApp (
{ node = EOp (Lt, UInt32); _ },
[{ node = EBound 0; _ };
({ node = EConstant (UInt32, max); _ })]); _},
{ node = EAssign (
{ node = EBound 0; _ },
{ node = EApp (
{ node = EOp (Add, UInt32); _ },
[{ node = EBound 0; _ };
({ node = EConstant (UInt32, incr as k_incr); _ })]); _}); _},
body)
when (
let init = int_of_string init in
let max = int_of_string max in
let incr = int_of_string incr in
let n_loops = (max - init + incr - 1) / incr in
n_loops <= !Options.unroll_loops
)
->
(* Keep initial environment to return after translation *)
let env0 = env in

let init = int_of_string init in
let max = int_of_string max in
let incr = int_of_string incr in
let n_loops = (max - init + incr - 1) / incr in

if n_loops = 0 then
env, Unit

else if n_loops = 1 then
let body = DeBruijn.subst e_init 0 body in
translate_expr env body

else begin
let unused = snd !(b.node.mark) = AtMost 3 in
(* We do an ad-hoc thing since this didn't go through lowstar.ignore
insertion. Rust uses the OCaml convention (which I recall I did suggest
to Graydon back in 2010). *)
let unused = if unused then "_" else "" in
let b: MiniRust.binding = { name = unused ^ b.node.name; typ = translate_type env b.typ; mut = false } in
let _, body = translate_expr (push env b) body in

(* This is a weird node, because it contains a binder, but does not rely
on a MiniRust.binding to encode that fact. For that reason, the
printer needs to be kept in sync and catch this special application
node. We could do this rewrite on the fly when pretty-printing, but
we'd have to retain the number of uses (above) in the node to figure
out whether to make the binder ignored or not. *)
env0, Call (Name ["krml"; "unroll_for!"], [], [
Constant (CInt, string_of_int n_loops);
ConstantString b.name;
Constant k_init;
Constant k_incr;
body ])
end

(* The introduction of the unroll_loops macro requires a "fake" binder
for the iterated value, which messes up with variable substitutions
mutability inference. We instead perform it in RustMacroize, after
all substitutions are done. *)
| EFor (b, e_start, e_test, e_incr, e_body) ->
(* Keep initial environment to return after translation *)
let env0 = env in
Expand Down
Loading

0 comments on commit facdca2

Please sign in to comment.