use core::panic;
use std::{cell::UnsafeCell, mem, sync::Arc};

static NUMBER_FONT: [u8; 5 * 16] = [
    0xF0, 0x90, 0x90, 0x90, 0xF0, // 0
    0x20, 0x60, 0x20, 0x20, 0x70, // 1
    0xF0, 0x10, 0xF0, 0x80, 0xF0, // 2
    0xF0, 0x10, 0xF0, 0x10, 0xF0, // 3
    0x90, 0x90, 0xF0, 0x10, 0x10, // 4
    0xF0, 0x80, 0xF0, 0x10, 0xF0, // 5
    0xF0, 0x80, 0xF0, 0x90, 0xF0, // 6
    0xF0, 0x10, 0x20, 0x40, 0x40, // 7
    0xF0, 0x90, 0xF0, 0x90, 0xF0, // 8
    0xF0, 0x90, 0xF0, 0x10, 0xF0, // 9
    0xF0, 0x90, 0xF0, 0x90, 0x90, // A
    0xE0, 0x90, 0xE0, 0x90, 0xE0, // B
    0xF0, 0x80, 0x80, 0x80, 0xF0, // C
    0xE0, 0x90, 0x90, 0x90, 0xE0, // D
    0xF0, 0x80, 0xF0, 0x80, 0xF0, // E
    0xF0, 0x80, 0xF0, 0x80, 0x80, // F
];

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum KeyCode {
    Zero,
    One,
    Two,
    Three,
    Four,
    Five,
    Six,
    Seven,
    Height,
    Nine,
    A,
    B,
    C,
    D,
    E,
    F,
}

impl core::fmt::Display for KeyCode {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match *self {
            KeyCode::Zero => f.write_str("0"),
            KeyCode::One => f.write_str("1"),
            KeyCode::Two => f.write_str("2"),
            KeyCode::Three => f.write_str("3"),
            KeyCode::Four => f.write_str("4"),
            KeyCode::Five => f.write_str("5"),
            KeyCode::Six => f.write_str("6"),
            KeyCode::Seven => f.write_str("7"),
            KeyCode::Height => f.write_str("8"),
            KeyCode::Nine => f.write_str("9"),
            KeyCode::A => f.write_str("A"),
            KeyCode::B => f.write_str("B"),
            KeyCode::C => f.write_str("C"),
            KeyCode::D => f.write_str("D"),
            KeyCode::E => f.write_str("E"),
            KeyCode::F => f.write_str("F"),
        }
    }
}

impl KeyCode {
    pub fn to_u8(&self) -> u8 {
        let val = unsafe { mem::transmute_copy(self) };
        // println!("{self}={val}");
        val
    }
}

impl TryFrom<char> for KeyCode {
    type Error = ();

    fn try_from(value: char) -> Result<Self, Self::Error> {
        Ok(match value {
            '&' => KeyCode::One,
            'é' => KeyCode::Two,
            '"' => KeyCode::Three,
            '\'' => KeyCode::C,
            'a' => KeyCode::Four,
            'z' => KeyCode::Five,
            'e' => KeyCode::Six,
            'r' => KeyCode::D,
            'q' => KeyCode::Seven,
            's' => KeyCode::Height,
            'd' => KeyCode::Nine,
            'f' => KeyCode::E,
            'w' => KeyCode::A,
            'x' => KeyCode::Zero,
            'c' => KeyCode::B,
            'v' => KeyCode::F,
            _ => return Err(()),
        })
    }
}

struct Memory {
    buf: [u8; 4096],
}

impl Memory {
    fn new() -> Self {
        Memory { buf: [0; 4096] }
    }

    fn load_at(&mut self, at: u16, bytes: &[u8]) {
        let start = at as usize;
        let end = start + bytes.len();
        // println!("{bytes:x?}");
        self.buf[start..end].copy_from_slice(bytes);
    }

    fn read(&self, at: u16) -> u8 {
        self.buf[at as usize]
    }

