Module ObcAddDefaults

From Coq Require Import FSets.FMapPositive.
From Coq Require Import PArith.
From Velus Require Import Common.
From Velus Require Import Environment.
From Velus Require Import Operators.
From Velus Require Import Obc.ObcSyntax.
From Velus Require Import Obc.ObcSemantics.
From Velus Require Import Obc.ObcInvariants.
From Velus Require Import Obc.ObcTyping.
From Velus Require Import Obc.Equiv.

From Velus Require Import VelusMemory.

From Coq Require Import List.
Import List.ListNotations.
Open Scope list_scope.

Import Env.Notations.

From Coq Require Import Morphisms.

Module Type OBCADDDEFAULTS
       (Import Ids : IDS)
       (Import Op : OPERATORS)
       (Import OpAux : OPERATORS_AUX Op)
       (Import SynObc: Velus.Obc.ObcSyntax.OBCSYNTAX Ids Op OpAux)
       (Import SemObc: Velus.Obc.ObcSemantics.OBCSEMANTICS Ids Op OpAux SynObc)
       (Import InvObc: Velus.Obc.ObcInvariants.OBCINVARIANTS Ids Op OpAux SynObc SemObc)
       (Import TypObc: Velus.Obc.ObcTyping.OBCTYPING Ids Op OpAux SynObc SemObc)
       (Import Equ : Velus.Obc.Equiv.EQUIV Ids Op OpAux SynObc SemObc TypObc).

AddDefault functions


  Section AddDefaults.

    Variable type_of_var : ident -> option type.

    Definition add_write x s :=
      match type_of_var x with
      | None => s
      | Some ty => Comp (Assign x (Const (init_type ty))) s
      end.

    Definition add_writes W s := PS.fold add_write W s.

    Definition add_valid (e : exp) (esreq : list exp * PS.t) :=
      match e with
      | Var x ty => (Valid x ty :: fst esreq, PS.add x (snd esreq))
      | _ => (e :: fst esreq, snd esreq)
      end.

    Fixpoint add_defaults_stmt (required: PS.t) (s: stmt) : stmt * PS.t * PS.t * PS.t :=
      match s with
      | Skip => (s, required, PS.empty, PS.empty)
      | Assign x e => (s, PS.remove x required, PS.empty, PS.singleton x)
      | AssignSt x e => (s, required, PS.empty, PS.empty)
      | Call xs f o m es =>
        let (es', required') := fold_right add_valid
                                           ([], ps_removes xs required) es
        in (Call xs f o m es', required', PS.empty, PSP.of_list xs)
      | Comp s1 s2 =>
        let '(t2, required2, sometimes2, always2) := add_defaults_stmt required s2 in
        let '(t1, required1, sometimes1, always1) := add_defaults_stmt required2 s1 in
        (Comp t1 t2,
         required1,
         PS.union (PS.diff sometimes1 always2) (PS.diff sometimes2 always1),
         PS.union always1 always2)
      | Ifte e s1 s2 =>
        let '(t1, required1, sometimes1, always1) := add_defaults_stmt PS.empty s1 in
        let '(t2, required2, sometimes2, always2) := add_defaults_stmt PS.empty s2 in
        let always1_req := PS.inter always1 required in
        let always2_req := PS.inter always2 required in
        let w1 := PS.diff always2_req always1_req in
        let w2 := PS.diff always1_req always2_req in
        let w := PS.union (PS.diff (PS.inter sometimes1 required) w1)
                          (PS.diff (PS.inter sometimes2 required) w2) in
        let always1' := PS.union always1 w1 in
        let always2' := PS.union always2 w2 in
        let sometimes1' := PS.diff sometimes1 w1 in
        let sometimes2' := PS.diff sometimes2 w2 in
        (add_writes w (Ifte e (add_writes w1 t1) (add_writes w2 t2)),
         PS.diff (PS.union
                    (PS.union
                       ((PS.diff (PS.diff required always1_req) always2_req))
                       required1)
                    required2)
                 w,
         PS.diff (PS.union sometimes1'
                           (PS.union sometimes2'
                                     (PS.union
                                        (PS.diff always1' always2')
                                        (PS.diff always2' always1')))) w,
         PS.union (PS.inter always1' always2') w)
      end.

  End AddDefaults.

  Definition add_defaults_method (m: method): method :=
    match m with
      mk_method name ins vars outs body nodup good =>
      mk_method name ins vars outs
         (let tyenv := fun x => Env.find x
                (Env.adds' outs (Env.adds' vars (Env.from_list ins))) in
          let '(body', required, sometimes, always) :=
              add_defaults_stmt tyenv (PSP.of_list (map fst outs)) body in
          add_writes tyenv (ps_removes (map fst ins) required) body')
         nodup good
    end.

  Lemma add_defaults_method_m_name:
    forall m,
      (add_defaults_method m).(m_name) = m.(m_name).
