From mathcomp Require Import all_ssreflect ssralg ssrint.
Import GRing.Theory.
Local Open Scope R.

Require Import nondeterminism word memory peripherals.
From RecordUpdate Require Import RecordSet.

(* In this file, I precisely describe the behavior of the Atari 2600.
    I then define what it means for a state of the Atari to be reachable.
    This is necessary to even *state* the theorem we want to prove;
    to prove anything about minimum speedrun times, we need to define what it means to run a program!

    Note that if there is a bug in this specification,
    the final theorem will instead prove a minimum speedrun time *when running Dragster on a buggy Atari*.
    This is an unavoidable limitation to the proof.
*)


(* Flags *)

Record flags := {
    FlagN : bool;
    FlagV : bool; (* WARNING: I assume the SO pin is unused. *)
    FlagD : bool;
    FlagI : bool;
    FlagZ : bool;
    FlagC : bool;
}.

Instance set_flags : Settable _ := settable! Build_flags
    <FlagN; FlagV; FlagD; FlagI; FlagZ; FlagC>.

Definition word_of_flags (break : bool) (f : flags) : word 8
    := word_of_tuple [tuple 
        (FlagC f);
        (FlagZ f);
        (FlagI f);
        (FlagD f);
        break;
        true;
        (FlagV f);
        (FlagN f)
    ].

Definition flags_of_word (w : word 8) : flags.
by refine {|
    FlagC := bit w (Ordinal (m := 0) _);
    FlagZ := bit w (Ordinal (m := 1) _);
    FlagI := bit w (Ordinal (m := 2) _);
    FlagD := bit w (Ordinal (m := 3) _);
    FlagV := bit w (Ordinal (m := 6) _);
    FlagN := bit w (Ordinal (m := 7) _);
|}.
Defined.

Lemma word_of_flagsK b : cancel (word_of_flags b) flags_of_word.
Proof.
    move=> [n v d i z c].
    rewrite /word_of_flags; simpl.
    by rewrite /flags_of_word /bit word_of_tupleK /tnth.
Qed.

(* State *)

Record state (rom : ROM) := {
    Peripherals : peripherals.state rom;
    Flags : flags;
    ProgramCounter : word 16;
    RegA : word 8;
    RegX : word 8;
    RegY : word 8;
    RegS : word 8; (* Stack Pointer *)
}.

Arguments Peripherals {rom}.
Arguments Flags {rom}.
Arguments ProgramCounter {rom}.
Arguments RegA {rom}.
Arguments RegX {rom}.
Arguments RegY {rom}.
Arguments RegS {rom}.

Instance set_state {rom} : Settable _ := settable! (Build_state rom)
    <Peripherals; Flags; ProgramCounter; RegA; RegX; RegY; RegS>.

Definition read {rom} (addr : word 16) (s : state rom) : nondet (word 8)
    := peripherals.read addr (Peripherals s).

Definition write {rom} (addr : word 16) (w : word 8) (s : state rom) : nondet (state rom)
    := p <- peripherals.write addr w (Peripherals s); determined (set Peripherals (fun=>p) s).

(* Flags *)

Definition setZN {rom} (w : word 8) : state rom -> state rom
    := set Flags
        (fun f => set FlagN (fun=> bit w ord_max)
            (set FlagZ (fun=> w == op0 false) f)).

(* Stack *)

Definition push8 {rom} (w : word 8) (s : state rom) : nondet (state rom). Admitted.
Definition push16 {rom} (w : word 16) (s : state rom) : nondet (state rom). Admitted.
Definition pull8 {rom} (s : state rom) : nondet (word 8 * state rom). Admitted.
Definition pull16 {rom} (s : state rom) : nondet (word 16 * state rom). Admitted.

(* Addressing Modes *)

Inductive addressing_mode :=
| Accumulator
| Immediate
| Implied
| Relative
| Absolute
| ZeroPage
| Indirect
| AbsoluteX
| AbsoluteY
| ZeroPageX
| ZeroPageY
| IndexedIndirect
| IndirectIndexed
.

