Skip to content

Commit

Permalink
Merge pull request formal-land#196 from pedrotst/grab_existentials
Browse files Browse the repository at this point in the history
Disable grab of existentials
  • Loading branch information
clarus authored Oct 27, 2021
2 parents 53ea6b8 + d728028 commit 7ba3143
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 73 deletions.
47 changes: 47 additions & 0 deletions doc/docs/attributes.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,53 @@ Inductive gadt_list : Set :=

One possible reason to force a type to be a GADT is to make sure that all the inductive types in a mutually recursive type definition have the same (zero) arity, as it is expected by Coq.

## coq_grab_existentials
When translating terms that mentions existential variables it might be necessary to make that existential variable explicit.
To achieve this we use the `[@coq_grab_existentials]` attribute. Here is an example:

```ocaml
type wrap1 =
| Cw1 : ('a -> 'b) -> wrap1
type wrap2 =
| Cw2 : ('a -> 'a) -> wrap2
let w2_of_w1 (w : wrap2) : wrap1 =
match [@coq_grab_existentials]w with
| Cw2 f ->
Cw1 (fun y -> f y)
```

Notice that the type of `inj` is `'a -> 'a` for some existential variable `'a`.
Since `coq-of-ocaml` always generates fully anotated code, we need to explicitely
name `'a` in order to properly anotate the type of `y` in the body of `Cw1`.
This gives us the following translation:

```coq
Inductive wrap1 : Set :=
| Cw1 : forall {a b : Set}, (a -> b) -> wrap1.
Inductive wrap2 : Set :=
| Cw2 : forall {a : Set}, (a -> a) -> wrap2.
Definition w2_of_w1 (w : wrap2) : wrap1 :=
let 'Cw2 f := w in
let 'existT _ __Cw2_'a f as exi :=
existT (A := Set) (fun __Cw2_'a => __Cw2_'a -> __Cw2_'a) _ f
return
let fst := projT1 exi in
let __Cw2_'a := fst in
wrap1 in
Cw1 (fun (y : __Cw2_'a) => f y).
```

In the coq side we use an `existT` to grab these existential variables. The key
here is that this allows us to explicitely name `'a` as `__Cw2_'a`.

The return clause is used to bind this new name in the return type of the term
that is being built, in this example it wouldn't be necessary but we generate
the same code for a simpler boilerplate.