Proof.

  Lemma add_defaults_method_m_in:
    forall m, (add_defaults_method m).(m_in) = m.(m_in).
Proof.

  Lemma add_defaults_method_m_out:
    forall m, (add_defaults_method m).(m_out) = m.(m_out).
Proof.

  Lemma add_defaults_method_m_vars:
    forall m, (add_defaults_method m).(m_vars) = m.(m_vars).
Proof.

  Program Definition add_defaults_class (c: class): class :=
    match c with
      mk_class name mems objs methods nodup nodupm cgood =>
      mk_class name mems objs (map add_defaults_method methods) nodup _ cgood
    end.
Next Obligation.

  Definition add_defaults := map add_defaults_class.

  Lemma find_method_add_defaults_method:
    forall n ms m,
      find_method n ms = Some m ->
      find_method n (map add_defaults_method ms) = Some (add_defaults_method m).
Proof.

  Lemma find_method_map_add_defaults_method':
    forall n ms fm,
      find_method n (map add_defaults_method ms) = Some fm
      -> exists fm',
        find_method n ms = Some fm' /\ fm = add_defaults_method fm'.
Proof.

  Lemma find_method_map_add_defaults_method:
    forall n c,
      find_method n (map add_defaults_method c.(c_methods))
      = find_method n (add_defaults_class c).(c_methods).
Proof.

  Lemma add_defaults_class_c_name:
    forall c, (add_defaults_class c).(c_name) = c.(c_name).
Proof.

  Lemma add_defaults_class_c_objs:
    forall c, (add_defaults_class c).(c_objs) = c.(c_objs).
Proof.

  Lemma add_defaults_class_c_mems:
    forall c, (add_defaults_class c).(c_mems) = c.(c_mems).