    fn write(&mut self, at: u16, val: u8) {
        self.buf[at as usize] = val;
    }
}

pub const WIDTH: usize = 64;
pub const HEIGHT: usize = 32;

pub struct Display {
    buf: UnsafeCell<[u32; WIDTH * HEIGHT]>,
}

unsafe impl Send for Display {}
unsafe impl Sync for Display {}

impl Display {
    fn new() -> Self {
        Self {
            buf: UnsafeCell::new([0; WIDTH * HEIGHT]),
        }
    }

    pub fn get_buffer(&self) -> &[u32] {
        &(unsafe { &*self.buf.get() })[..]
    }

    fn get_buffer_mut(&self) -> &mut [u32] {
        unsafe { &mut *self.buf.get() }
    }

    fn clear(&self) {
        for px in self.get_buffer_mut() {
            *px = u32::MIN;
        }
    }

    fn get_px(&self, x: u8, y: u8) -> u8 {
        // println!("getpx {x} {y}");
        let x = x as usize;
        let y = y as usize;
        debug_assert!(x < WIDTH);
        debug_assert!(y < HEIGHT);

        let px = self.get_buffer()[y * WIDTH + x];
        if px == u32::MIN {
            0
        } else {
            1
        }
    }

    fn put_px(&self, x: u8, y: u8, px: u8) {
        // println!("putpx {x} {y} {px}");
        debug_assert!(px == 0 || px == 1);
        let x = x as usize;
        let y = y as usize;
        debug_assert!(x < WIDTH);
        debug_assert!(y < HEIGHT);

        let px = if px == 0 { u32::MIN } else { u32::MAX };
        self.get_buffer_mut()[y * WIDTH + x] = px;
    }
}

const STACK_LEN: usize = 16;

struct Stack {
    buf: [u16; STACK_LEN],
    idx: i8,
}

impl Stack {
    fn new() -> Self {
        Self {
            buf: [0; STACK_LEN],
            idx: -1,
        }
    }

    fn push(&mut self, elem: u16) -> Option<()> {
        if self.idx as usize == STACK_LEN - 1 {
            None
        } else {
            self.idx += 1;
            self.buf[self.idx as usize] = elem;
            Some(())
        }
    }

    fn pop(&mut self) -> Option<u16> {
        if self.idx == -1 {
            None
        } else {
            let elem = self.buf[self.idx as usize];
            self.idx -= 1;
            Some(elem)
        }
    }
}

#[derive(Debug, Clone, Copy)]
enum Opcode {
    ClearScreen,
    Jump(u16),
    SetRegNN(u8, u8),
    AddRegNN(u8, u8),
    SetRegI(u16),
    Draw(u8, u8, u8),
    Ret,
    Call(u16),
    SkipEqNN(u8, u8),
    SkipNeNN(u8, u8),
    SkipNeReg(u8, u8),
    SkipEqReg(u8, u8),
    SetRegReg(u8, u8),
    Or(u8, u8),
    And(u8, u8),
    Xor(u8, u8),
    AddRegReg(u8, u8),
    SubRegXY(u8, u8),
    SubRegYX(u8, u8),
    ShiftLeft(u8, u8),
    ShiftRight(u8, u8),
    JumpRegNNN(u16),
    RndNN(u8, u8),
    SkipNotPressed(u8),
    SkipPressed(u8),
    SetRegDelay(u8),
    SetDelayReg(u8),
    SetSoundReg(u8),
    AddIReg(u8),
    GetKey(u8),
    SetIFont(u8),
    Bcd(u8),
    Store(u8),
    Load(u8),
}

