//! Bytecode and related structures

// all bytecode should fit within a u32, so
// all instructions look like:
// NOTE Endianness is undefined for now as this is not meant to be serialized.

/// Bytecode instructions for a virtual machine
///
/// ## Shapes
/// If a shape is not mentioned, the instructions is in shape A.
/// - INST (8bits) A (24bits) (shape A)
/// - INST (8bits) A (8bits) B (16bits) (shape B)
/// - INST (8bits) A (8bits) B (8bits) C (8bits) (shape C)
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum OpCode {
    /// Do nothing.
    NoOp,
    /// Push argument as an unsigned integer
    PushInt,
    /// Push argument as a signed integer
    PushSignedInt,
    /// Push from constant table
    PushConst,
    /// (Shape B) Create a list object, and set the SRL register at (A != 0) to that object (for self-reference)
    ///
    /// It will simply create an empty list object if A == 0 (without setting an SRL register)
    PushList,
    /// (Shape B) Create a list object using a number of items (B) from the stack, with the
    /// top-most item being the last item of the list
    ///
    /// If SRL register at (A != 0) is set, instead set the items of the list in the SRL register
    /// to the items that would have been used to create a new list, and unset the register
    CreateList,
    /// Creates a new stack frame, pops number of arguments (A) from stack, then pops
    /// and calls a procedure
    Call,
    /// Does the same thing as [`OpCode::Call`], but replaces the current stack frame
    TailCall,
    /// Return from the current stack frame, pushing the value at the top of the
    /// current stack frame's stack to the stack of the stack frame below.
    ///
    /// At the top-level, this stops execution (TODO).
    /// If the stack is empty, returns the empty list.
    ///
    // NOTE(impl): A Call followed by this (a Return) should *always*
    // be rewritten into a TailCall.
    Return,
    /// Pushes the local item referenced by the symbol (A).
    LoadLocal,
    /// Pops the desired value to store at the local reference by symbol (A)
    StoreLocal,
    /// Push the current continuation to the stack
    // TODO (maybe use A as a pointer offset from the current instruction?)
    PushCurrentContinuation,
    /// (Shape B) jumps (B) instructions ahead if (A == 0)
    /// and jumps (B) instructions back if (A == 1)
    /// in the case that the item it popped from the stack is false (#f)
    JumpIfFalse,
}

/*
    A continuation from this vm's perspective might be:
    - module code to reference
    - current environment
    - index of the next instruction to execute
    - symbol table (for use with debugging, all functions have a symbol table)
    - local table

    REMEMBER a continution would *not* contain the stack of the VM, as
    that represents the *data* of the execution (maybe? I'd have to implement to be sure)
*/

impl From<OpCode> for u8 {
    fn from(value: OpCode) -> Self {
        // safe due to repr(u8)
        value as u8
    }
}

impl TryFrom<u8> for OpCode {
    // There is only 1 possible error: the u8 doesn't represent a valid op-code
    type Error = ();
    fn try_from(value: u8) -> Result<Self, Self::Error> {
        match value {
            0 => Ok(Self::NoOp),
            1 => Ok(Self::PushInt),
            2 => Ok(Self::PushSignedInt),
            3 => Ok(Self::PushConst),
            4 => Ok(Self::PushList),
            5 => Ok(Self::CreateList),
            6 => Ok(Self::Call),
            7 => Ok(Self::TailCall),
            8 => Ok(Self::Return),
            9 => Ok(Self::LoadLocal),
            10 => Ok(Self::StoreLocal),
            11 => Ok(Self::PushCurrentContinuation),
            12 => Ok(Self::JumpIfFalse),
            _ => Err(()),
        }
    }
}

/// Arguments for an [`OpCode`]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpArguments {
    /// Shape A - only the last 24 bits of the value are used
    A(u32),
    /// Shape B
    B(u8, u16),
    /// Shape C
    C(u8, u8, u8),
}

impl OpArguments {
    /// Mask used to extract arguments from a given bytecode
    pub const ARGUMENT_MASK: u32 = 0x00ff_ffff;

    /// Treat arguments as a Shape A
    pub fn a(self) -> u32 {
        match self {
            Self::A(a) => a & Self::ARGUMENT_MASK,
            Self::B(a, b) => (a as u32) << 16 | b as u32,
            Self::C(a, b, c) => (a as u32) << 16 | (b as u32) << 8 | c as u32,
        }
    }

    /// Treat arguments as a Shape B
    pub fn b(self) -> (u8, u16) {
        match self {
            Self::A(a) => ((a >> 16) as u8, a as u16),
            Self::B(a, b) => (a, b),
            Self::C(a, b, c) => (a, (b as u16) << 8 | c as u16),
        }
    }