Proof.

  Lemma find_class_add_defaults_class:
    forall p n c p',
      find_class n p = Some (c, p') ->
      find_class n (add_defaults p)
      = Some (add_defaults_class c, add_defaults p').
Proof.

  Lemma find_class_add_defaults_class_not_None:
    forall n p,
      find_class n p <> None ->
      find_class n (add_defaults p) <> None.
Proof.

  Notation "x '∈' y" := (PS.In x y) (at level 10).
  Notation "x '∪' y" := (PS.union x y) (at level 11, right associativity).
  Notation "x '∩' y" := (PS.inter x y) (at level 11, right associativity).
  Notation "x '—' y" := (PS.diff x y) (at level 11).

  Ltac PS_split :=
    repeat match goal with
           | H: context [ PS.union _ _ ] |- _ => setoid_rewrite PS.union_spec in H
           | H: context [ ~(PS.inter _ _) ] |- _ => setoid_rewrite PS_not_inter in H
           | H: context [ PS.inter _ _ ] |- _ => setoid_rewrite PS.inter_spec in H
           | H: context [ PS.diff _ _ ] |- _ => setoid_rewrite PS.diff_spec in H
           | H: context [ ~(_ \/ _) ] |- _ => setoid_rewrite not_or' in H
           | H: context [ ~~PS.In _ _ ] |- _ => setoid_rewrite not_not_in in H
           | H:_ /\ _ |- _ => destruct H
           | |- context [ PS.union _ _ ] => setoid_rewrite PS.union_spec
           | |- context [ ~(PS.inter _ _) ] => setoid_rewrite PS_not_inter
           | |- context [ PS.inter _ _ ] => setoid_rewrite PS.inter_spec
           | |- context [ PS.diff _ _ ] => setoid_rewrite PS.diff_spec
           | |- context [ ~(_ \/ _) ] => setoid_rewrite not_or'
           | |- context [ ~~PS.In _ _ ] => setoid_rewrite not_not_in
           end.

  Ltac PS_negate :=
    repeat match goal with
           | H:~(_ /\ _) |- _ => apply Decidable.not_and in H; [|now intuition]
           | H:~~_ |- _ => apply Decidable.not_not in H; [|now intuition]
           | H: context [ ~~PS.In _ _ ] |- _ => setoid_rewrite not_not_in in H
           end.

  Lemma simplify_write_sets:
    forall w w1 w2 al1 al2 st1 st2 rq,
      w1 = (al2rq) — (al1rq) ->
      w2 = (al1rq) — (al2rq) ->
      w = ((st1rq) — w1) ∪ ((st2rq) — w2) ->
      PS.Equal ((((st1w1)
                    ∪ (st2w2)
                    ∪ ((al1w1) — (al2w2))
                    ∪ (al2w2) — (al1w1)) — w)
                  ∪ ((al1w1) ∩ al2w2) ∪ w)
               (ww1w2al1al2st1st2).
Proof.

Basic lemmas around add_defaults_class and add_defaults_method.


  Lemma add_defaults_class_find_method:
    forall f c,
      find_method f (add_defaults_class c).(c_methods)
      = option_map (add_defaults_method) (find_method f c.(c_methods)).
Proof.

  Lemma In_snd_fold_right_add_valid:
    forall x s,
      PS.In x s ->
      forall es xs,
        PS.In x (snd (fold_right add_valid (xs, s) es)).
Proof.

  Definition add_valid' e := match e with Var x ty => Valid x ty | _ => e end.

  Lemma add_valid_add_valid':
    forall es S es',
      fst (fold_right add_valid (es', S) es) = map add_valid' es ++ es'.
Proof.

  Lemma Forall2_exp_eval_refines_with_valid:
    forall me ve1 ve2 es vos,
      ve2ve1 ->
      Forall (fun e => match e with Var x _ => Env.In x ve1 | _ => True end) es ->
      Forall2 (exp_eval me ve2) es vos ->
      exists vos',
        Forall2 (exp_eval me ve1) (map add_valid' es) vos'
        /\ Forall2 (fun vo vo' => forall v, vo = Some v -> vo' = Some v) vos vos'.
Proof.

  Lemma stmt_eval_add_writes_split:
    forall tyenv p s W me ve me'' ve'',
      stmt_eval p me ve (add_writes tyenv W s) (me'', ve'') <->
      (exists me' ve',
          stmt_eval p me ve (add_writes tyenv W Skip) (me', ve')
          /\ stmt_eval p me' ve' s (me'', ve'')).
Proof.

  Lemma No_Naked_Vars_add_writes:
    forall tyenv W s,
      No_Naked_Vars s <-> No_Naked_Vars (add_writes tyenv W s).
Proof.

  Lemma stmt_eval_add_writes:
    forall p,
      (forall ome ome' clsid f vos rvos,
          Forall (fun vo => vo <> None) vos ->
          stmt_call_eval p ome clsid f vos ome' rvos ->
          Forall (fun x => x <> None) rvos) ->
      forall tyenv s W me ve me' ve',
        PS.For_all (fun x => tyenv x <> None) W ->
        No_Naked_Vars s ->
        stmt_eval p me ve (add_writes tyenv W s) (me', ve') ->
        (forall x, PS.In x W -> Env.In x ve').
Proof.

  Lemma stmt_eval_add_writes_Skip_other:
    forall p tyenv W me ve me' ve',
      stmt_eval p me ve (add_writes tyenv W Skip) (me', ve') ->
      forall x, ~PS.In x W ->
                Env.find x ve' = Env.find x ve.
Proof.

  Lemma add_defaults_stmt_inv1:
    forall tyenv s t req req' stimes always,
      add_defaults_stmt tyenv req s = (t, req', stimes, always) ->
      PS.Empty (PS.inter stimes always)
      /\ (forall x, PS.In x req -> PS.In x always \/ PS.In x req')
      /\ (forall x, PS.In x (PS.union stimes always) -> Can_write_in x s)
      /\ (forall x, ~Can_write_in x s -> ~PS.In x (PS.union stimes always))
      /\ No_Naked_Vars t.
Proof.

  Lemma Can_write_in_add_writes_mono:
    forall tyenv s W x,
      Can_write_in x s ->
      Can_write_in x (add_writes tyenv W s).
Proof.

  Lemma Can_write_in_add_writes:
    forall tyenv s W x,
      Can_write_in x (add_writes tyenv W s) ->
      PS.In x W \/ Can_write_in x s.
Proof.

  Lemma Can_write_in_add_defaults_stmt:
    forall tyenv s req t req' st al,
      add_defaults_stmt tyenv req s = (t, req', st, al) ->
      (forall x, Can_write_in x s <-> Can_write_in x t).
Proof.

  Lemma add_defaults_stmt_no_write:
    forall p tyenv s t me me' ve ve' req req' stimes always,
      add_defaults_stmt tyenv req s = (t, req', stimes, always) ->
      stmt_eval p me ve s (me', ve') ->
      forall x, ~PS.In x (PS.union stimes always) ->
                Env.find x ve' = Env.find x ve.
Proof.


  Lemma wt_method_add_defaults:
    forall p insts mem m,
      wt_method p insts mem m ->
      wt_method (add_defaults p) insts mem m.
Proof.

  Section AddDefaultsStmt.

    Variables (p : list class)
              (insts : list (ident * ident))
              (mems : list (ident * type))
              (vars : list (ident * type))
              (tyenv : ident -> option type).

    Hypothesis wf_vars_tyenv:
      (forall x ty, In (x, ty) vars <-> tyenv x = Some ty).

    Lemma wf_vars_tyenv':
      forall x, InMembers x vars <-> tyenv x <> None.
Proof.

    Lemma add_writes_wt':
      forall W s,
        wt_stmt p insts mems vars (add_writes tyenv W s) ->
        wt_stmt p insts mems vars s.
Proof.

    Lemma add_writes_wt:
      forall W s,
        PS.For_all (fun x => tyenv x <> None) W ->
        (wt_stmt p insts mems vars s <->
         wt_stmt p insts mems vars (add_writes tyenv W s)).
Proof.

    Lemma add_defaults_stmt_wt:
      forall s t req req' stimes always,
        add_defaults_stmt tyenv req s = (t, req', stimes, always) ->
        wt_stmt p insts mems vars s ->
        wt_stmt p insts mems vars t
        /\ PS.For_all (fun x => InMembers x vars) stimes
        /\ PS.For_all (fun x => InMembers x vars) always
        /\ PS.For_all (fun x => PS.In x req \/ InMembers x vars) req'.
Proof.


    Lemma add_defaults_stmt_inv2:
      forall s t me me' ve ve' req req' stimes always,
        add_defaults_stmt tyenv req s = (t, req', stimes, always) ->
        stmt_eval p me ve t (me', ve') ->
        wt_stmt p insts mems vars s ->
        (forall x, PS.In x req' -> Env.In x ve) ->
        (forall ome ome' clsid f vos rvos,
            Forall (fun vo => vo <> None) vos ->
            stmt_call_eval p ome clsid f vos ome' rvos ->
            Forall (fun x => x <> None) rvos) ->
        (forall x, ~PS.In x (PS.union stimes always) -> Env.find x ve' = Env.find x ve)
        /\ (forall x, PS.In x always -> Env.In x ve').
Proof.

    Definition in1_notin2 xs1 xs2 (ve1 ve2 : Env.t val) :=
      (forall x, PS.In x xs1 -> Env.In x ve1)
      /\ (forall x, PS.In x xs2 -> ~Env.In x ve2).

    Import Basics.

    Instance in1_notin2_Proper1:
      Proper (PS.Equal ==> PS.Equal ==> Env.refines eq ==> Env.refines eq --> impl)
             in1_notin2.
Proof.

    Instance in1_notin2_Proper2:
      Proper (PS.Equal ==> PS.Equal ==> eq ==> eq ==> iff) in1_notin2.
Proof.

    Lemma in1_notin2_add1:
      forall ve1 ve2 x S1 S2,
        in1_notin2 (PS.add x S1) S2 ve1 ve2 ->
        in1_notin2 S1 S2 ve1 ve2 /\ Env.In x ve1.
Proof.

    Lemma in1_notin2_Var_In:
      forall ve' ve es acc S,
        in1_notin2 (snd (fold_right add_valid acc es)) S ve' ve ->
        Forall (fun e => match e with Var x _ => Env.In x ve' | _ => True end) es.
Proof.

    Lemma in1_notin2_Var_Not_In:
      forall ys s1 ve' ve,
        in1_notin2 s1 (PSP.of_list ys) ve' ve ->
        Forall (fun x => ~ Env.In x ve) ys.
Proof.

    Lemma in1_notin2_infer:
      forall ve1 ve2 S1 S2 Z1 Z2,
        in1_notin2 S1 S2 ve1 ve2 ->
        (forall x, PS.In x Z1 -> PS.In x S1) ->
        (forall x, PS.In x Z2 -> PS.In x S2) ->
        in1_notin2 Z1 Z2 ve1 ve2.
Proof.

    Lemma stmt_eval_add_writes_Skip:
      forall me w ve0' ve0,
        ve0ve0' ->
        PS.For_all (fun x => ~Env.In x ve0) w ->
        PS.For_all (fun x => InMembers x vars) w ->
        exists ve1',
          ve0ve1'
          /\ stmt_eval p me ve0' (add_writes tyenv w Skip) (me, ve1')
          /\ (forall x, Env.In x ve0' -> Env.In x ve1')
          /\ PS.For_all (fun x => Env.In x ve1') w.
Proof.

    Definition all_in1 (xs : list (ident * type)) (ve1 ve2 : Env.t val) :=
      (forall x, InMembers x xs <-> Env.In x ve1)
      /\ (forall x, Env.In x ve2 -> InMembers x xs).

    Lemma add_defaults_stmt_correct:
      forall p' s req t req' st al,
        program_refines (fun _ _ => all_in1) p p' ->
        (forall ome ome' clsid f vos rvos,
            Forall (fun vo => vo <> None) vos ->
            stmt_call_eval p ome clsid f vos ome' rvos ->
            Forall (fun x => x <> None) rvos) ->
        No_Overwrites s ->
        wt_stmt p insts mems vars s ->
        add_defaults_stmt tyenv req s = (t, req', st, al) ->
        stmt_refines p p' (in1_notin2 req' (PS.union st al)) t s.
Proof.

  End AddDefaultsStmt.


  Lemma output_or_local_in_typing_env:
    forall {A} (min mvars mout : list (ident * A)) S,
      NoDupMembers (min ++ mvars ++ mout) ->
      PS.For_all
        (fun x => PS.In x (PSP.of_list (map fst mout)) \/
                  InMembers x (min ++ mvars ++ mout)) S ->
      PS.For_all (fun x => Env.find x (Env.from_list (min ++ mvars ++ mout)) <> None)
                 (ps_removes (map fst min) S).
Proof.

  Lemma stmt_call_eval_add_defaults_class_not_None:
    forall p,
      wt_program p ->
      forall ome ome' clsid f vos rvos,
        Forall (fun vo => vo <> None) vos ->
        stmt_call_eval (add_defaults p) ome clsid f vos ome' rvos ->
        Forall (fun x => x <> None) rvos.
Proof.

  Lemma wt_add_defaults_method:
    forall p objs mems m,
      wt_method p objs mems m ->
      wt_method p objs mems (add_defaults_method m).
Proof.

  Lemma wt_mem_add_defaults:
    forall p c me,
      wt_mem me p c ->
      wt_mem me (add_defaults p) (add_defaults_class c).
Proof.

  Lemma wt_add_defaults_class:
    forall p,
      wt_program p ->
      wt_program (add_defaults p).
Proof.

  Lemma add_defaults_stmt_refines:
    forall p1 p2 insts mems m,
      program_refines (fun _ _ => all_in1) p1 p2 ->
      wt_method p2 insts mems m ->
      No_Overwrites m.(m_body) ->
      Forall (fun x => ~Can_write_in x m.(m_body)) (map fst m.(m_in)) ->
      (forall ome ome' clsid f vos rvos,
          Forall (fun vo => vo <> None) vos ->
          stmt_call_eval p1 ome clsid f vos ome' rvos ->
          Forall (fun x => x <> None) rvos) ->
      stmt_refines p1 p2
                   (in1_notin2 (PSP.of_list (map fst m.(m_in)))
                               (PSP.of_list (map fst (m.(m_out) ++ m.(m_vars)))))
                   (add_defaults_method m).(m_body) m.(m_body).
Proof.

  Lemma No_Naked_Vars_add_defaults_method:
    forall m, No_Naked_Vars (add_defaults_method m).(m_body).
Proof.

  Lemma add_defaults_method_refines:
    forall p1 p2 insts mems m,
      program_refines (fun _ _ => all_in1) p1 p2 ->
      wt_method p2 insts mems m ->
      No_Overwrites m.(m_body) ->
      Forall (fun x => ~ Can_write_in x m.(m_body)) (map fst m.(m_in)) ->
      (forall ome ome' clsid f vos rvos,
          Forall (fun vo => vo <> None) vos ->
          stmt_call_eval p1 ome clsid f vos ome' rvos ->
          Forall (fun x => x <> None) rvos) ->
      method_refines p1 p2 all_in1 (add_defaults_method m) m.
Proof.

  Lemma add_defaults_class_refines:
    forall p1 p2 c,
      program_refines (fun _ _ => all_in1) p1 p2 ->
      wt_class p2 c ->
      Forall (fun m => No_Overwrites m.(m_body)) c.(c_methods) ->
      Forall (fun m => Forall (fun x => ~ Can_write_in x m.(m_body))
                              (map fst m.(m_in))) c.(c_methods) ->
      (forall ome ome' clsid f vos rvos,
          Forall (fun vo => vo <> None) vos ->
          stmt_call_eval p1 ome clsid f vos ome' rvos ->
          Forall (fun x => x <> None) rvos) ->
      class_refines p1 p2 (fun _ => all_in1) (add_defaults_class c) c.
Proof.

  Lemma add_defaults_program_refines:
    forall p,
      wt_program p ->
      Forall_methods (fun m => No_Overwrites m.(m_body)) p ->
      Forall_methods (fun m => Forall (fun x => ~ Can_write_in x m.(m_body))
                                      (map fst m.(m_in))) p ->
      program_refines (fun _ _ => all_in1) (add_defaults p) p.
Proof.

  Lemma No_Naked_Vars_add_defaults_class:
    forall p,
      Forall_methods (fun m => No_Naked_Vars m.(m_body)) (add_defaults p).
Proof.

  Theorem stmt_call_eval_add_defaults:
    forall p me f m vs me' rvs,
      wt_program p ->
      Forall_methods (fun m => No_Overwrites m.(m_body)) p ->
      Forall_methods (fun m => Forall (fun x => ~ Can_write_in x m.(m_body))
                                      (map fst m.(m_in))) p ->
      stmt_call_eval p me f m (map Some vs) me' (map Some rvs) ->
      stmt_call_eval (add_defaults p) me f m (map Some vs) me' (map Some rvs).
Proof.

  Corollary loop_call_add_defaults:
    forall f c ins outs p me,
      wt_program p ->
      Forall_methods (fun m => No_Overwrites m.(m_body)) p ->
      Forall_methods (fun m => Forall (fun x => ~ Can_write_in x m.(m_body))
                                      (map fst m.(m_in))) p ->
      loop_call p c f (fun n => map Some (ins n)) (fun n => map Some (outs n)) 0 me ->
      loop_call (add_defaults p) c f (fun n => map Some (ins n)) (fun n => map Some (outs n)) 0 me.
Proof.

End OBCADDDEFAULTS.

Module ObcAddDefaultsFun
       (Import Ids : IDS)
       (Import Op : OPERATORS)
       (Import OpAux : OPERATORS_AUX Op)
       (Import SynObc: Velus.Obc.ObcSyntax.OBCSYNTAX Ids Op OpAux)
       (Import SemObc: Velus.Obc.ObcSemantics.OBCSEMANTICS Ids Op OpAux SynObc)
       (Import InvObc: Velus.Obc.ObcInvariants.OBCINVARIANTS Ids Op OpAux SynObc SemObc)
       (Import TypObc: Velus.Obc.ObcTyping.OBCTYPING Ids Op OpAux SynObc SemObc)
       (Import Equ : Velus.Obc.Equiv.EQUIV Ids Op OpAux SynObc SemObc TypObc)
       <: OBCADDDEFAULTS Ids Op OpAux SynObc SemObc InvObc TypObc Equ.

  Include OBCADDDEFAULTS Ids Op OpAux SynObc SemObc InvObc TypObc Equ.

End ObcAddDefaultsFun.