//! Lexer, taking an input string, and splitting it into tokens

// We (at least initially) use logos and winnow
use logos::{Lexer, Logos};
use std::{
    collections::HashMap,
    num::{ParseFloatError, ParseIntError},
    sync::OnceLock,
};
// Use tiny winnow parsers to parse intra-element stuff
use winnow::{
    ascii::{digit1, hex_digit1},
    combinator::{alt, delimited, opt, preceded, repeat},
    error::ContextError,
    token::{any, one_of},
    PResult, Parser,
};

/// A lexing error
#[derive(thiserror::Error, Clone, PartialEq, Default, Debug)]
pub enum LexError {
    /// A char escape was used that is not implemented
    // NOTE if we didn't *need* the Clone impl due to Logos,
    // I would have stuck a Box here, but there isn't a point
    // in multiple instances of this *same* error referring
    // to unique allocations of the *same* character sequence
    #[error("Invalid character escape `{0}`")]
    InvalidCharEscape(std::sync::Arc<str>),
    /// An error occured paring an integer
    #[error(transparent)]
    IntError(#[from] ParseIntError),
    /// Lexer encountered an unknown character
    #[default]
    #[error("unknown")]
    Unknown,
    /// A winnow parse error occured
    #[error("{1} around offset {0}")]
    ParseError(usize, ContextError),
}

fn symbol_element_parse(input: &mut &str) -> PResult<Vec<char>> {
    repeat(
        0..,
        alt((
            delimited(('\\', one_of(['x', 'X'])), hex_digit1, r";").try_map(|s: &str| {
                Ok::<_, ParseIntError>(
                    char::from_u32(u32::from_str_radix(s, 16)?)
                        .unwrap_or(char::REPLACEMENT_CHARACTER),
                )
            }),
            preceded(r"\", one_of(['a', 'b', 't', 'n', 'r'])).map(|c| match c {
                'a' => 7 as char,
                'b' => 8 as char,
                't' => 9 as char,
                'n' => 0xa as char,
                'r' => 0xd as char,
                _ => unreachable!(r"Scheme R7RS only handles \a, \b, \t, \n, and \r"),
            }),
            any,
        )),
    )
    .parse_next(input)
}

fn parse_to_symbol_elements(lex: &mut Lexer<Token>) -> Result<String, LexError> {
    let slice = lex.slice();
    let target = &slice[1..slice.len() - 1];
    symbol_element_parse
        .parse(target)
        .map(|vc| vc.into_iter().collect())
        .map_err(|e| {
            // TODO A proper lex error (for ident symbols!)
            dbg!(std::any::type_name_of_val(&e));
            LexError::Unknown
        })
}

