From mathcomp Require Import all_ssreflect.

Section Basics.
    (* Design decision: Representation
        - n.-tuple bool
        - ffun 'I_n -> bool
        - 'I_(2^n)
            - Pros: Ring structure already defined, for n > 0.
            - Cons: Performance.
    *)
    Inductive word n := word_of_tuple of n.-tuple bool.
    Arguments word_of_tuple {n}.
    Definition tuple_of_word {n} : word n -> n.-tuple bool := fun '(word_of_tuple w) => w.

    Lemma word_of_tupleK {n} : cancel word_of_tuple (tuple_of_word (n := n)). done. Qed.
    Lemma tuple_of_wordK {n} : cancel tuple_of_word (word_of_tuple (n := n)). by case. Qed.

    (* Design decision: Argument order.
        Advantage of this order: The function is invertible.
        Advantage of other order: Can say `bit 7` as a function
    *)
    Definition bit {n} (w : word n) (i : 'I_n) : bool := tnth (tuple_of_word w) i.
    Definition fromBits {n} (f : 'I_n -> bool) : word n := word_of_tuple [tuple f i | i < n].

    Lemma fromBitsK {n} : forall f : 'I_n -> bool, bit (fromBits f) =1 f.
        by move=> f i; rewrite /bit /fromBits word_of_tupleK tnth_mktuple.
    Qed.

    Lemma wordP {n} {a b : word n} : (forall i, bit a i = bit b i) <-> (a = b).
    Proof.
        split; last by move=>->.
        move=> H.
        rewrite -[a]tuple_of_wordK -[b]tuple_of_wordK; f_equal.
        rewrite -[tuple_of_word a]finfun_of_tupleK -[tuple_of_word b]finfun_of_tupleK; f_equal.
        rewrite -ffunP => i.
        rewrite ffunE ffunE.
        apply H.
    Qed.

    Lemma rew_bit n w m eq i :
        bit (eq_rect n word w m eq) i = bit w (eq_rect m ordinal i n (Logic.eq_sym eq)).
    Proof. by move: i; case eq => i. Qed.

    Lemma rew_ord n i m eq:
        eq_rect n ordinal i m eq =
        let '(Ordinal i l) := i in Ordinal (eq_rect n (fun m => i < m) l m eq).
    Proof. by move: i; case eq; case=> m' ltn. Qed.
End Basics.
Arguments word_of_tuple {n}.

Section Logic.

    (* word 0 *)

    Definition trivial_word : word 0 := word_of_tuple [tuple].
    Lemma word0 : all_equal_to trivial_word. by case=> x; rewrite tuple0. Qed.

    (* word 1 *)

    Definition bool_of_word1 (w : word 1) : bool := bit w ord0.
    Definition word1_of_bool (b : bool) : word 1 := word_of_tuple [tuple b].

    Lemma bool_of_word1K : cancel bool_of_word1 word1_of_bool.
        move=> w. apply /wordP => i. by rewrite ord1 /bit tnth0.
    Qed.

    Lemma word1_of_boolK : cancel word1_of_bool bool_of_word1.
        move=> b. by rewrite /bool_of_word1 /bit word_of_tupleK tnth0.
    Qed.

    (* Bitwise Operations *)

    Definition op0 {n} op := fromBits (fun i => op) : word n.
    Lemma op0_spec {n} op i: bit (op0 op : word n) i = op.
    by rewrite fromBitsK. Qed.

    Definition op1 {n} op (a : word n) := fromBits (fun i => op (bit a i)).
    Lemma op1_spec {n} op (a : word n) i: bit (op1 op a) i = op (bit a i).
    by rewrite fromBitsK. Qed.

    Definition op2 {n} op (a b : word n) := fromBits (fun i => op (bit a i) (bit b i)).
    Lemma op2_spec {n} op (a b : word n) i: bit (op2 op a b) i = op (bit a i) (bit b i).
    by rewrite fromBitsK. Qed.

    (* Concatenation and Splitting *)

    Definition concat {m n} (ws : word m * word n) : word (m+n) :=
        fromBits (fun i => match split i with inl i => bit ws.1 i | inr i => bit ws.2 i end).
        (* More efficient:
            word_of_tuple [tuple of cat (tuple_of_word ws.1) (tuple_of_word ws.2)].
        *)
    
    Lemma concat_spec_1 {m n} (ws : word m * word n) (i : 'I_m)
        : bit (concat ws) (lshift _ i) = bit ws.1 i.
    Proof.
        rewrite fromBitsK.
        have -> : split (lshift n i) = inl i by rewrite -[inl i]unsplitK.
        done.
    Qed.

    Lemma concat_spec_2 {m n} (ws : word m * word n) (i : 'I_n)
        : bit (concat ws) (rshift _ i) = bit ws.2 i.
    Proof.
        rewrite fromBitsK.
        have -> : split (rshift m i) = inr i by rewrite -[inr i]unsplitK.
        done.
    Qed.

    Definition split {m n} (w : word (m+n)) : word m * word n := 
        ( fromBits (fun i => bit w (lshift _ i))
        , fromBits (fun i => bit w (rshift _ i))
        ).
        (* More effficient version is tricky to write, because equality proofs. *)

    Lemma split_spec_1 {m n} (w : word (m + n)) (i : 'I_m)
        : bit (split w).1 i = bit w (lshift _ i).
    Proof. by rewrite fromBitsK. Qed.

    Lemma split_spec_2 {m n} (w : word (m + n)) (i : 'I_n)
        : bit (split w).2 i = bit w (rshift _ i).
    Proof. by rewrite fromBitsK. Qed.

    Lemma concatK {m n} : cancel (concat (m := m) (n := n)) split.
    Proof.
        move=> ws.
        apply injective_projections.
        - by apply wordP => i; rewrite split_spec_1 concat_spec_1.
        - by apply wordP => i; rewrite split_spec_2 concat_spec_2.
    Qed.

    Lemma splitK {m n} : cancel split (concat (m := m) (n := n)).
    Proof.
        move=> w. apply /wordP => i.
        rewrite -[i]splitK. case: (fintype.split i) => j.
        - by rewrite concat_spec_1 split_spec_1.
        - by rewrite concat_spec_2 split_spec_2.
    Qed.

    (* Shifts *)

    Definition pushLow {n} : bool * word n -> word n.+1 :=
        fun w => word_of_tuple [tuple of cons w.1 (tuple_of_word w.2)].

    Lemma pushLow_spec1 {n} b (w : word n) : bit (pushLow (b,w)) ord0 = b.
    Proof. by rewrite /bit tnth0. Qed.

    Lemma pushLow_spec2 {n} b (w : word n) i : bit (pushLow (b,w)) (lift ord0 i) = bit w i.
    Proof. by rewrite /bit tnthS. Qed.
        

    Inductive ordLow {n} : 'I_n.+1 -> Type :=
    | ordLow1 : ordLow ord0
    | ordLow2 i : ordLow (lift ord0 i)
    .

    Lemma ordLowP {n} (i : 'I_n.+1) : ordLow i.
    Proof.
        case: i; case=> [pf | i pf].
        - have->: Ordinal pf = ord0 by apply /eqP.
            apply ordLow1.
        - have pf2: i < n by rewrite -(ltn_add2l 1).
            have->: Ordinal pf = lift ord0 (Ordinal pf2)
                by apply /eqP; rewrite /eq_op /lift /bump; simpl.
            apply ordLow2.
    Qed.


    Definition popLow {n} (w : word n.+1) : bool * word n.
        apply locked.
        refine (let '(b,w) := split _ in (bool_of_word1 b, w)).
        by rewrite add1n.
    Defined.

    Lemma pushLowK {n} : cancel pushLow (popLow (n := n)).
    Proof.
        move=> [b w].
        rewrite /popLow; unlock.
        apply pair_equal_spec; split.
        - rewrite -{2}[b](pushLow_spec1 b w).
            rewrite /bool_of_word1 fromBitsK rew_bit; f_equal.
            by apply /eqP; rewrite rew_ord /eq_op; simpl.
        - apply wordP => i.
            rewrite -(pushLow_spec2 b w).
            rewrite fromBitsK rew_bit; f_equal.
            by apply /eqP; rewrite rew_ord /eq_op; simpl.
    Qed.
    Lemma popLowK {n} : cancel popLow (pushLow (n := n)).
    Proof.
        move=> w.
        rewrite /popLow; unlock.
        apply wordP; case /ordLowP.
        - rewrite pushLow_spec1 /bool_of_word1 fromBitsK rew_bit; f_equal.
            by apply /eqP; rewrite rew_ord /eq_op; simpl.
        - move=> i.
            rewrite pushLow_spec2 fromBitsK rew_bit; f_equal.
            by apply /eqP; rewrite rew_ord /eq_op; simpl.
    Qed.


    
    Definition pushHigh {n} : word n * bool -> word n.+1
        := fun w => word_of_tuple [tuple of rcons (tuple_of_word w.1) w.2].

    Lemma pushHigh_spec1 {n} (w : word n) b i
        : bit (pushHigh (w,b)) (widen_ord (leqnSn _) i) = bit w i.
    Proof.
        rewrite /bit (tnth_nth false) (tnth_nth false) nth_rcons size_tuple; simpl.
        by rewrite ltn_ord.
    Qed.

    Lemma pushHigh_spec2 {n} (w : word n) b : bit (pushHigh (w,b)) ord_max = b.
    Proof. by rewrite /bit (tnth_nth false) nth_rcons size_tuple ltnn eqxx. Qed.


    Inductive ordHigh {n} : 'I_n.+1 -> Type :=
    | ordHigh1 (i : 'I_n) : ordHigh (widen_ord (leqnSn _) i)
    | ordHigh2 : ordHigh ord_max
    .

    Lemma ordHighP {n} (i : 'I_n.+1) : ordHigh i.
    Proof.
        case i => m pf.
        have: (m == n) = (m == n) by []; case: {-1}(m == n) => [eq | neq].
        - have->: Ordinal pf = ord_max by apply /eqP; rewrite /eq_op; simpl.
            apply ordHigh2.
        - move: (pf) => pf'.
            rewrite leq_eqVlt eqSS neq -[m.+1]add1n -[n.+1]add1n ltn_add2l in pf.
            simpl in pf.
            have->: Ordinal pf' = widen_ord (leqnSn _) (Ordinal pf)
                by apply /eqP; rewrite /eq_op; simpl.
            apply ordHigh1.
    Qed.


    Definition popHigh {n} (w : word n.+1) : word n * bool.
        apply locked.
        refine (let '(w,b) := split _ in (w, bool_of_word1 b)).
        by rewrite addn1.
    Defined.

    Lemma pushHighK {n} : cancel pushHigh (popHigh (n := n)).
    Proof.
        move=> [w b].
        rewrite /popHigh; unlock.
        apply pair_equal_spec; split.
        - apply wordP => i.
            rewrite -(pushHigh_spec1 w b).
            rewrite fromBitsK rew_bit; f_equal.
            by apply /eqP; rewrite rew_ord /eq_op; simpl.
        - rewrite -{2}[b](pushHigh_spec2 w b).
            rewrite /bool_of_word1 fromBitsK rew_bit; f_equal.
            by apply /eqP; rewrite rew_ord /eq_op; simpl; rewrite addn0.
    Qed.

    Lemma popHighK {n} : cancel popHigh (pushHigh (n := n)).
    Proof.
        move=> w.
        rewrite /popHigh; unlock.
        apply wordP; case /ordHighP.
        - move=> i.
            rewrite pushHigh_spec1 fromBitsK rew_bit; f_equal.
            by apply /eqP; rewrite rew_ord /eq_op; simpl.
        - rewrite pushHigh_spec2 /bool_of_word1 fromBitsK rew_bit; f_equal.
            by apply /eqP; rewrite rew_ord /eq_op; simpl; rewrite addn0.
    Qed.

End Logic.

Section Numerics.

    (* Design Decision: nat vs int *)

    From mathcomp Require Import ssralg ssrint intdiv.
    Local Open Scope Z.
    Import GRing.Theory.

    Definition nat_of_word {n} (w : word n) : nat := \sum_i bit w i * 2^i.
    Definition int_of_word {n} (w : word n) : int := nat_of_word w.

    Definition word_of_int {n} (w : int) : word n :=
        word_of_tuple [tuple ~~(divz w (2^i)%:Z %| 2%:Z)%Z | i < n].
    Definition word_of_nat {n} (w : nat) : word n := word_of_int w.

    (* Bitwise *)

    Fact geometric_series {n} : (\sum_(i<n) 2^i).+1 = 2^n.
    Proof.
        elim: n => [|n IH].
        - by rewrite big_ord0 expn0.
        - by rewrite big_ord_recr -addSn IH addnn -mul2n -expnS.
    Qed.

    Lemma nat_of_neg {n} (w : word n) : 
        (nat_of_word (op1 negb w) + nat_of_word w).+1 = 2^n.
    Proof.
        rewrite -geometric_series -big_split.
        f_equal; apply eq_bigr => i _; simpl.
        by rewrite -mulnDl op1_spec addn_negb mul1n.
    Qed.

    Theorem int_of_neg {n} (w : word n) : 
        (int_of_word (op1 negb w) = (2^n)%:Z - 1 - int_of_word w)%R.
    Proof.
        by rewrite -(nat_of_neg w) [(_ - 1)%R]addrC -add1n PoszD addKr PoszD addrK.
    Qed.
    

    (* Shifts *)

    Lemma nat_of_pushLow {n} b (w : word n) :
        nat_of_word (pushLow (b, w)) = b + double (nat_of_word w).
    Proof.
        rewrite /nat_of_word big_ord_recl; simpl; f_equal.
        - by rewrite pushLow_spec1 expn0 muln1.
        - rewrite -muln2 big_distrl.
            apply: eq_bigr => i _.
            rewrite pushLow_spec2 /lift /bump. simpl.
            by rewrite -mulnA expnS [2 * _]mulnC.
    Qed.

    Lemma int_of_pushLow {n} b (w : word n) :
        (int_of_word (pushLow (b, w)) = b%:Z + int_of_word w * 2)%R.
    Proof. by rewrite /int_of_word nat_of_pushLow -muln2 PoszD PoszM. Qed.

    Lemma nat_of_pushHigh {n} (w : word n) b :
        nat_of_word (pushHigh (w, b)) = nat_of_word w + b * 2^n.
    Proof.
        rewrite /nat_of_word big_ord_recr; simpl; f_equal.
        - apply eq_bigr => i _.
            by rewrite pushHigh_spec1.
        - by rewrite pushHigh_spec2.
    Qed.

    Lemma int_of_pushHigh {n} (w : word n) b :
        (int_of_word (pushHigh (w, b)) = int_of_word w + b%:Z * (2^n)%:Z)%R.
    Proof. by rewrite /int_of_word nat_of_pushHigh PoszD PoszM. Qed.


    (* Addition *)

    Definition full_adder (a b c : bool) : (bool * bool) := match a, b, c with
        | false, false, false => (false, false)

        | false, false, true
        | false, true, false
        | true, false, false => (true, false)

        | true, true, false
        | true, false, true
        | false, true, true => (false, true)

        | true, true, true => (true, true)
    end.

    Lemma full_adder_spec (a b c : bool) :
        a + b + c = (full_adder a b c).1 + (full_adder a b c).2 * 2.
        by case a; case b; case c.
    Qed.

    Fixpoint addition {n} : bool -> word n -> word n -> word n * bool := match n with
        | 0 => fun carry _ _ => (trivial_word, carry)
        | n.+1 => fun carry a b =>
            let '(a, aa) := popLow a in
            let '(b, bb) := popLow b in
            let '(c, carry') := full_adder carry a b in
            let '(cc, carry'') := addition carry' aa bb in
            (pushLow (c, cc), carry'')
    end.

    Theorem nat_of_addition {n} carryin (a b : word n)
        : nat_of_word (addition carryin a b).1 + (addition carryin a b).2 * 2^n
        = carryin + nat_of_word a + nat_of_word b.
    Proof.
        move: carryin a b.
        elim: n.
        - move=> cin a b. simpl.
            by rewrite
                /nat_of_word big_ord0 big_ord0 big_ord0
                expn0 muln1 add0n addn0 addn0.
        - move=> n IH cin a b.
            simpl.
            rewrite -{3}[a]popLowK -{3}[b]popLowK.
            move: (popLow a) (popLow b); clear a b; move=> [a aa] [b bb].
            rewrite nat_of_pushLow nat_of_pushLow 
                [cin + (a + _)]addnA addnACA full_adder_spec.
            move: (full_adder cin a b) => [c carry']; simpl.
            rewrite -muln2 -muln2 -addnA -mulnDl -mulnDl addnA -IH.
            move: (addition carry' aa bb) => [cc carry'']; simpl.
            rewrite nat_of_pushLow -muln2.
            by rewrite expnS [2 * 2^n]mulnC -addnA mulnA -mulnDl.
    Qed.

    Theorem int_of_addition {n} carryin (a b : word n)
        : (int_of_word (addition carryin a b).1
        = carryin%:Z + int_of_word a + int_of_word b
        - ((addition carryin a b).2 * 2^n)%:Z)%R.
    Proof. by rewrite -PoszD -nat_of_addition PoszD addrK. Qed.

    (* Subtraction *)
    
    (* Operates with inverted borrows. *)
    Definition subtraction {n} : bool -> word n -> word n -> word n * bool
        := fun borrowin a b => addition borrowin a (op1 negb b).

    Theorem int_of_subtraction {n} borrowin (a b : word n)
        : (int_of_word (subtraction borrowin a b).1
        = -(negb borrowin)%:Z + int_of_word a - int_of_word b
        + (negb (subtraction borrowin a b).2 * 2^n)%:Z)%R.
    Proof.
        rewrite /subtraction int_of_addition int_of_neg.
        move: (addition borrowin a (op1 negb b)) => [c borrowout]; simpl.

        rewrite [(_ - (borrowout * 2^n)%:Z)%R]addrC [(_ + (~~ borrowout * 2^n)%:Z)%R]addrC.
        do 6 rewrite addrA; f_equal.
        rewrite -addrA [(_ - 1)%R]addrC addrA.
        do 2 rewrite addrC addrA addrA addrA; f_equal.
        rewrite addrC addrA addrA -addrA [(_ - (~~borrowin)%:Z)%R]addrC; f_equal.
        - by apply subr0_eq; rewrite opprK addrC addrA -PoszD addn_negb.
        - rewrite -[(2^n)%:Z]mul1r PoszM -mulrBl PoszM; f_equal.
            apply /eqP.
            by rewrite subr_eq -PoszD addn_negb.

    Definition addw {n} (w w' : word n) : word n
        := (addition false w w').1.

    Definition increment {n} (i : int) (w : word n) : word n
        := addw w (word_of_int i).

Qed.

End Numerics.

Definition word_eqMixin {n} := CanEqMixin (tuple_of_wordK (n := n)).
Canonical word_eqType {n} := EqType (word n) word_eqMixin.

Definition word_choiceMixin {n} := CanChoiceMixin (tuple_of_wordK (n := n)).
Canonical word_choiceType {n} := ChoiceType (word n) word_choiceMixin.

Definition word_countMixin {n} := CanCountMixin (tuple_of_wordK (n := n)).
Canonical word_countType {n} := CountType (word n) word_countMixin.

Definition word_finMixin {n} := CanFinMixin (tuple_of_wordK (n := n)).
Canonical word_finType {n} := FinType (word n) word_finMixin.


(* Section ZmodType.

    From mathcomp Require Import ssralg ssrint.
    Local Open Scope Z.
    Import GRing.Theory.

    Definition word_eqMixin {n} := CanEqMixin (tuple_of_wordK (n := n)).
    Canonical word_eqType {n} := EqType (word n) word_eqMixin.

    Definition word_choiceMixin {n} := CanChoiceMixin (tuple_of_wordK (n := n)).
    Canonical word_choiceType {n} := ChoiceType (word n) word_choiceMixin.

    Definition word_countMixin {n} := CanCountMixin (tuple_of_wordK (n := n)).
    Canonical word_countType {n} := CountType (word n) word_countMixin.

    Definition word_finMixin {n} := CanFinMixin (tuple_of_wordK (n := n)).
    Canonical word_finType {n} := FinType (word n) word_finMixin.

    Definition addw {n} (a b : word n) : word n := (addition false a b).1.
    Definition oppw {n} (a : word n) : word n := (subtraction true (op0 false) a).1.

    Lemma addwA {n} : associative (addw (n := n)). Admitted.
    Lemma addwC {n} : commutative (addw (n := n)). Admitted.
    Lemma add0w {n} : left_id (op0 false) (addw (n := n)). Admitted.
    Lemma addKw {n} : left_inverse (op0 false) oppw (addw (n := n)). Admitted.

    Definition word_zmodMixin {n} := ZmodMixin addwA addwC add0w (addKw (n := n)).
    Canonical word_zmodType {n} := ZmodType (word n) word_zmodMixin.

End ZmodType. *)