    /// Treat arguments as a Shape C
    pub fn c(self) -> (u8, u8, u8) {
        match self {
            Self::A(a) => ((a >> 16) as u8, (a >> 8) as u8, a as u8),
            Self::B(a, b) => (a, (b >> 8) as u8, b as u8),
            Self::C(a, b, c) => (a, b, c),
        }
    }

    // Private cuz
    fn stitch(self) -> u32 {
        match self {
            Self::A(a) => a & Self::ARGUMENT_MASK,
            Self::B(a, b) => ((a as u32) << 16) | (b as u32),
            Self::C(a, b, c) => ((a as u32) << 16) | ((b as u32) << 8) | (c as u32),
        }
    }

    fn parse_a(n: u32) -> Self {
        let masked = n & Self::ARGUMENT_MASK;
        Self::A(masked)
    }

    fn parse_b(n: u32) -> Self {
        let masked = n & Self::ARGUMENT_MASK;
        Self::B((masked >> 16) as u8, masked as u16)
    }

    fn parse_c(n: u32) -> Self {
        let masked = n & Self::ARGUMENT_MASK;
        Self::C((masked >> 16) as u8, (masked >> 8) as u8, masked as u8)
    }
}

/// Representation of valid bytecode
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ByteCode {
    /// The [`OpCode`] of a given bytecode
    pub op_code: OpCode,
    /// Arguments specified of a given bytecode
    pub arguments: OpArguments,
}

impl ByteCode {
    /// Produces a single number representing this code
    pub fn stitch(self) -> u32 {
        ((self.op_code as u32) << 24) | self.arguments.stitch()
    }

    /// Parse a valid bytecode from a given raw bytecode
    pub fn parse(n: u32) -> Option<Self> {
        // Get the op code
        let op_code = (((n & !OpArguments::ARGUMENT_MASK) >> 24) as u8)
            .try_into()
            .ok()?;
        // The only fallible part is getting the opcode. All other operations are infallible
        Some(match op_code {
            OpCode::NoOp => Self {
                op_code,
                arguments: OpArguments::parse_a(n),
            },
            OpCode::PushInt => Self {
                op_code,
                arguments: OpArguments::parse_a(n),
            },
            OpCode::PushSignedInt => Self {
                op_code,
                arguments: OpArguments::parse_a(n),
            },
            OpCode::PushConst => Self {
                op_code,
                arguments: OpArguments::parse_b(n),
            },
            OpCode::PushList => Self {
                op_code,
                arguments: OpArguments::parse_b(n),
            },
            OpCode::CreateList => Self {
                op_code,
                arguments: OpArguments::parse_b(n),
            },
            OpCode::Call => Self {
                op_code,
                arguments: OpArguments::parse_a(n),
            },
            OpCode::TailCall => Self {
                op_code,
                arguments: OpArguments::parse_a(n),
            },
            OpCode::Return => Self {
                op_code,
                arguments: OpArguments::parse_a(n),
            },
            OpCode::LoadLocal => Self {
                op_code,
                arguments: OpArguments::parse_a(n),
            },
            OpCode::StoreLocal => Self {
                op_code,
                arguments: OpArguments::parse_a(n),
            },
            OpCode::PushCurrentContinuation => Self {
                op_code,
                // TODO determine shape
                arguments: OpArguments::parse_a(n),
            },
            OpCode::JumpIfFalse => Self {
                op_code,
                arguments: OpArguments::parse_b(n),
            },
        })
    }
}

impl TryFrom<u32> for ByteCode {
    type Error = ();
    fn try_from(value: u32) -> Result<Self, Self::Error> {
        Self::parse(value).ok_or(())
    }
}

#[cfg(test)]
mod tests {
    use super::{ByteCode, OpArguments, OpCode};
    use assert2::assert;

    #[test]
    fn bytecode_parse() {
        let code = 0x00debeef;
        let op = ByteCode::parse(code);
        assert!(
            op == Some(ByteCode {
                op_code: OpCode::NoOp,
                arguments: OpArguments::A(0xdebeef)
            })
        );

        let code = 0xffadbeef;
        let op = ByteCode::parse(code);
        assert!(op == None);
    }

    #[test]
    fn bytecode_stitch() {
        let bc = ByteCode {
            op_code: OpCode::NoOp,
            arguments: OpArguments::C(0xca, 0xbe, 0xef),
        };
        assert!(bc.stitch() == 0x00cabeef);
    }

    #[test]
    fn consistency() {
        for i in 0..=255u8 {
            let produced: Result<OpCode, _> = i.try_into();
            if let Ok(op) = produced {
                let roundtrip_i: u8 = op.into();
                assert!(i == roundtrip_i);
            }
        }
    }
}