fn string_element_parse(input: &mut &str) -> PResult<Vec<char>> {
    repeat(
        0..,
        alt((
            delimited(('\\', one_of(['x', 'X'])), hex_digit1, r";")
                .try_map(|s: &str| {
                    Ok::<_, ParseIntError>(
                        char::from_u32(u32::from_str_radix(s, 16)?)
                            .unwrap_or(char::REPLACEMENT_CHARACTER),
                    )
                })
                .map(Some),
            preceded(r"\", one_of(['a', 'b', 't', 'n', 'r', '"', '\\', '|']))
                .map(|c| match c {
                    'a' => 7 as char,
                    'b' => 8 as char,
                    't' => 9 as char,
                    'n' => 0xa as char,
                    'r' => 0xd as char,
                    '"' => 0x22 as char,
                    '\\' => 0x5c as char,
                    '|' => 0x7c as char,
                    _ => unreachable!(
                        r#"Scheme R7RS strings only handle \a, \b, \t, \n, \r, \", \\, and \|"#
                    ),
                })
                .map(Some),
            delimited(
                repeat::<_, _, (), _, _>(0.., one_of([' ', '\t', '\r'])),
                '\n'.void(),
                repeat::<_, _, (), _, _>(0.., one_of([' ', '\t', '\r'])),
            )
            .map(|_| None),
            any.map(Some),
        )),
    )
    .map(|vc: Vec<Option<char>>| vc.into_iter().flatten().collect())
    .parse_next(input)
}

fn parse_to_string(lex: &mut Lexer<Token>) -> Result<String, LexError> {
    let slice = lex.slice();
    let target = &slice[1..slice.len() - 1];
    string_element_parse
        .parse(target)
        .map(|vc| vc.into_iter().collect())
        .map_err(|e| LexError::ParseError(e.offset(), e.into_inner()))
}

fn integer_number_parser(input: &mut &str) -> PResult<(i64, u32)> {
    let sign = opt(one_of(['-', '+']));
    let digits_with_underscores = || {
        repeat(1.., alt((digit1.map(Some), '_'.map(|_| None))))
            .map(|vd: Vec<Option<&str>>| vd.into_iter().flatten().collect::<Vec<_>>().join(""))
    };
    let signed_digits_with_underscores =
        (sign, digits_with_underscores()).map(|(sign, digits): (_, String)| {
            if let Some(sign) = sign {
                format!("{sign}{digits}")
            } else {
                digits
            }
        });
    let exponent = preceded(
        one_of(['e', 'E']),
        (opt('+').void(), digits_with_underscores()),
    );
    (signed_digits_with_underscores, opt(exponent))
        .try_map(|(base, maybe_exp)| {
            let base = base.parse::<i64>()?;
            let exponent = if let Some((_, exponent)) = maybe_exp {
                exponent.parse::<u32>()?
            } else {
                0
            };
            Ok::<_, ParseIntError>((base, exponent))
        })
        .parse_next(input)
}

fn parse_integer(lexer: &mut Lexer<Token>) -> Result<(i64, u32), LexError> {
    let target = lexer.slice();
    integer_number_parser
        .parse(target)
        .map_err(|e| LexError::ParseError(e.offset(), e.into_inner()))
}

/// Errors possible when parsing a decimal number
#[derive(thiserror::Error, Debug)]
pub enum DecimalParseError {
    /// Parsing a floating-point number failed
    #[error(transparent)]
    FloatError(#[from] ParseFloatError),
    /// Parsing an integer failed
    #[error(transparent)]
    IntError(#[from] ParseIntError),
}

fn decimal_number_parser(input: &mut &str) -> PResult<(f64, i32)> {
    let sign = || opt(one_of(['-', '+']));
    let digits_with_underscores = || {
        repeat(1.., alt((digit1.map(Some), '_'.map(|_| None))))
            .map(|vd: Vec<Option<&str>>| vd.into_iter().flatten().collect::<Vec<_>>().join(""))
    };
    let float_digits = || {
        (
            opt(digits_with_underscores()),
            '.',
            digits_with_underscores(),
        )
            .map(|(maybe_predot, _, postdot)| {
                maybe_predot
                    .map(|predot| format!("{predot}.{postdot}"))
                    .unwrap_or_else(|| format!(".{postdot}"))
            })
    };
    let signed_base = (sign(), alt((float_digits(), digits_with_underscores()))).map(
        |(sign, digits): (_, String)| {
            if let Some(sign) = sign {
                format!("{sign}{digits}")
            } else {
                digits
            }
        },
    );
    let exponent = preceded(one_of(['e', 'E']), (sign(), digits_with_underscores()));
    (signed_base, opt(exponent))
        .try_map(|(base, maybe_exp)| {
            let base = base.parse::<f64>()?;
            let exponent = if let Some((sign, exponent)) = maybe_exp {
                let signed = if let Some(sign) = sign {
                    format!("{sign}{exponent}")
                } else {
                    exponent
                };
                signed.parse::<i32>()?
            } else {
                0
            };
            Ok::<_, DecimalParseError>((base, exponent))
        })
        .parse_next(input)
}

fn parse_float(lexer: &mut Lexer<Token>) -> Result<(f64, i32), LexError> {
    let target = lexer.slice();
    decimal_number_parser
        .parse(target)
        .map_err(|e| LexError::ParseError(e.offset(), e.into_inner()))
}

fn parse_character(lexer: &mut Lexer<Token>) -> Result<char, LexError> {
    let target = &lexer.slice()[2..];
    static NAME_MAP: OnceLock<HashMap<Box<str>, char>> = OnceLock::new();

    let name_map = NAME_MAP.get_or_init(|| {
        let mut map = HashMap::new();
        map.insert("alarm".into(), '\u{7}');
        map.insert("backspace".into(), '\u{8}');
        map.insert("delete".into(), '\u{7f}');
        map.insert("escape".into(), '\u{1b}');
        map.insert("newline".into(), '\n');
        map.insert("null".into(), '\0');
        map.insert("return".into(), '\r');
        map.insert("space".into(), ' ');
        map.insert("tab".into(), '\t');
        // non-standard character names
        map.insert("lambda".into(), '\u{3bb}');
        map.insert("Lambda".into(), '\u{39b}');
        map
    });

    match target {
        c if name_map.contains_key(c) => name_map
            .get(c)
            .copied()
            .ok_or(LexError::InvalidCharEscape(c.into())),
        c if c.starts_with(['x', 'X'])
            && c.chars().skip(1).all(|c| c.is_ascii_hexdigit())
            // Don't match '#\x'
            && c.chars().count() > 1 =>
        {
            // We know that x is 1 byte long, so chop it off
            let hex_str = &c[1..];
            char::from_u32(u32::from_str_radix(hex_str, 16)?)
                .ok_or(LexError::InvalidCharEscape(c.into()))
        }
        c if c.chars().count() == 1 => c.chars().next().ok_or(LexError::Unknown),
        c => Err(LexError::InvalidCharEscape(c.into())),
    }
}

/// The smallest unit of lexical meaning
#[derive(Debug, Clone, PartialEq, Logos)]
#[logos(error = LexError)]
pub enum Token {
    /// Left parenthesis
    #[token("(")]
    LParen,
    /// Right parenthesis
    #[token(")")]
    RParen,
    /// Start vector token
    #[token("#(")]
    VectorStart,
    /// Start bytevector token
    #[token("#u8(")]
    #[token("#U8(")]
    BytevectorStart,

    /// Contiguous Whitespace
    #[regex(r"[ \t\r]+")]
    Whitespace,

    /// Boolean literal
    #[regex(r"#[tT]([rR][uU][eE])?", |_| true)]
    #[regex(r"#[fF]([aA][lL][sS][eE])?", |_| false)]
    Boolean(bool),
    /// Character literal
    #[regex(r"#\\(.|[a-zA-Z]+|x[0-9a-fA-F]+)", parse_character)]
    Character(char),
    /// Identifier (or "symbol literal")
    #[regex(r"[A-Za-z!$%&*+\-./:<=>?@^_-][A-Za-z0-9!$%&*+\-./:<=>?@^_-]*", |lex| lex.slice().to_string())]
    #[regex(r"\|(?:\\x[0-9a-fA-F]+;|\\[abtnr]|[^|])*\|", parse_to_symbol_elements)]
    Identifier(String),
    /// String literal
    #[regex(
        r#""(?:\\[xX][0-9a-fA-F]+;|\\[abtnr"\\|]|\\[ \t\r]*\n[ \t\r]*|[^"])*""#,
        parse_to_string
    )]
    String(String),
    // TODO lexers for num, complex, real
    /// Integer literal in `(base, exponent)` form
    ///
    /// The number represented is `base * 10 ** exponent`
    #[regex(r"[+-]?[0-9][0-9_]*", priority = 3, callback = parse_integer)]
    #[regex(r"[+-]?[0-9][0-9_]*[eE]\+?[0-9_]+", callback = parse_integer)]
    Integer((i64, u32)),
    /// Decimal literal in `(base, exponent)` form
    ///
    /// The number represented is `base * 10 ** exponent`
    #[regex(r"[+-]?[0-9][0-9_]*\.[0-9][0-9_]*", parse_float)]
    #[regex(r"[+-]?[0-9][0-9_]*\.[0-9][0-9_]*[eE][+-]?[0-9_]+", parse_float)]
    #[regex(r"[+-]?\.[0-9][0-9_]*", parse_float)]
    #[regex(r"[+-]?\.[0-9][0-9_]*[eE][+-]?[0-9_]+", parse_float)]
    #[regex(r"[+-]?[0-9][0-9_]*[eE]\-[0-9_]+", parse_float)] // Negative exponent
    Decimal((f64, i32)),

    /// Introduce a label for the following datum
    #[regex(r"#[0-9]+=", |lex| lex.slice()[1..lex.slice().len() - 1].to_string())]
    DatumLabeler(String),
    /// A placeholder for a datum referenced by a given label
    #[regex(r"#[0-9]+#", |lex| lex.slice()[1..lex.slice().len() - 1].to_string())]
    DatumLabel(String),

    /// Positive infinity
    #[regex(r"\+[iI][nN][fF]\.0")]
    InfinityPos,
    /// Negative infinity
    #[regex(r"\-[iI][nN][fF]\.0")]
    InifinityNeg,
    /// Not a number (positive)
    #[regex(r"\+[nN][aA][nN]\.0")]
    NotANumberPos,
    /// Not a number (negative)
    #[regex(r"\-[nN][aA][nN]\.0")]
    NotANumberNeg,
    /// Positive Imaginary Root `(sqrt(-1))`
    #[regex(r"\+[iI]")]
    ImaginaryPos,
    /// Negative Imaginary Root `(-sqrt(-1))`
    #[regex(r"\-[iI]")]
    ImaginaryNeg,
    /// A dot
    #[token(".", priority = 3)]
    Dot,
    /// Symbol used to represent `quote` special form
    #[token("'")]
    QuoteSymbol,
    /// Symbol used to represent `quasiquote` special form
    #[token("`")]
    QuasiquoteSymbol,
    /// Symbol used to represent `unquote` special form
    #[token(",")]
    UnquoteSymbol,
    /// Symbol used to represent `unquote-splicing` special form
    #[token(",@")]
    UnquoteSplicingSymbol,

    /// Used to start a comment that ends with `\n`
    #[token(";")]
    LineCommentStart,
    /// Contiguous line ends
    #[regex(r"\n+")]
    LineEnd,
    /// Used to start a comment that ends at the end of the following datum
    #[token("#;")]
    DatumCommentStart,
    /// Used to start a nestable block comment
    #[token("#|")]
    BlockCommentStart,
    /// Used to end a nestable block comment
    #[token("|#")]
    BlockCommentEnd,
}

/// Produce a lexer that lexes [`Token`]s from an input string
pub fn lexer(input: &str) -> Lexer<'_, Token> {
    Token::lexer(input)
}