## coq_implicit
The `[@coq_implicit "(A := ...)"]` attribute adds an arbitrary annotation on an OCaml identifier or constructor. We typically use this attribute to help Coq to infer implicit types where there is an ambiguity:
```ocaml
Expand Down
12 changes: 6 additions & 6 deletions src/attribute.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ open Monad.Notations
type t =
| AxiomWithReason
| Cast
| DisableExistential
| ForceGadt
| GrabExistentials
| TaggedGadt
| Implicit of string
| MatchGadt
Expand Down Expand Up @@ -59,8 +59,8 @@ let of_attributes (attributes : Typedtree.attributes) : t list Monad.t =
let* _ = of_payload_string error_message id attr_payload in
return (Some AxiomWithReason)
| "coq_cast" -> return (Some Cast)
| "coq_disable_existential" -> return (Some DisableExistential)
| "coq_force_gadt" -> return (Some ForceGadt)
| "coq_grab_existentials" -> return (Some GrabExistentials)
| "coq_tag_gadt" -> return (Some TaggedGadt)
| "coq_implicit" ->
let error_message =
Expand Down Expand Up @@ -94,15 +94,15 @@ let has_cast (attributes : t list) : bool =
| _ -> false
)

let has_disable_existential (attributes : t list) : bool =
let has_force_gadt (attributes : t list) : bool =
attributes |> List.exists (function
| DisableExistential -> true
| ForceGadt -> true
| _ -> false
)

let has_force_gadt (attributes : t list) : bool =
let has_grab_existentials (attributes : t list) : bool =
attributes |> List.exists (function
| ForceGadt -> true
| GrabExistentials -> true
| _ -> false
)

Expand Down
25 changes: 12 additions & 13 deletions src/exp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type match_existential_cast = {
return_typ : Type.t;
use_axioms : bool;
cast_result : bool;
disable : bool
enable : bool
}

type dependent_pattern_match = {
Expand Down Expand Up @@ -313,9 +313,9 @@ let rec of_expression (typ_vars : Name.t Name.Map.t) (e : expression)
let is_tagged_match = Attribute.has_tagged_match attributes in
let do_cast_results = Attribute.has_match_gadt_with_result attributes in
let is_with_default_case = Attribute.has_match_with_default attributes in
let is_disable_existentials = Attribute.has_disable_existential attributes in
let is_grab_existentials = Attribute.has_grab_existentials attributes in
let* (x, typ, e) =
open_cases typ_vars cases is_gadt_match is_tagged_match do_cast_results is_with_default_case is_disable_existentials in
open_cases typ_vars cases is_gadt_match is_tagged_match do_cast_results is_with_default_case is_grab_existentials in
return (Function (x, typ, e))
| Texp_apply (e_f, e_xs) ->
of_expression typ_vars e_f >>= fun e_f ->
Expand Down Expand Up @@ -437,9 +437,9 @@ let rec of_expression (typ_vars : Name.t Name.Map.t) (e : expression)
let is_tagged_match = Attribute.has_tagged_match attributes in
let do_cast_results = Attribute.has_match_gadt_with_result attributes in
let is_with_default_case = Attribute.has_match_with_default attributes in
let is_disable_existential = Attribute.has_disable_existential attributes in
let is_grab_existential = Attribute.has_grab_existentials attributes in
let* e = of_expression typ_vars e in
of_match typ_vars e cases is_gadt_match is_tagged_match do_cast_results is_with_default_case is_disable_existential
of_match typ_vars e cases is_gadt_match is_tagged_match do_cast_results is_with_default_case is_grab_existential
| Texp_tuple es ->
Monad.List.map (of_expression typ_vars) es >>= fun es ->
return (Tuple es)
Expand Down Expand Up @@ -723,7 +723,7 @@ and of_match :
type k . Name.t Name.Map.t -> t -> k case list -> bool -> bool -> bool ->
bool -> bool -> t Monad.t =
fun typ_vars e cases is_gadt_match is_tagged_match do_cast_results
is_with_default_case is_disable_existential ->
is_with_default_case is_grab_existentials ->
let is_extensible_type_match =
cases |>
List.map (fun { c_lhs; _ } -> c_lhs) |>
Expand Down Expand Up @@ -803,7 +803,7 @@ and of_match :
return_typ = typ;
use_axioms = is_gadt_match;
cast_result = do_cast_results;
disable = is_disable_existential;
enable = is_grab_existentials;
} in

begin match c_guard with
Expand Down Expand Up @@ -904,7 +904,7 @@ and open_cases
(is_tagged_match : bool)
(do_cast_results : bool)
(is_with_default_case : bool)
(is_disable_existential: bool)
(is_grab_existentials: bool)
: (Name.t * Type.t option * t) Monad.t =
let name = Name.FunctionParameter in
let* typ =
Expand All @@ -916,7 +916,7 @@ and open_cases
let e = Variable (MixedPath.of_name name, []) in
let* e =
of_match
typ_vars e cases is_gadt_match is_tagged_match do_cast_results is_with_default_case is_disable_existential in
typ_vars e cases is_gadt_match is_tagged_match do_cast_results is_with_default_case is_grab_existentials in
return (name, typ, e)

and import_let_fun
Expand Down Expand Up @@ -1773,7 +1773,7 @@ and to_coq_cast_existentials
| _ -> to_coq false e in
match existential_cast with
| None -> e
| Some { new_typ_vars; bound_vars; use_axioms; return_typ; disable; _ } ->
| Some { new_typ_vars; bound_vars; use_axioms; return_typ; enable; _ } ->
let variable_names =
Pp.primitive_tuple (bound_vars |> List.map (fun (name, _) ->
Name.to_coq name
Expand All @@ -1787,9 +1787,8 @@ and to_coq_cast_existentials
Pp.primitive_tuple_type (bound_vars |> List.map (fun (_, typ) ->
Type.to_coq None None typ
)) in
begin match (disable, bound_vars, new_typ_vars) with
| (true, _, _) -> e
| (_, [], _) -> e
begin match (enable, bound_vars, new_typ_vars) with
| (false, _, _) | (_, [], _) -> e
| (_, _, []) ->
if use_axioms then
let variable_names_pattern =
Expand Down
8 changes: 1 addition & 7 deletions tests/function_with_named_parameters.v
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@ Require Import CoqOfOCaml.Settings.

Definition option_value {a : Set} (x : option a) (default : a) : a :=
match x with
| Some x =>
let 'existT _ a x as exi := existT (A := Set) (fun a => a) _ x
return
let fst := projT1 exi in
let a := fst in
a in
x
| Some x => x
| None => default
end.

Expand Down
8 changes: 0 additions & 8 deletions tests/gadts2.v
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,4 @@ Fixpoint interp {a : vtag} (t : Test.term a) : int :=
Test.term.T_Pair.x2 := x2;
Test.term.T_Pair.x3 := x3
|} := t in
let 'existT _ [__1, __0] [x3, x2, x1] as exi :=
existT (A := [Set ** Set]) (fun '[__1, __0] => [__1 ** __0 ** int]) [_, _]
[x3, x2, x1]
return
let fst := projT1 exi in
let __0 := Primitive.snd fst in
let __1 := Primitive.fst fst in
int in
x1.
13 changes: 1 addition & 12 deletions tests/gadts_existential.v
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,6 @@ End wrapper.

Definition unwrap (w1 : wrapper) (w2 : wrapper) : int :=
match (w1, w2) with
| (W_exp e1, W_term e2) =>
let 'existT _ [__W_term_'kind, __W_exp_'kind] [e2, e1] as exi :=
existT (A := [vtag ** vtag])
(fun '[__W_term_'kind, __W_exp_'kind] =>
[wrapper.W_term __W_term_'kind ** wrapper.W_exp __W_exp_'kind]) [_, _]
[e2, e1]
return
let fst := projT1 exi in
let __W_exp_'kind := Primitive.snd fst in
let __W_term_'kind := Primitive.fst fst in
int in
2
| (W_exp e1, W_term e2) => 2
| _ => 4
end.
22 changes: 2 additions & 20 deletions tests/gadts_record.v
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,6 @@ Fixpoint interp {a : vtag} (function_parameter : term a) : decode_vtag a :=
match function_parameter with
| T_Int n => n
| T_String s => s
| T_Pair p1 p2 =>
let 'existT _ [__1, __0] [p2, p1] as exi :=
existT (A := [vtag ** vtag]) (fun '[__1, __0] => [term __1 ** term __0])
[_, _] [p2, p1]
return
let fst := projT1 exi in
let __0 := Primitive.snd fst in
let __1 := Primitive.fst fst in
decode_vtag __0 * decode_vtag __1 in
((interp p1), (interp p2))
| T_Rec {| term.T_Rec.x := x; term.T_Rec.y := y |} =>
let 'existT _ [__3, __2] [y, x] as exi :=
existT (A := [Set ** vtag]) (fun '[__3, __2] => [__3 ** term __2]) [_, _]
[y, x]
return
let fst := projT1 exi in
let __2 := Primitive.snd fst in
let __3 := Primitive.fst fst in
decode_vtag __2 * __3 in
((interp x), y)
| T_Pair p1 p2 => ((interp p1), (interp p2))
| T_Rec {| term.T_Rec.x := x; term.T_Rec.y := y |} => ((interp x), y)
end.
8 changes: 1 addition & 7 deletions tests/impredicative_set.v
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,5 @@ Inductive t : Set :=
Fixpoint t_of_list {a : Set} (l : list a) : t :=
match l with
| [] => Empty
| cons _ l =>
let 'existT _ a l as exi := existT (A := Set) (fun a => list a) _ l
return
let fst := projT1 exi in
let a := fst in
t in
Node (t_of_list l)
| cons _ l => Node (t_of_list l)
end.

0 comments on commit 7ba3143

Please sign in to comment.