From mathcomp Require Import all_ssreflect.
Require Import word.

(* An efficient implementation of `word n -> V`. *)
Fixpoint memory n (V : Type) : Type := match n with
    | 0 => V
    | n.+1 => memory n (V * V)
end.

Definition get {n V} : memory n V -> (word n -> V).
Proof.
    move: V; elim: n.
    - move=> V v _. exact v.
    - move=> n recurse V m /popLow [b w].
        move: (recurse (V*V)%type m w) => [x y].
        exact (if b then y else x).
Defined.

Definition fromFun {n V} : (word n -> V) -> memory n V.
Proof.
    move: V; elim: n.
    - move=> V f. exact (f trivial_word).
    - move=> n recurse V f.
        apply recurse.
        move=> w.
        split; apply f.
        + exact (pushLow (false, w)).
        + exact (pushLow (true, w)).
Defined.

Lemma fromFun_respect_eq1 {n V f g} : f =1 g -> fromFun f = fromFun (n := n) (V := V) g.
Proof.
    move: V f g; elim: n.
    - simpl. done.
    - move=> n IH V f g eq. apply IH => w. by rewrite eq eq.
Qed.

Lemma getK {n V} : cancel get (fromFun (n := n) (V := V)).
Proof.
    move: V; elim: n.
    - done.
    - move=> n IH V m.
        rewrite -{2}(IH _ m).
        simpl.
        apply: fromFun_respect_eq1 => w.
        rewrite pushLowK pushLowK.
        by apply: injective_projections.
Qed.

Lemma fromFunK {n V} : forall f, get (fromFun (n := n) (V := V) f) =1 f.
Proof.
    move: V; elim: n.
    - move=> V f w. by rewrite word0.
    - move=> n IH V f w.
        simpl.
        have: popLow w = popLow w by []; move: {-1}(popLow w) => [b w'] eq.
        rewrite IH.
        have: b = b by []; case: {-1}b => <-; move: eq => <-; by rewrite popLowK.
Qed.


Definition update {n V} : memory n V -> (V -> V) -> word n -> memory n V.
Proof.
    move: V; elim: n.
    - move=> V m f _. exact (f m).
    - move=> n recurse V m f /popLow [b w].
        move: (fun '(x,y) => if b then (x, f y) else (f x, y)) => f'.
        exact (recurse (V*V)%type m f' w).
Defined.

Lemma update_spec {n V} (f : V -> V) (w w' : word n) m
    : get (update m f w') w = if w == w' then f (get m w) else get m w.
Proof.
    move: V f w w' m; elim: n.
    - move=> V f w w' m. by rewrite word0 word0.
    - move=> n IH V f w w' m.
        rewrite -{2}[w]popLowK -{2}[w']popLowK.
        simpl.
        move: (popLow w) (popLow w'); clear w w'; move=> [b w] [b' w'].
        rewrite IH.

        have->: (pushLow (b, w) == pushLow (b', w')) = andb (b == b') (w == w')
            by rewrite /pushLow; do 3 (rewrite /eq_op; simpl).
        
        move: (get (n := n) m w) => [x y].
        by case: (w == w'); case: b; case: b'.
Qed.



Lemma shift_pair {n V} : (memory n (V * V) = memory n V * memory n V)%type.
Proof.
    move: V.
    elim n; first done.
    move=> n' IH V.
    apply IH.
Defined.

Definition from_bytes (n : nat) {A V} (convert : A -> V) :
    nat_rect
        (fun _ => Type -> Type)
        (fun T => A -> T)
        (fun _ f T => f (f T))
        n
        (memory n V).
Proof.
    set X := memory n V.
    have: memory n V -> X by [].
    move: X.

    elim n; simpl.
    - by move=> X f /convert /f.
    - move=> n0 recurse X pair.
        apply: (recurse) => mem1.
        apply: recurse => mem2.
        rewrite shift_pair in pair.
        exact: pair (mem1, mem2).
Defined.

(* How do we know that `from_bytes` puts the bytes in the right order?
    I don't know how to phrase the theorem. But here's an example:
*)

Goal from_bytes 3 (fun x => x) 0 1 2 3 4 5 6 7 = (((0,1),(2,3)),((4,5),(6,7))).
Proof. compute. reflexivity. Qed.

Goal get (from_bytes 3 (fun x => x) 0 1 2 3 4 5 6 7) (word_of_nat 6) = 6.
Proof.
    have -> : from_bytes 3 (fun x => x) 0 1 2 3 4 5 6 7 = (((0,1),(2,3)),((4,5),(6,7))) by compute.
    simpl.
    (* Actually, I don't know how to evaluate this cleanly. *)
Abort.