#[cfg(test)]
mod tests {
    use super::Token;
    use assert2::check;
    use logos::Logos;

    fn lexer_test_without_whitespace<S, I>(source: S, expected: I)
    where
        S: AsRef<str>,
        I: IntoIterator<Item = (Token, std::ops::Range<usize>, &'static str)>,
    {
        let source = source.as_ref();
        let lexer = Token::lexer(source);
        for ((tok, span), (lexed, tspan, tslice)) in lexer
            .spanned()
            .filter(|(tok, _)| tok.is_err() || tok.as_ref().is_ok_and(|v| v != &Token::Whitespace))
            .zip(expected.into_iter())
        {
            let slice = &source[span.clone()];
            check!(tok == Ok(lexed));
            check!(span == tspan);
            check!(slice == tslice);
        }
    }

    #[test]
    fn lex_character_literals() {
        let source = r"#\t #\alarm #\lambda #\x03bb";
        let data = [
            (Token::Character('t'), (0..3), r"#\t"),
            (Token::Character('\u{7}'), (4..11), r"#\alarm"),
            (Token::Character('\u{3bb}'), (12..20), r"#\lambda"),
            (Token::Character('\u{3bb}'), (21..28), r"#\x03bb"),
        ];
        lexer_test_without_whitespace(source, data)
    }

    #[test]
    fn lex_booleans() {
        let source = r"#t #f #true #false";
        let data = [
            (Token::Boolean(true), (0..2), "#t"),
            (Token::Boolean(false), (3..5), "#f"),
            (Token::Boolean(true), (6..11), "#true"),
            (Token::Boolean(false), (12..18), "#false"),
        ];
        lexer_test_without_whitespace(source, data)
    }

    #[test]
    fn lex_datum_labels() {
        let source = r"#0=(#0# 2 3)";
        let data = [
            (Token::DatumLabeler("0".to_string()), (0..3), "#0="),
            (Token::LParen, (3..4), "("),
            (Token::DatumLabel("0".to_string()), (4..7), "#0#"),
            (Token::Integer((2, 0)), (8..9), "2"),
            (Token::Integer((3, 0)), (10..11), "3"),
            (Token::RParen, (11..12), ")"),
        ];

        lexer_test_without_whitespace(source, data)
    }

    #[test]
    fn lex_decimal() {
        let source = r"0.0 -0.0e0 -.0 .0e0 0e-1";
        let data = [
            ((0.0, 0), (0..3), "0.0"),
            ((-0.0, 0), (4..10), "-0.0e0"),
            ((-0.0, 0), (11..14), "-.0"),
            ((0.0, 0), (15..19), ".0e0"),
            ((0.0, -1), (20..24), "0e-1"),
        ];
        let lexer = Token::lexer(source);
        for ((_, (tok, span)), (parsed, tspan, tslice)) in lexer
            .spanned()
            .enumerate()
            // Skip all odd-indexed tokens (it's whitespace)
            .filter(|(idx, _)| idx % 2 == 0)
            .zip(data.into_iter())
        {
            let slice = &source[span.clone()];
            check!(tok == Ok(Token::Decimal(parsed)));
            check!(span == tspan);
            check!(slice == tslice);
        }
    }

    #[test]
    fn literal_identfier_equivalence() {
        let source = r"|\t\t||\x9;\x9;|";
        let mut lexer = Token::lexer(source);

        let tok1 = lexer.next();
        let tok1_slice = lexer.slice();
        let tok2 = lexer.next();
        let tok2_slice = lexer.slice();
        check!(tok1 == tok2);
        check!(tok1_slice != tok2_slice);
    }

    #[test]
    fn lex_literal_identifier() {
        let source = r"|H\x65;llo|";
        let mut lexer = Token::lexer(source);
        check!(lexer.next() == Some(Ok(Token::Identifier("Hello".to_string()))));
        check!(lexer.span() == (0..source.len()));
        check!(lexer.slice() == "|H\\x65;llo|");
    }

    #[test]
    fn lex_identifier_or_number() {
        let source = "3let -i";
        let mut lexer = Token::lexer(source);
        check!(lexer.next() == Some(Ok(Token::Integer((3, 0)))));
        check!(lexer.span() == (0..1));
        check!(lexer.slice() == "3");

        check!(lexer.next() == Some(Ok(Token::Identifier("let".to_string()))));
        check!(lexer.span() == (1..4));
        check!(lexer.slice() == "let");

        check!(lexer.next() == Some(Ok(Token::Whitespace)));
        check!(lexer.span() == (4..5));
        check!(lexer.slice() == " ");

        check!(lexer.next() == Some(Ok(Token::ImaginaryNeg)));
        check!(lexer.span() == (5..7));
        check!(lexer.slice() == "-i");
    }

    #[test]
    fn lex_identifier() {
        let source = "+soup+ $?bama";
        let mut lexer = Token::lexer(source);
        check!(lexer.next() == Some(Ok(Token::Identifier("+soup+".to_string()))));
        check!(lexer.span() == (0..6));
        check!(lexer.slice() == "+soup+");

        check!(lexer.next() == Some(Ok(Token::Whitespace)));
        check!(lexer.span() == (6..7));
        check!(lexer.slice() == " ");

        check!(lexer.next() == Some(Ok(Token::Identifier("$?bama".to_string()))));
        check!(lexer.span() == (7..13));
        check!(lexer.slice() == "$?bama");
    }
}