impl From<u16> for Opcode {
    fn from(instruction: u16) -> Self {
        let op = instruction >> 12;
        let x = ((instruction >> 8) & 0xF) as u8;
        let y = ((instruction >> 4) & 0xF) as u8;
        let n = (instruction & 0xF) as u8;
        let nn = (instruction & 0xFF) as u8;
        let nnn = instruction & 0xFFF;

        match op {
            0 => match nnn {
                0x0E0 => Opcode::ClearScreen,
                0x0EE => Opcode::Ret,
                _ => todo!(),
            },
            1 => Opcode::Jump(nnn),
            2 => Opcode::Call(nnn),
            3 => Opcode::SkipEqNN(x, nn),
            4 => Opcode::SkipNeNN(x, nn),
            5 => Opcode::SkipEqReg(x, y),
            6 => Opcode::SetRegNN(x, nn),
            7 => Opcode::AddRegNN(x, nn),
            8 => match n {
                0 => Opcode::SetRegReg(x, y),
                1 => Opcode::Or(x, y),
                2 => Opcode::And(x, y),
                3 => Opcode::Xor(x, y),
                4 => Opcode::AddRegReg(x, y),
                5 => Opcode::SubRegXY(x, y),
                6 => Opcode::ShiftRight(x, y),
                7 => Opcode::SubRegYX(x, y),
                0xE => Opcode::ShiftLeft(x, y),
                _ => panic!(),
            },
            9 => Opcode::SkipNeReg(x, y),
            0xA => Opcode::SetRegI(nnn),
            0xB => Opcode::JumpRegNNN(nnn),
            0xC => Opcode::RndNN(x, nn),
            0xD => Opcode::Draw(x, y, n),
            0xE => match nn {
                0x9E => Opcode::SkipPressed(x),
                0xA1 => Opcode::SkipNotPressed(x),
                _ => panic!(),
            },
            0xF => match nn {
                0x07 => Opcode::SetRegDelay(x),
                0x15 => Opcode::SetDelayReg(x),
                0x18 => Opcode::SetSoundReg(x),
                0x1E => Opcode::AddIReg(x),
                0x0A => Opcode::GetKey(x),
                0x29 => Opcode::SetIFont(x),
                0x33 => Opcode::Bcd(x),
                0x55 => Opcode::Store(x),
                0x65 => Opcode::Load(x),
                _ => panic!(),
            },
            _ => todo!("{:x}", instruction),
        }
    }
}

const REGISTER_COUNT: usize = 16;

struct RegisterFile {
    regs: [u8; REGISTER_COUNT],
}

impl RegisterFile {
    fn new() -> Self {
        Self {
            regs: [0; REGISTER_COUNT],
        }
    }

    pub fn read(&mut self, reg: u8) -> u8 {
        self.regs[reg as usize]
    }
    fn write(&mut self, reg: u8, nn: u8) {
        self.regs[reg as usize] = nn;
    }
}

const START_PROGRAM: u16 = 0x200;
const START_NUMBER_FONT: u16 = 0x50;

pub struct Chip8 {
    mem: Memory,
    pub display: Arc<Display>,
    stack: Stack,
    regs: RegisterFile,
    pc: u16,
    i: u16,
    key_pressed: Option<u8>,
    delay_timer: Option<u8>,
    sound_timer: Option<u8>,
}

impl Chip8 {
    pub fn new(program: &[u8]) -> Self {
        let mut mem = Memory::new();
        let display = Arc::new(Display::new());
        let stack = Stack::new();
        mem.load_at(START_NUMBER_FONT, &NUMBER_FONT[..]);
        mem.load_at(START_PROGRAM, program);
        let regs = RegisterFile::new();
        Chip8 {
            mem,
            display,
            stack,
            regs,
            pc: START_PROGRAM,
            i: 0,
            key_pressed: None,
            delay_timer: None,
            sound_timer: None,
        }
    }

