Skip to content

Commit

Permalink
Merge pull request #358 from FStarLang/_taramana_mutual_struct
Browse files Browse the repository at this point in the history
Some extraction tests for mutually recursive structs
  • Loading branch information
tahina-pro authored Jul 6, 2023
2 parents fabda97 + 66e0e1f commit bffa630
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 34 deletions.
128 changes: 94 additions & 34 deletions src/Monomorphization.ml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ let monomorphize_data_types map = object(self)
val mutable current_file = ""
(* Possibly populated with something relevant *)
val mutable best_hint: node * lident = (dummy_lid, []), dummy_lid
(* For forward references, a map from lid to its pending monomorphizations
(type arguments) *)
val pending_monomorphizations = Hashtbl.create 41
val seen_declarations = Hashtbl.create 41

method sanity_check () =
Hashtbl.iter (fun lid args ->
KPrint.bprintf "Missing monomorphization: %a<%a>\n" plid lid ptyps args
) pending_monomorphizations;
if Hashtbl.length pending_monomorphizations > 0 then
Warn.fatal_error "Internal error: missing monomorphizations"

(* Record a new declaration. *)
method private record (d: decl) =
Expand All @@ -96,15 +107,15 @@ let monomorphize_data_types map = object(self)
method! visit_TTuple () args =
match Hashtbl.find state (tuple_lid, args) with
| exception Not_found ->
let args = List.map (self#visit_typ ()) args in
let args = List.map (self#visit_typ false) args in
TTuple args
| _, chosen_lid ->
TQualified chosen_lid

method! visit_TApp () lid args =
match Hashtbl.find state (lid, args) with
| exception Not_found ->
let args = List.map (self#visit_typ ()) args in
let args = List.map (self#visit_typ false) args in
TApp (lid, args)
| _, chosen_lid ->
TQualified chosen_lid
Expand Down Expand Up @@ -139,15 +150,17 @@ let monomorphize_data_types map = object(self)
(* Visit a given node in the graph, modifying [pending] to append in reverse
* order declarations as they are needed, including that of the node we are
* visiting. *)
method private visit_node (n: node) =
method private visit_node (under_ref: bool) (n: node) =
let lid, args = n in
(* White, gray or black? *)
match Hashtbl.find state n with
| exception Not_found ->
if Options.debug "data-types-traversal" then
KPrint.bprintf "visiting %a<%a>: Not_found\n" plid (fst n) ptyps (snd n);
let chosen_lid, flag = self#lid_of n in
if lid = tuple_lid then begin
Hashtbl.add state n (Gray, chosen_lid);
let args = List.map (self#visit_typ ()) args in
let args = List.map (self#visit_typ under_ref) args in
(* For tuples, we immediately know how to generate a definition. *)
let fields = List.mapi (fun i arg -> Some (self#field_at i), (arg, false)) args in
self#record (DType (chosen_lid, [ Common.Private ] @ flag, 0, Flat fields));
Expand All @@ -161,25 +174,51 @@ let monomorphize_data_types map = object(self)
) fields in
begin match Hashtbl.find map lid with
| exception Not_found ->
()
Hashtbl.replace state n (Black, chosen_lid)
| flags, (Variant _ | Flat _) when under_ref && not (Hashtbl.mem seen_declarations lid) ->
(* Because this looks up a definition in the global map, the
definitions are reordered according to the traversal order, which
is generally a good idea (we accept more programs!), EXCEPT
when the user relies on mutual recursion behind a reference
(pointer) type. In that case, following the type dependency graph is
generally not a good idea, since we may go from a valid
ordering to an invalid one (see tests/MutualStruct.fst). So,
the intent here (i.e., when under a ref type) is that:
- tuple types ALWAYS get monomorphized on-demand (see
above)
- abbreviations are fine and won't cause further issues
- data types, however, need to have their names allocated and a
forward reference inserted (TODO: at most once), then the
specific choice of type arguments need to be recorded as
something we want to visit later (once we're done with this
particular traversal)... *)
if Options.debug "data-types-traversal" then
KPrint.bprintf "DEFERRING %a<%a>\n" plid (fst n) ptyps (snd n);
self#record (DType (chosen_lid, flags, 0, Forward));
Hashtbl.add pending_monomorphizations (fst n) (snd n);
Hashtbl.remove state n
| flags, Variant branches ->
let branches = List.map (fun (cons, fields) -> cons, subst fields) branches in
let branches = self#visit_branches_t () branches in
self#record (DType (chosen_lid, flag @ flags, 0, Variant branches))
let branches = self#visit_branches_t under_ref branches in
self#record (DType (chosen_lid, flag @ flags, 0, Variant branches));
Hashtbl.replace state n (Black, chosen_lid)
| flags, Flat fields ->
let fields = self#visit_fields_t_opt () (subst fields) in
self#record (DType (chosen_lid, flag @ flags, 0, Flat fields))
let fields = self#visit_fields_t_opt under_ref (subst fields) in
self#record (DType (chosen_lid, flag @ flags, 0, Flat fields));
Hashtbl.replace state n (Black, chosen_lid)
| flags, Abbrev t ->
let t = DeBruijn.subst_tn args t in
let t = self#visit_typ () t in
self#record (DType (chosen_lid, flag @ flags, 0, Abbrev t))
let t = self#visit_typ under_ref t in
self#record (DType (chosen_lid, flag @ flags, 0, Abbrev t));
Hashtbl.replace state n (Black, chosen_lid)
| _ ->
()
end;
Hashtbl.replace state n (Black, chosen_lid)
Hashtbl.replace state n (Black, chosen_lid)
end
end;
chosen_lid
| Gray, chosen_lid ->
if Options.debug "data-types-traversal" then
KPrint.bprintf "visiting %a<%a>: Gray\n" plid (fst n) ptyps (snd n);
begin match Hashtbl.find map lid with
| exception Not_found ->
()
Expand All @@ -188,6 +227,8 @@ let monomorphize_data_types map = object(self)
end;
chosen_lid
| Black, chosen_lid ->
if Options.debug "data-types-traversal" then
KPrint.bprintf "visiting %a<%a>: Black\n" plid (fst n) ptyps (snd n);
chosen_lid

(* Top-level, non-parameterized declarations are root of our graph traversal.
Expand All @@ -197,13 +238,16 @@ let monomorphize_data_types map = object(self)
let name, decls = file in
current_file <- name;
name, KList.map_flatten (fun d ->
if Options.debug "data-types-traversal" then
KPrint.bprintf "decl %a\n" plid (lid_of_decl d);
match d with
| DType (lid, _, n, Abbrev (TTuple args)) when n = 0 && not (Hashtbl.mem state (tuple_lid, args)) ->
Hashtbl.remove map lid;
if Options.debug "monomorphization" then
KPrint.bprintf "%a abbreviation for %a\n" plid lid ptyp (TApp (tuple_lid, args));
best_hint <- (tuple_lid, args), lid;
ignore (self#visit_node (tuple_lid, args));
ignore (self#visit_node false (tuple_lid, args));
Hashtbl.add seen_declarations lid ();
self#clear ()

| DType (lid, _, n, Abbrev (TApp (hd, args) as t)) when n = 0 && not (Hashtbl.mem state (hd, args)) ->
Expand All @@ -227,20 +271,28 @@ let monomorphize_data_types map = object(self)
else
best_hint <- (hd, args), lid;

ignore (self#visit_node (hd, args));
ignore (self#visit_node false (hd, args));

(* And a type abbreviation will automatically be rewritten (see
GcTypes) into `typedef foobar foobar_gc *`. And mitlsffi.ci will be
happy. *)
if abbrev_for_gc_type then
self#record (DType (lid, [], 0, Abbrev (TQualified (fst lid, snd lid ^ "_gc"))));

Hashtbl.add seen_declarations lid ();
self#clear ()

| DType (_, _, n, _) when n > 0 ->
(* Can't do anything useful with this, and will not generate further
monomorphizations. Drop. *)
[]
| DType (lid, _, n, _) when n > 0 ->
(* The type itself cannot be monomorphized, but we may have seen in
the past monomorphic instances of this type that we ought to
generate. *)
List.iter (fun args ->
ignore (self#visit_node false (lid, args));
Hashtbl.remove pending_monomorphizations lid
) (Hashtbl.find_all pending_monomorphizations lid);

Hashtbl.add seen_declarations lid ();
self#clear ()

| DType (lid, _, n, (Flat _ | Variant _ | Abbrev _)) ->
(* Re-inserted by visit_node... don't insert twice. *)
Expand All @@ -250,13 +302,15 @@ let monomorphize_data_types map = object(self)
it recursively needs to trigger monomorphizations, and
side-effectfully register the type as visited in our map for
further uses (but why?). *)
ignore (self#visit_decl () d);
ignore (self#visit_node (lid, []));
ignore (self#visit_decl false d);
ignore (self#visit_node false (lid, []));
Hashtbl.add seen_declarations lid ();
self#clear ()

| _ ->
(* An actual run-time definition, needs to be retained. *)
let d = self#visit_decl () d in
let d = self#visit_decl false d in
Hashtbl.add seen_declarations (lid_of_decl d) ();
self#clear () @ [ d ]
) decls

Expand All @@ -266,25 +320,31 @@ let monomorphize_data_types map = object(self)
else
super#visit_DType env name flags n d

method! visit_ETuple _ es =
EFlat (List.mapi (fun i e -> Some (self#field_at i), self#visit_expr_w () e) es)
method! visit_ETuple under_ref es =
EFlat (List.mapi (fun i e -> Some (self#field_at i), self#visit_expr under_ref e) es)

method! visit_PTuple under_ref pats =
PRecord (List.mapi (fun i p -> self#field_at i, self#visit_pattern under_ref p) pats)

method! visit_PTuple _ pats =
PRecord (List.mapi (fun i p -> self#field_at i, self#visit_pattern_w () p) pats)
method! visit_TTuple under_ref ts =
TQualified (self#visit_node under_ref (tuple_lid, ts))

method! visit_TTuple _ ts =
TQualified (self#visit_node (tuple_lid, ts))
method! visit_TQualified under_ref lid =
TQualified (self#visit_node under_ref (lid, []))

method! visit_TQualified _ lid =
TQualified (self#visit_node (lid, []))
method! visit_TApp under_ref lid ts =
TQualified (self#visit_node under_ref (lid, ts))

method! visit_TApp _ lid ts =
TQualified (self#visit_node (lid, ts))
method! visit_TBuf _ t const =
TBuf (self#visit_typ true t, const)
end

let datatypes files =
let map = build_def_map files in
(monomorphize_data_types map)#visit_files () files
let o = monomorphize_data_types map in
let files = o#visit_files false files in
(* o#sanity_check (); *)
files


(* Type monomorphization of functions. ****************************************)
Expand Down
113 changes: 113 additions & 0 deletions test/MutualStruct.fst
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module MutualStruct
open FStar.HyperStack.ST

#set-options "--__no_positivity" // because FStar.HyperStack.ST.ref does not respect positivity

module U64 = FStar.UInt64
module U8 = FStar.UInt8
module SZ = FStar.SizeT

let main () = C.EXIT_SUCCESS // dummy

// SUCCESS
noeq
type object1_tagged = {
object1_tagged_tag: U64.t;
object1_tagged_payload: ref object1;
}
and object1 = {
object1_type: U8.t;
object1_payload: object1_tagged;
}

(*
// FAIL to compile: struct types are generated in the wrong order, leading to the compiler complaining about `object2_tagged` being an incomplete type
// The order of mutually recursive type
// definitions should match that of C, in the sense that types that
// depend on other types only behind `ref` should be defined first.
// So the correct order for `object2` below is `object1` above.
noeq
type object2 = {
object2_type: U8.t;
object2_payload: object2_tagged;
}
and object2_tagged = {
object2_tagged_tag: U64.t;
object2_tagged_payload: ref object2;
}
// FAIL to compile: same here
noeq
type object3 = {
object3_type: U8.t;
object3_map: object3_map;
}
and object3_pair = {
object3_pair_key: object3;
object3_pair_payload: object3;
}
and object3_map = {
object3_map_entry_count: U64.t;
object3_map_payload: ref object3_pair;
}
// The proper order of `object3` above is `object4` below:
*)

noeq
type object4_map = {
object4_map_entry_count: U64.t;
object4_map_payload: ref object4_pair;
}
and object4 = {
object4_type: U8.t;
object4_map: object4_map;
}
and object4_pair = {
object4_pair_key: object4;
object4_pair_payload: object4;
}

(*
// FAIL to compile: incomplete type, this time because the monomorphized type instance for `object6_map (ref object6_pair)` is not generated
noeq
type object6_map ([@@@strictly_positive] param: Type0) = {
object6_map_entry_count: U64.t;
object6_map_payload: param;
}
noeq
type object6 = {
object6_type: U8.t;
object6_map: object6_map (ref object6_pair);
}
and object6_pair = {
object6_pair_key: object6;
object6_pair_payload: object6;
}
*)

// This test extracts. It should compile, but the C compiler complains with object7_pair incomplete because KaRaMeL extracted it too early

noeq
type object7_tagged = {
object7_tagged_tag: U64.t;
object7_tagged_payload: ref object7;
}
and object7_map = {
object7_map_entry_count: U64.t;
object7_map_payload: ref object7_pair;
}
and object7_case = {
object7_case_tagged: object7_tagged;
object7_case_map: object7_map;
}
and object7 = {
object7_type: U8.t;
object7_payload: object7_case;
}
and object7_pair = {
object7_pair_fst: object7;
object7_pair_snd: object7;
}

0 comments on commit bffa630

Please sign in to comment.