Definition mode_width (mode : addressing_mode) : nat := match mode with
    | Accumulator => 0
    | Immediate => 1
    | Implied => 0
    | Relative => 1
    | Absolute => 2
    | ZeroPage => 1
    | Indirect => 2
    | AbsoluteX => 2
    | AbsoluteY => 2
    | ZeroPageX => 1
    | ZeroPageY => 1
    | IndexedIndirect => 1
    | IndirectIndexed => 1
end.

(* Many addressing modes pick out an address in memory. Return that address. *)
Definition mode_addr {rom} (mode : addressing_mode)
:= match mode return word (mode_width mode * 8) -> state rom -> nondet (word 16) with
    | Accumulator => fun _ _ => unspecified
    | Immediate => fun _ _ => unspecified
    | Implied => fun _ _ => unspecified
    | Relative => fun offset s =>
        determined (addw
            (ProgramCounter s)
            (concat (offset, op0 (bit offset ord_max))))
    | Absolute => fun addr s => determined addr
    | ZeroPage => fun addr s => determined (concat (addr, op0 false))
    | Indirect => fun addr s =>
        (* Annoying edge case. I could specify the behavior here too, but why bother? *)
        if (split addr).1 == op0 (n := 8) true then unspecified else
        lo <- read addr s;
        hi <- read (increment 1 addr) s;
        determined (concat (lo, hi))
    | AbsoluteX => fun addr s => determined (addw addr (concat (RegX s, op0 false)))
    | AbsoluteY => fun addr s => determined (addw addr (concat (RegY s, op0 false)))
    | ZeroPageX => fun addr s => determined (concat (addw addr (RegX s), op0 false))
    | ZeroPageY => fun addr s => determined (concat (addw addr (RegY s), op0 false))
    | IndexedIndirect => fun addr s => 
        let addr := addw addr (RegX s) in
        lo <- read (concat (addr, op0 false)) s;
        hi <- read (concat (increment 1 addr, op0 false)) s;
        determined (concat (lo, hi))
    | IndirectIndexed => fun addr s =>
        lo <- read (concat (addr, op0 false)) s;
        hi <- read (concat (increment 1 addr, op0 false)) s;
        determined (addw (concat (lo, hi)) (concat (RegY s, op0 false)))
end.

(* Read data, using this addressing mode. *)
Definition mode_read {rom} (mode : addressing_mode)
:= match mode return word (mode_width mode * 8) -> state rom -> nondet (word 8) with
    | Accumulator => fun _ s => determined (RegA s)
    | Immediate => fun w _ => determined w
    | Implied => fun _ _ => unspecified
    | mode => fun bytes s =>
        addr <- mode_addr mode bytes s;
        read addr s
end.

(* Write data, using this addressing mode. *)
Definition mode_write {rom} (mode : addressing_mode)
:= match mode return word (mode_width mode * 8) -> word 8 -> state rom -> nondet (state rom) with
    | Accumulator => fun _ w s => determined (set RegA (fun=>w) s)
    | Immediate => fun _ _ _ => unspecified
    | Implied => fun _ _ _ => unspecified
    | mode => fun bytes w s =>
        addr <- mode_addr mode bytes s;
        write addr w s
end.


(* Instructions *)

Inductive instruction :=
| ADC | AND | ASL | BCC | BCS | BEQ | BIT | BMI | BNE | BPL | BRK | BVC | BVS | CLC
| CLD | CLI | CLV | CMP | CPX | CPY | DEC | DEX | DEY | EOR | INC | INX | INY | JMP
| JSR | LDA | LDX | LDY | LSR | NOP | ORA | PHA | PHP | PLA | PLP | ROL | ROR | RTI
| RTS | SBC | SEC | SED | SEI | STA | STX | STY | TAX | TAY | TSX | TXA | TXS | TYA
.