    pub fn step(&mut self) {
        // fetch
        let first_byte = self.mem.read(self.pc) as u16;
        let second_byte = self.mem.read(self.pc + 1) as u16;
        let instruction = first_byte << 8 | second_byte;
        self.pc += 2;

        // decode
        let opcode = instruction.into();
        // dbg!(opcode);

        // execute
        match opcode {
            Opcode::ClearScreen => {
                self.display.clear();
            }
            Opcode::Jump(nnn) => self.pc = nnn,
            Opcode::SetRegNN(x, nn) => self.regs.write(x, nn),
            Opcode::AddRegNN(x, nn) => {
                let val = self.regs.read(x);
                self.regs.write(x, val.wrapping_add(nn))
            }
            Opcode::SetRegI(nnn) => self.i = nnn,
            Opcode::Draw(x, y, n) => {
                let x = self.regs.read(x) & (WIDTH as u8 - 1);
                let mut y = self.regs.read(y) & (HEIGHT as u8 - 1);

                // println!("draw at {x},{y}");

                self.regs.write(0xF, 0);
                let mut set_vf = false;
                let save_x = x;
                for i in 0..n {
                    let sprite_row = self.mem.read(self.i + i as u16);
                    let mut x = save_x;
                    for j in (0..8).rev() {
                        let mask = 1 << j;
                        let sprite_px = (sprite_row & mask) >> j;
                        let display_px = self.display.get_px(x, y);

                        if sprite_px == 1 {
                            if display_px == 1 {
                                // println!("clear");
                                self.display.put_px(x, y, 0);
                                set_vf |= sprite_px == display_px;
                            } else {
                                self.display.put_px(x, y, 1);
                            }
                        }
                        x += 1;
                        if x as usize == WIDTH {
                            break;
                        }
                    }
                    y += 1;
                    if y as usize == HEIGHT {
                        break;
                    }
                }
                if set_vf {
                    self.regs.write(0xF, 1);
                }
            }
            Opcode::Call(nnn) => {
                self.stack.push(self.pc).unwrap();
                self.pc = nnn;
            }
            Opcode::Ret => {
                self.pc = self.stack.pop().unwrap();
            }
            Opcode::SkipEqNN(x, nn) => {
                if self.regs.read(x) == nn {
                    self.pc += 2;
                }
            }
            Opcode::SkipNeNN(x, nn) => {
                if self.regs.read(x) != nn {
                    self.pc += 2;
                }
            }
            Opcode::SkipEqReg(x, y) => {
                if self.regs.read(x) == self.regs.read(y) {
                    self.pc += 2;
                }
            }
            Opcode::SkipNeReg(x, y) => {
                if self.regs.read(x) != self.regs.read(y) {
                    self.pc += 2;
                }
            }
            Opcode::SetRegReg(x, y) => {
                let vy = self.regs.read(y);
                self.regs.write(x, vy);
            }
            Opcode::Or(x, y) => {
                let vx = self.regs.read(x);
                let vy = self.regs.read(y);
                self.regs.write(x, vx | vy);
            }

            Opcode::And(x, y) => {
                let vx = self.regs.read(x);
                let vy = self.regs.read(y);
                self.regs.write(x, vx & vy);
            }
            Opcode::Xor(x, y) => {
                let vx = self.regs.read(x);
                let vy = self.regs.read(y);
                self.regs.write(x, vx ^ vy);
            }
            Opcode::AddRegReg(x, y) => {
                let vx = self.regs.read(x);
                let vy = self.regs.read(y);
                let res = vx as u16 + vy as u16;
                self.regs.write(x, res as u8);
                self.regs.write(0xF, (res > u8::MAX as u16) as u8);
            }
            Opcode::SubRegXY(x, y) => {
                let vx = self.regs.read(x);
                let vy = self.regs.read(y);
                let vf = (vx > vy) as u8;
                // println!("{vx} {vy} {vf}");
                // self.regs.write(0xF, (vx > vy) as u8);
                self.regs.write(x, vx.wrapping_sub(vy));
                self.regs.write(0xF, vf);
            }
            Opcode::SubRegYX(x, y) => {
                let vx = self.regs.read(x);
                let vy = self.regs.read(y);
                let vf = (vy > vx) as u8;
                // println!("{vx} {vy} {vf}");
                self.regs.write(x, vy.wrapping_sub(vx));
                self.regs.write(0xF, vf);
            }
            Opcode::ShiftLeft(x, y) => {
                let vx = self.regs.read(x);
                let _vy = self.regs.read(y);
                // vx = vy;
                self.regs.write(x, vx << 1);
                self.regs.write(0xF, (vx & (1 << 7) != 0) as u8);
            }
            Opcode::ShiftRight(x, y) => {
                let vx = self.regs.read(x);
                let _vy = self.regs.read(y);
                // vx = vy;
                self.regs.write(x, vx >> 1);
                self.regs.write(0xF, (vx & 1 != 0) as u8);
            }
            Opcode::JumpRegNNN(nnn) => {
                let v0 = self.regs.read(0);
                self.pc = v0 as u16 + nnn;
            }
            Opcode::RndNN(x, nn) => {
                let rnd: u8 = rand::random();
                self.regs.write(x, rnd & nn);
            }
            Opcode::SkipPressed(x) => {
                let vx = self.regs.read(x);
                if let Some(key_pressed) = self.key_pressed {
                    if key_pressed == vx {
                        self.pc += 2;
                    }
                }
            }
            Opcode::SkipNotPressed(x) => {
                let vx = self.regs.read(x);
                if let Some(key_pressed) = self.key_pressed {
                    if key_pressed != vx {
                        self.pc += 2;
                    }
                }
            }
            Opcode::SetRegDelay(x) => {
                let val = self.delay_timer.unwrap_or(0);
                // println!("delay set v{x} = {val}");
                self.regs.write(x, val);
            }
            Opcode::SetDelayReg(x) => {
                let vx = self.regs.read(x);
                // println!("delay = {vx}");
                self.delay_timer = Some(vx);
            }
            Opcode::SetSoundReg(x) => {
                let vx = self.regs.read(x);
                self.sound_timer = Some(vx);
            }
            Opcode::AddIReg(x) => {
                let vx = self.regs.read(x);
                self.i += vx as u16;
            }
            Opcode::GetKey(x) => {
                if let Some(key_pressed) = self.key_pressed {
                    self.regs.write(x, key_pressed);
                } else {
                    // reexecute the instruction
                    self.pc -= 2;
                }
            }
            Opcode::SetIFont(x) => {
                let vx = self.regs.read(x);
                self.i = START_NUMBER_FONT + (vx as u16 & 0xF) * 5;
            }
            Opcode::Bcd(x) => {
                let vx = self.regs.read(x);
                let huns = vx / 100;
                let tens = (vx / 10) % 10;
                let ones = vx % 10;
                // println!("{vx} = {huns} {tens} {ones}");

                self.mem.write(self.i, huns);
                self.mem.write(self.i + 1, tens);
                self.mem.write(self.i + 2, ones);
            }
            Opcode::Store(x) => {
                for i in 0..=x {
                    let vi = self.regs.read(i);
                    self.mem.write(self.i + i as u16, vi);
                }
            }
            Opcode::Load(x) => {
                for i in 0..=x {
                    let vi = self.mem.read(self.i + i as u16);
                    self.regs.write(i, vi);
                }
            }
        }
    }

    pub fn set_key_pressed(&mut self, key: KeyCode) {
        self.key_pressed = Some(key.to_u8());
    }

    pub fn decr_timers(&mut self) {
        if let Some(delay) = self.delay_timer.take() {
            if delay != 0 {
                // println!("decr delay");
                self.delay_timer = Some(delay - 1);
            }
        }
        if let Some(sound) = self.sound_timer.take() {
            if sound != 0 {
                // println!("decr sound");
                self.sound_timer = Some(sound - 1);
            }
        }
    }

    pub fn get_delay_timer(&self) -> Option<u8> {
        self.delay_timer.clone()
    }
    // pub fn set_delay_timer(&mut self, val: u8) {
    //     self.delay_timer = Some(val);
    // }

    // pub fn set_sound_timer(&mut self, val: u8) {
    //     self.sound_timer = Some(val);
    // }
}