From Coq Require Import FSets.FMapPositive.
From Coq Require Import PArith.
From Velus Require Import Common.
From Velus Require Import CommonTyping.
From Velus Require Import Operators.
From Velus Require Import Obc.ObcSyntax.
From Velus Require Import Obc.ObcSemantics.
From Velus Require Import Obc.ObcInvariants.
From Velus Require Import Obc.ObcTyping.
From Velus Require Import Obc.Equiv.
From Velus Require Import VelusMemory.
From Coq Require Import List.
Import List.ListNotations.
Open Scope list_scope.
This pass rewrites switches that do not have a default branch to switches
where the last branch is a default branch
Module Type OBCSWITCHESNORMALIZATION
(
Import Ids :
IDS)
(
Import Op :
OPERATORS)
(
Import OpAux :
OPERATORS_AUX Ids Op)
(
Import SynObc:
OBCSYNTAX Ids Op OpAux)
(
Import ComTyp:
COMMONTYPING Ids Op OpAux)
(
Import SemObc:
OBCSEMANTICS Ids Op OpAux SynObc)
(
Import InvObc:
OBCINVARIANTS Ids Op OpAux SynObc SemObc)
(
Import TypObc:
OBCTYPING Ids Op OpAux SynObc ComTyp SemObc)
(
Import Equ :
EQUIV Ids Op OpAux SynObc ComTyp SemObc TypObc).
Import List.
Definition has_default_branch:
list (
option stmt) ->
bool :=
existsb (
or_default_with true (
fun _ =>
false)).
Fixpoint normalize_branches (
branches:
list (
option stmt)) : (
list (
option stmt) *
option stmt):=
match branches with
| [] => ([],
None)
| [
os] => ([
None],
os)
|
os ::
brs =>
let (
branches,
lst) :=
normalize_branches brs in
(
os ::
branches,
lst)
end.
Definition normalize_switch (
branches:
list (
option stmt)) (
default:
stmt) : (
list (
option stmt) *
stmt) :=
if has_default_branch branches
then (
branches,
default)
else
let (
branches,
lst) :=
normalize_branches branches in
(
branches,
or_default default lst).
Fixpoint normalize_stmt (
s:
stmt):
stmt :=
match s with
|
Switch e branches default =>
let (
branches,
default) :=
normalize_switch (
map (
option_map normalize_stmt)
branches)
(
normalize_stmt default)
in
Switch e branches default
|
Comp s1 s2 =>
Comp (
normalize_stmt s1) (
normalize_stmt s2)
|
_ =>
s
end.
Definition normalize_method (
m:
method) :
method :=
match m with
mk_method name ins vars out body nodup good =>
mk_method name ins vars out (
normalize_stmt body)
nodup good
end.
Lemma map_m_name_normalize_method:
forall methods,
map m_name (
map normalize_method methods) =
map m_name methods.
Proof.
intro ms; induction ms as [|m ms]; auto.
simpl. rewrite IHms.
now destruct m.
Qed.
Program Definition normalize_class (
c:
class) :
class :=
match c with
mk_class name mems objs methods nodup _ _ =>
mk_class name mems objs (
map normalize_method methods)
nodup _ _
end.
Next Obligation.
Global Program Instance normalize_class_transform_unit:
TransformUnit class class :=
{
transform_unit :=
normalize_class }.
Next Obligation.
Global Program Instance normalize_class_transform_state_unit:
TransformStateUnit class class.
Next Obligation.
Definition normalize_switches :
program ->
program :=
transform_units.
Lemma has_default_branch_true:
forall branches,
has_default_branch branches =
true <->
Exists (
fun os =>
os =
None)
branches.
Proof.
Corollary has_default_branch_false:
forall branches,
has_default_branch branches =
false <->
Forall (
fun os =>
os <>
None)
branches.
Proof.
Lemma normalize_branches_spec:
forall branches,
let (
branches',
default) :=
normalize_branches branches in
match branches with
| [] =>
branches' = []
|
_ =>
branches' =
removelast branches ++ [
None]
end
/\
default =
last branches None.
Proof.
induction branches as [|? []]; simpl in *; auto; cases.
Qed.
Lemma normalize_stmt_eq:
forall s,
stmt_eval_eq (
normalize_stmt s)
s.
Proof.
induction s using stmt_ind2';
simpl;
try reflexivity.
-
unfold normalize_switch.
destruct (
has_default_branch (
map (
Datatypes.option_map normalize_stmt)
ss))
eqn:
E.
+
apply stmt_eval_eq_Switch_Proper;
eauto.
apply Forall2_map_1,
Forall2_same,
Forall_forall.
intros os ?;
take (
Forall _ _)
and eapply Forall_forall in it;
eauto.
destruct os;
simpl in *;
constructor;
auto.
+
etransitivity.
2: {
apply stmt_eval_eq_Switch_Proper;
eauto.
apply Forall2_same,
Forall_forall;
auto with datatypes.
}
apply has_default_branch_false in E.
pose proof (
normalize_branches_spec (
map (
Datatypes.option_map normalize_stmt)
ss))
as Norm.
destruct (
normalize_branches (
map (
Datatypes.option_map normalize_stmt)
ss)).
destruct Norm;
subst.
Opaque removelast last.
split;
inversion_clear 1
as [| | | | |???????????
Nth|];
destruct ss;
simpl in *;
subst;
eauto using stmt_eval;
rewrite <-
map_cons,
Forall_map in E.
*
destruct (
Compare_dec.lt_eq_lt_dec c (
length ss))
as [[|]|].
--
rewrite nth_error_app1,
nth_error_removelast, <-
map_cons,
map_nth_error'
in Nth.
++
apply option_map_inv in Nth as (
os &?&?);
subst.
econstructor;
eauto.
take (
nth_error _ _ =
_)
and apply nth_error_In in it.
repeat (
take (
Forall _ _)
and eapply Forall_forall in it;
eauto).
destruct os;
simpl in *;
try contradiction.
now apply it.
++
simpl;
rewrite map_length;
lia.
++
now rewrite length_removelast_cons,
map_length.
--
rewrite nth_error_app3 in Nth;
inv Nth;
simpl in *.
++
econstructor.
**
eauto.
**
erewrite app_removelast_last with (
l :=
o ::
ss) (
d :=
None);
try discriminate.
rewrite nth_error_app3;
eauto.
now rewrite length_removelast_cons.
**
take (
stmt_eval _ _ _ _ _)
and rewrite <-
map_cons in it.
change None with (
option_map normalize_stmt None)
in it.
rewrite CommonList.map_last in it.
assert (
In (
last (
o ::
ss)
None) (
o ::
ss))
by apply last_In_cons.
repeat (
take (
Forall _ _)
and eapply Forall_forall in it;
eauto).
destruct (
last (
o ::
ss)
None);
simpl in *;
try contradiction.
now apply it.
++
now rewrite length_removelast_cons,
map_length.
--
contradict Nth.
apply not_Some_is_None,
nth_error_None.
rewrite app_length,
length_removelast_cons,
map_length;
simpl;
lia.
*
pose proof Nth as Hin;
apply nth_error_In in Hin.
repeat (
take (
Forall _ _)
and eapply Forall_forall in it;
eauto).
destruct (
Compare_dec.lt_eq_lt_dec c (
length ss))
as [[|]|].
--
econstructor;
eauto.
++
rewrite nth_error_app1,
nth_error_removelast, <-
map_cons,
map_nth_error',
Nth;
simpl;
eauto.
**
rewrite map_length;
lia.
**
now rewrite length_removelast_cons,
map_length.
++
destruct s0;
simpl in *;
try contradiction.
now apply it.
--
rewrite nth_error_last with (
d :=
None)
in Nth;
auto.
inv Nth.
econstructor;
eauto.
++
rewrite nth_error_app3;
eauto.
now rewrite length_removelast_cons,
map_length.
++
simpl.
rewrite <-
map_cons.
change None with (
option_map normalize_stmt None).
rewrite CommonList.map_last.
destruct (
last (
o ::
ss)
None);
simpl in *;
try contradiction.
now apply it.
--
contradict Nth.
apply not_Some_is_None,
nth_error_None.
simpl;
lia.
-
now rewrite IHs1,
IHs2.
Qed.
Lemma normalize_switches_find_class:
forall p id c p',
find_class id p =
Some (
c,
p') ->
find_class id (
normalize_switches p) =
Some (
normalize_class c,
normalize_switches p').
Proof.
Lemma normalize_class_c_objs:
forall c,
(
normalize_class c).(
c_objs) =
c.(
c_objs).
Proof.
Lemma normalize_method_m_name:
forall m, (
normalize_method m).(
m_name) =
m.(
m_name).
Proof.
destruct m; auto. Qed.
Lemma normalize_method_in:
forall m, (
normalize_method m).(
m_in) =
m.(
m_in).
Proof.
destruct m; auto. Qed.
Lemma normalize_method_out:
forall m, (
normalize_method m).(
m_out) =
m.(
m_out).
Proof.
destruct m; auto. Qed.
Lemma normalize_method_body:
forall fm,
(
normalize_method fm).(
m_body) =
normalize_stmt fm.(
m_body).
Proof.
now destruct fm.
Qed.
Lemma normalize_switches_find_method:
forall f fm cls,
find_method f cls.(
c_methods) =
Some fm ->
find_method f (
normalize_class cls).(
c_methods) =
Some (
normalize_method fm).
Proof.
Lemma normalize_switches_call:
forall p n me me'
f xss rs,
stmt_call_eval p me n f xss me'
rs ->
stmt_call_eval (
normalize_switches p)
me n f xss me'
rs.
Proof.
Corollary normalize_switches_loop_call:
forall f c ins outs prog me,
loop_call prog c f ins outs 0
me ->
loop_call (
normalize_switches prog)
c f ins outs 0
me.
Proof.
intros ?????;
generalize 0%
nat.
cofix COINDHYP.
intros n me Hdo.
destruct Hdo.
econstructor;
eauto using normalize_switches_call.
Qed.
Switches normalization preserves well-typing.
Lemma wt_exp_normalize_switches:
forall p Γ
m Γ
v e,
wt_exp p Γ
m Γ
v e ->
wt_exp (
normalize_switches p) Γ
m Γ
v e.
Proof.
induction e;
inversion_clear 1;
eauto using wt_exp.
Qed.
Lemma wt_stmt_normalize_switches:
forall p insts Γ
m Γ
v s,
wt_stmt p insts Γ
m Γ
v s ->
wt_stmt (
normalize_switches p)
insts Γ
m Γ
v s.
Proof.
Lemma wt_normalize_stmt:
forall p insts Γ
m Γ
v s,
wt_stmt p insts Γ
m Γ
v s ->
wt_stmt p insts Γ
m Γ
v (
normalize_stmt s).
Proof.
Lemma meth_vars_normalize_method:
forall m,
meth_vars (
normalize_method m) =
meth_vars m.
Proof.
Lemma normalize_switches_wt_program:
forall p,
wt_program p ->
wt_program (
normalize_switches p).
Proof.
Lemma normalize_switches_wt_memory:
forall me p c,
wt_memory me p c.(
c_mems)
c.(
c_objs) ->
wt_memory me (
normalize_switches p) (
normalize_class c).(
c_mems) (
normalize_class c).(
c_objs).
Proof.
intros * WT.
pose proof transform_units_wt_memory' as Spec; simpl in Spec.
apply Spec in WT; auto.
Qed.
Switches normalization preserves Can_write_in.
Lemma Can_write_in_var_normalize_stmt:
forall s x ,
Can_write_in_var x s <->
Can_write_in_var x (
normalize_stmt s).
Proof.
Lemma normalize_switches_cannot_write_inputs:
forall p,
wt_program p ->
Forall_methods (
fun m =>
Forall (
fun x => ~
Can_write_in_var x (
m_body m)) (
map fst (
m_in m)))
p ->
Forall_methods (
fun m =>
Forall (
fun x => ~
Can_write_in_var x (
m_body m)) (
map fst (
m_in m)))
(
normalize_switches p).
Proof.
Switches normalization preserves No_Overwrites.
Lemma No_Overwrites_normalize_stmt:
forall s,
No_Overwrites s ->
No_Overwrites (
normalize_stmt s).
Proof.
Lemma normalize_switches_No_Overwrites:
forall p,
wt_program p ->
Forall_methods (
fun m =>
No_Overwrites (
m_body m))
p ->
Forall_methods (
fun m =>
No_Overwrites (
m_body m)) (
normalize_switches p).
Proof.
End OBCSWITCHESNORMALIZATION.
Module ObcSwitchesNormalizationFun
(
Ids :
IDS)
(
Op :
OPERATORS)
(
OpAux :
OPERATORS_AUX Ids Op)
(
SynObc:
OBCSYNTAX Ids Op OpAux)
(
ComTyp:
COMMONTYPING Ids Op OpAux)
(
SemObc:
OBCSEMANTICS Ids Op OpAux SynObc)
(
InvObc:
OBCINVARIANTS Ids Op OpAux SynObc SemObc)
(
TypObc:
OBCTYPING Ids Op OpAux SynObc ComTyp SemObc)
(
Equ :
EQUIV Ids Op OpAux SynObc ComTyp SemObc TypObc)
<:
OBCSWITCHESNORMALIZATION Ids Op OpAux SynObc ComTyp SemObc InvObc TypObc Equ.
Include OBCSWITCHESNORMALIZATION Ids Op OpAux SynObc ComTyp SemObc InvObc TypObc Equ.
End ObcSwitchesNormalizationFun.