Definition run_instruction {rom}
    (instr : instruction)
    (mode : addressing_mode)
    (bytes : word (mode_width mode * 8))
    (s : state rom)
    : nondet (state rom)
:= match instr with

    (* Organization based on https://masswerk.at/6502/6502_instruction_set.html *)

    (* Transfer Instructions *)
    | LDA => w <- mode_read mode bytes s; determined (set RegA (fun=> w) (setZN w s))
    | LDX => w <- mode_read mode bytes s; determined (set RegX (fun=> w) (setZN w s))
    | LDY => w <- mode_read mode bytes s; determined (set RegY (fun=> w) (setZN w s))
    | STA => mode_write mode bytes (RegA s) s
    | STX => mode_write mode bytes (RegX s) s
    | STY => mode_write mode bytes (RegY s) s
    | TAX => determined (let w := RegA s in set RegX (fun=> w) (setZN w s))
    | TAY => determined (let w := RegA s in set RegY (fun=> w) (setZN w s))
    | TSX => determined (let w := RegS s in set RegX (fun=> w) (setZN w s))
    | TXA => determined (let w := RegX s in set RegA (fun=> w) (setZN w s))
    | TXS => determined (let w := RegX s in set RegS (fun=> w) s)
    | TYA => determined (let w := RegY s in set RegA (fun=> w) (setZN w s))

    (* Stack Instructions *)
    | PHA => push8 (RegA s) s
    | PHP => push8 (word_of_flags true (Flags s)) s
    | PLA =>
        '(w, s) <- pull8 s;
        determined (set RegA (fun=> w) (setZN w s))
    | PLP =>
        '(w, s) <- pull8 s;
        determined (set Flags (fun=> set FlagI (fun=> false) (flags_of_word w)) s)

    (* Decrements and Increments *)
    | DEC =>
        w <- mode_read mode bytes s;
        let w := increment (-1) w in
        mode_write mode bytes w (setZN w s)
    | DEX =>
        let w := increment (-1) (RegX s) in
        determined (set RegX (fun=> w) (setZN w s))
    | DEY =>
        let w := increment (-1) (RegY s) in
        determined (set RegY (fun=> w) (setZN w s))

    | INC =>
        w <- mode_read mode bytes s;
        let w := increment 1 w in
        mode_write mode bytes w (setZN w s)
    | INX =>
        let w := increment 1 (RegX s) in
        determined (set RegX (fun=> w) (setZN w s))
    | INY =>
        let w := increment 1 (RegY s) in
        determined (set RegY (fun=> w) (setZN w s))

    (* Arithmetic Operations *)
    | ADC =>
        if FlagD (Flags s)
        then unspecified (* TODO. Note: Only C should be specified. *)
        else
            w <- mode_read mode bytes s;
            let '(w, carryout) := addition (FlagC (Flags s)) (RegA s) w in
            overflow <- unspecified;
            determined 
                (set RegA (fun=> w)
                    (set Flags
                        (fun f =>
                            set FlagC (fun=> carryout) 
                                (set FlagV (fun=> overflow) f))
                        (setZN w s)))

    | SBC =>
        if FlagD (Flags s)
        then unspecified (* TODO. Note: Only C should be specified. *)
        else
            w <- mode_read mode bytes s;
            let '(w, carryout) := subtraction (FlagC (Flags s)) (RegA s) w in
            overflow <- unspecified;
            determined 
                (set RegA (fun=> w)
                    (set Flags
                        (fun f =>
                            set FlagC (fun=> carryout) 
                                (set FlagV (fun=> overflow) f))
                        (setZN w s)))

    (* Logical Operations *)

    | AND =>
        w <- mode_read mode bytes s;
        let w := op2 andb (RegA s) w in
        determined (set RegA (fun=> w) (setZN w s))
    | ORA =>
        w <- mode_read mode bytes s;
        let w := op2 orb (RegA s) w in
        determined (set RegA (fun=> w) (setZN w s))
    | EOR =>
        w <- mode_read mode bytes s;
        let w := op2 addb (RegA s) w in
        determined (set RegA (fun=> w) (setZN w s))

    (* Shift & Rotate Instructions *)

    | ASL =>
        w <- mode_read mode bytes s;
        let '(w, b) := popHigh (pushLow (false, w)) in
        determined (set RegA (fun=> w) (set Flags (set FlagC (fun=>b)) (setZN w s)))
    | LSR =>
        w <- mode_read mode bytes s;
        let '(b, w) := popLow (pushHigh (w, false)) in
        determined (set RegA (fun=> w) (set Flags (set FlagC (fun=>b)) (setZN w s)))
    | ROL =>
        w <- mode_read mode bytes s;
        let '(w, b) := popHigh (pushLow (FlagC (Flags s), w)) in
        determined (set RegA (fun=> w) (set Flags (set FlagC (fun=>b)) (setZN w s)))
    | ROR =>
        w <- mode_read mode bytes s;
        let '(b, w) := popLow (pushHigh (w, FlagC (Flags s))) in
        determined (set RegA (fun=> w) (set Flags (set FlagC (fun=>b)) (setZN w s)))

    (* Flag Instructions *)
    | CLC => determined (set Flags (set FlagC (fun=> false)) s)
    | CLD => determined (set Flags (set FlagD (fun=> false)) s)
    | CLI => determined (set Flags (set FlagI (fun=> false)) s)
    | CLV => determined (set Flags (set FlagV (fun=> false)) s)
    | SEC => determined (set Flags (set FlagC (fun=> true)) s)
    | SED => determined (set Flags (set FlagD (fun=> true)) s)
    | SEI => determined (set Flags (set FlagI (fun=> true)) s)

    (* Comparisons *)

    | CMP =>
        w <- mode_read mode bytes s;
        let '(w, carryout) := subtraction true (RegA s) w in
        determined 
            (set Flags
                (set FlagC (fun=> carryout))
                (setZN w s))
    | CPX =>
        w <- mode_read mode bytes s;
        let '(w, carryout) := subtraction true (RegX s) w in
        determined 
            (set Flags
                (set FlagC (fun=> carryout))
                (setZN w s))
    | CPY =>
        w <- mode_read mode bytes s;
        let '(w, carryout) := subtraction true (RegY s) w in
        determined 
            (set Flags
                (set FlagC (fun=> carryout))
                (setZN w s))
                        

    (* Conditional Branch Instructions *)
    | BCS =>
        if   FlagC (Flags s) then
            addr <- mode_addr mode bytes s;
            determined (set ProgramCounter (fun=> addr) s)
        else determined s
    | BCC =>
        if ~~FlagC (Flags s) then
            addr <- mode_addr mode bytes s;
            determined (set ProgramCounter (fun=> addr) s)
        else determined s
    | BEQ =>
        if   FlagZ (Flags s) then
            addr <- mode_addr mode bytes s;
            determined (set ProgramCounter (fun=> addr) s)
        else determined s
    | BNE =>
        if ~~FlagZ (Flags s) then
            addr <- mode_addr mode bytes s;
            determined (set ProgramCounter (fun=> addr) s)
        else determined s
    | BVS =>
        if   FlagV (Flags s) then
            addr <- mode_addr mode bytes s;
            determined (set ProgramCounter (fun=> addr) s)
        else determined s
    | BVC =>
        if ~~FlagV (Flags s) then
            addr <- mode_addr mode bytes s;
            determined (set ProgramCounter (fun=> addr) s)
        else determined s
    | BMI =>
        if   FlagN (Flags s) then
            addr <- mode_addr mode bytes s;
            determined (set ProgramCounter (fun=> addr) s)
        else determined s
    | BPL =>
        if ~~FlagN (Flags s) then
            addr <- mode_addr mode bytes s;
            determined (set ProgramCounter (fun=> addr) s)
        else determined s

    (* Jumps & Subroutines *)
    | JMP =>
        addr <- mode_addr mode bytes s;
        determined (set ProgramCounter (fun=> addr) s)

    | JSR =>
        addr <- mode_addr mode bytes s;
        push16 (increment (-1) (ProgramCounter s))
            (set ProgramCounter (fun=> addr) s)

    | RTS =>
        '(addr, s) <- pull16 s;
        determined (set ProgramCounter (fun=> increment 1 addr) s)

    (* Interrupts *)
    | BRK => 
        s <- push16 (increment (-1) (ProgramCounter s)) s;
        s <- push8 (word_of_flags true (Flags s)) s;
        lo <- read (pushLow (false, op0 true)) s;
        hi <- read (pushLow (true, op0 true)) s;
        determined
            (set ProgramCounter (fun=> concat (lo : word 8, hi : word 8))
                (set Flags (set FlagI (fun=> true)) s))

    | RTI =>
        '(f, s) <- pull8 s;
        '(addr, s) <- pull16 s;
        determined
            (set ProgramCounter (fun=> addr)
                (set Flags (fun=> set FlagI (fun=> false) (flags_of_word f)) s))


    (* Other *)
    | BIT =>
        w <- mode_read mode bytes s;
        determined (set Flags 
            (fun f =>
                (set FlagZ (fun=> op2 andb (RegA s) w == op0 false)
                    (set FlagN (fun=> bit w (Ordinal (m := 7) (leqnn _)))
                        (set FlagV (fun=> bit w (Ordinal (m := 6) (leqnSn _)))
                            f))))
        s)

    | NOP => determined s
end.


(* Given an opcode, return the corresponding instruction and addressing mode.
    Return `None` if the opcode is invalid.
*)
Definition parse_opcode (w : word 8) : option (instruction * addressing_mode)
:= match nat_of_word w with
    | 0x00 => Some (BRK, Implied)
    | 0x01 => Some (ORA, IndexedIndirect)
    | 0x05 => Some (ORA, ZeroPage)
    | 0x06 => Some (ASL, ZeroPage)
    | 0x08 => Some (PHP, Implied)
    | 0x09 => Some (ORA, Immediate)
    | 0x0a => Some (ASL, Accumulator)
    | 0x0d => Some (ORA, Absolute)
    | 0x0e => Some (ASL, Absolute)
    | 0x10 => Some (BPL, Relative)
    | 0x11 => Some (ORA, IndirectIndexed)
    | 0x15 => Some (ORA, ZeroPageX)
    | 0x16 => Some (ASL, ZeroPageX)
    | 0x18 => Some (CLC, Implied)
    | 0x19 => Some (ORA, AbsoluteY)
    | 0x1d => Some (ORA, AbsoluteX)
    | 0x1e => Some (ASL, AbsoluteX)
    | 0x20 => Some (JSR, Absolute)
    | 0x21 => Some (AND, IndexedIndirect)
    | 0x24 => Some (BIT, ZeroPage)
    | 0x25 => Some (AND, ZeroPage)
    | 0x26 => Some (ROL, ZeroPage)
    | 0x28 => Some (PLP, Implied)
    | 0x29 => Some (AND, Immediate)
    | 0x2a => Some (ROL, Accumulator)
    | 0x2c => Some (BIT, Absolute)
    | 0x2d => Some (AND, Absolute)
    | 0x2e => Some (ROL, Absolute)
    | 0x30 => Some (BMI, Relative)
    | 0x31 => Some (AND, IndirectIndexed)
    | 0x35 => Some (AND, ZeroPageX)
    | 0x36 => Some (ROL, ZeroPageX)
    | 0x38 => Some (SEC, Implied)
    | 0x39 => Some (AND, AbsoluteY)
    | 0x3d => Some (AND, AbsoluteX)
    | 0x3e => Some (ROL, AbsoluteX)
    | 0x40 => Some (RTI, Implied)
    | 0x41 => Some (EOR, IndexedIndirect)
    | 0x45 => Some (EOR, ZeroPage)
    | 0x46 => Some (LSR, ZeroPage)
    | 0x48 => Some (PHA, Implied)
    | 0x49 => Some (EOR, Immediate)
    | 0x4a => Some (LSR, Accumulator)
    | 0x4c => Some (JMP, Absolute)
    | 0x4d => Some (EOR, Absolute)
    | 0x4e => Some (LSR, Absolute)
    | 0x50 => Some (BVC, Relative)
    | 0x51 => Some (EOR, IndirectIndexed)
    | 0x55 => Some (EOR, ZeroPageX)
    | 0x56 => Some (LSR, ZeroPageX)
    | 0x58 => Some (CLI, Implied)
    | 0x59 => Some (EOR, AbsoluteY)
    | 0x5d => Some (EOR, AbsoluteX)
    | 0x5e => Some (LSR, AbsoluteX)
    | 0x60 => Some (RTS, Implied)
    | 0x61 => Some (ADC, IndexedIndirect)
    | 0x65 => Some (ADC, ZeroPage)
    | 0x66 => Some (ROR, ZeroPage)
    | 0x68 => Some (PLA, Implied)
    | 0x69 => Some (ADC, Immediate)
    | 0x6a => Some (ROR, Accumulator)
    | 0x6c => Some (JMP, Indirect)
    | 0x6d => Some (ADC, Absolute)
    | 0x6e => Some (ROR, AbsoluteX)
    | 0x70 => Some (BVS, Relative)
    | 0x71 => Some (ADC, IndirectIndexed)
    | 0x75 => Some (ADC, ZeroPageX)
    | 0x76 => Some (ROR, ZeroPageX)
    | 0x78 => Some (SEI, Implied)
    | 0x79 => Some (ADC, AbsoluteY)
    | 0x7d => Some (ADC, AbsoluteX)
    | 0x7e => Some (ROR, Absolute)
    | 0x81 => Some (STA, IndexedIndirect)
    | 0x84 => Some (STY, ZeroPage)
    | 0x85 => Some (STA, ZeroPage)
    | 0x86 => Some (STX, ZeroPage)
    | 0x88 => Some (DEY, Implied)
    | 0x8a => Some (TXA, Implied)
    | 0x8c => Some (STY, Absolute)
    | 0x8d => Some (STA, Absolute)
    | 0x8e => Some (STX, Absolute)
    | 0x90 => Some (BCC, Relative)
    | 0x91 => Some (STA, IndirectIndexed)
    | 0x94 => Some (STY, ZeroPageX)
    | 0x95 => Some (STA, ZeroPageX)
    | 0x96 => Some (STX, ZeroPageY)
    | 0x98 => Some (TYA, Implied)
    | 0x99 => Some (STA, AbsoluteY)
    | 0x9a => Some (TXS, Implied)
    | 0x9d => Some (STA, AbsoluteX)
    | 0xa0 => Some (LDY, Immediate)
    | 0xa1 => Some (LDA, IndexedIndirect)
    | 0xa2 => Some (LDX, Immediate)
    | 0xa4 => Some (LDY, ZeroPage)
    | 0xa5 => Some (LDA, ZeroPage)
    | 0xa6 => Some (LDX, ZeroPage)
    | 0xa8 => Some (TAY, Implied)
    | 0xa9 => Some (LDA, Immediate)
    | 0xaa => Some (TAX, Implied)
    | 0xac => Some (LDY, Absolute)
    | 0xad => Some (LDA, Absolute)
    | 0xae => Some (LDX, Absolute)
    | 0xb0 => Some (BCS, Relative)
    | 0xb1 => Some (LDA, IndirectIndexed)
    | 0xb4 => Some (LDY, ZeroPageX)
    | 0xb5 => Some (LDA, ZeroPageX)
    | 0xb6 => Some (LDX, ZeroPageY)
    | 0xb8 => Some (CLV, Implied)
    | 0xb9 => Some (LDA, AbsoluteY)
    | 0xba => Some (TSX, Implied)
    | 0xbc => Some (LDY, AbsoluteX)
    | 0xbd => Some (LDA, AbsoluteX)
    | 0xbe => Some (LDX, AbsoluteY)
    | 0xc0 => Some (CPY, Immediate)
    | 0xc1 => Some (CMP, IndexedIndirect)
    | 0xc4 => Some (CPY, ZeroPage)
    | 0xc5 => Some (CMP, ZeroPage)
    | 0xc6 => Some (DEC, ZeroPage)
    | 0xc8 => Some (INY, Implied)
    | 0xc9 => Some (CMP, Immediate)
    | 0xca => Some (DEX, Implied)
    | 0xcc => Some (CPY, Absolute)
    | 0xcd => Some (CMP, Absolute)
    | 0xce => Some (DEC, Absolute)
    | 0xd0 => Some (BNE, Relative)
    | 0xd1 => Some (CMP, IndirectIndexed)
    | 0xd5 => Some (CMP, ZeroPageX)
    | 0xd6 => Some (DEC, ZeroPageX)
    | 0xd8 => Some (CLD, Implied)
    | 0xd9 => Some (CMP, AbsoluteY)
    | 0xdd => Some (CMP, AbsoluteX)
    | 0xde => Some (DEC, AbsoluteX)
    | 0xe0 => Some (CPX, Immediate)
    | 0xe1 => Some (SBC, IndexedIndirect)
    | 0xe4 => Some (CPX, ZeroPage)
    | 0xe5 => Some (SBC, ZeroPage)
    | 0xe6 => Some (INC, ZeroPage)
    | 0xe8 => Some (INX, Implied)
    | 0xe9 => Some (SBC, Immediate)
    | 0xea => Some (NOP, Implied)
    | 0xec => Some (CPX, Absolute)
    | 0xed => Some (SBC, Absolute)
    | 0xee => Some (INC, Absolute)
    | 0xf0 => Some (BEQ, Relative)
    | 0xf1 => Some (SBC, IndirectIndexed)
    | 0xf5 => Some (SBC, ZeroPageX)
    | 0xf6 => Some (INC, ZeroPageX)
    | 0xf8 => Some (SED, Implied)
    | 0xf9 => Some (SBC, AbsoluteY)
    | 0xfd => Some (SBC, AbsoluteX)
    | 0xfe => Some (INC, AbsoluteX)
    | _ => None
end.



(* Program Execution *)

Definition init {rom} : nondet (state rom) :=
    (* Begin in an arbitrary state. *)
    s <- unspecified;
    (* Jump to the address contained in the reset vector. *)
    run_instruction JMP Indirect (pushLow (false, pushLow (false, op0 true))) s.

Definition step {rom} (s : state rom) : nondet (state rom) :=
    (* Read the opcode. *)
    op <- read (ProgramCounter s) s;
    if parse_opcode op is Some (instr, mode)
    then 
        let s := set ProgramCounter (increment 1) s in
        (* Read the remaining bytes in the instruction *)
        '(s, bytes) <-
            nat_rect (* basically a for loop *)
                (fun n => nondet (state rom * word (n * 8)))%type
                (determined (s, trivial_word))
                (fun _ m =>
                    '(s, acc) <- m;
                    w <- read (ProgramCounter s) s;
                    (* I can make a `word (n * 8 + 8)`, but need a word ((n.+1) * 8).
                        `eq_rect` does the conversion.
                    *)
                    determined
                        ( set ProgramCounter (increment 1) s
                        , eq_rect _ _ (concat (acc, w)) _ (addnC _ _)
                        )
                )
                (mode_width mode);
        (* Run the instruction. *)
        run_instruction instr mode bytes s
    else unspecified.

Inductive reachable {rom : ROM} : state rom -> Type :=
| Init s : possible init s -> reachable s
| Step s1 s2 : reachable s1 -> possible (step s1) s2 -> reachable s2
.