//! This is based on a copy of [`Decimal::from_str_exact`], with adaptations to handle german decimals.

use rust_decimal::Decimal;
use rust_decimal::Error;

pub mod serde;

// Determines potential overflow for 128 bit operations
const OVERFLOW_U96: u128 = 1u128 << 96;
const WILL_OVERFLOW_U64: u64 = u64::MAX / 10 - u8::MAX as u64;
const BYTES_TO_OVERFLOW_U64: usize = 18; // We can probably get away with less

#[inline]
pub fn parse(value: &str) -> Result<Decimal, Error> {
    let bytes = value.as_bytes();
    if bytes.len() < BYTES_TO_OVERFLOW_U64 {
        parse_str_radix_10_dispatch::<false>(bytes)
    } else {
        parse_str_radix_10_dispatch::<true>(bytes)
    }
}

#[inline]
fn parse_str_radix_10_dispatch<const BIG: bool>(bytes: &[u8]) -> Result<Decimal, Error> {
    match bytes {
        [b, rest @ ..] => byte_dispatch_u64::<false, false, false, BIG, true>(rest, 0, 0, *b),
        [] => tail_error("invalid decimal: empty"),
    }
}

#[inline]
fn overflow_64(val: u64) -> bool {
    val >= WILL_OVERFLOW_U64
}

#[inline]
fn overflow_128(val: u128) -> bool {
    val >= OVERFLOW_U96
}

/// Dispatch the next byte:
///
/// * `SAW_DECIMAL_SEPARATOR` - a decimal point has been seen
/// * `NEGATIVE` - we've encountered a `-` and the number is negative
/// * `SAW_DIGIT` - a digit has been encountered (when HAS is false it's invalid)
/// * `BIG` - a number that uses 96 bits instead of only 64 bits
/// * `FIRST` - true if it is the first byte in the string
#[inline]
fn dispatch_next<
    const SAW_DECIMAL_SEPARATOR: bool,
    const NEGATIVE: bool,
    const SAW_DIGIT: bool,
    const BIG: bool,
>(
    bytes: &[u8],
    data64: u64,
    scale: u8,
) -> Result<Decimal, Error> {
    if let Some((next, bytes)) = bytes.split_first() {
        byte_dispatch_u64::<SAW_DECIMAL_SEPARATOR, NEGATIVE, SAW_DIGIT, BIG, false>(
            bytes, data64, scale, *next,
        )
    } else {
        handle_data::<NEGATIVE, SAW_DIGIT>(data64 as u128, scale)
    }
}

#[inline(never)]
fn non_digit_dispatch_u64<
    const SAW_DECIMAL_SEPARATOR: bool,
    const NEG: bool,
    const NON_EMPTY: bool,
    const BIG: bool,
    const FIRST: bool,
>(
    bytes: &[u8],
    data64: u64,
    scale: u8,
    b: u8,
) -> Result<Decimal, Error> {
    match b {
        b'-' if FIRST && !NON_EMPTY => {
            dispatch_next::<false, true, false, BIG>(bytes, data64, scale)
        }
        b'+' if FIRST && !NON_EMPTY => {
            dispatch_next::<false, false, false, BIG>(bytes, data64, scale)
        }
        b'.' if !SAW_DECIMAL_SEPARATOR && NON_EMPTY => {
            handle_separator::<SAW_DECIMAL_SEPARATOR, NEG, BIG>(bytes, data64, scale)
        }
        b => tail_invalid_digit(b),
    }
}

#[inline]
fn byte_dispatch_u64<
    const SAW_DECIMAL_SEPARATOR: bool,
    const NEGATIVE: bool,
    const NON_EMPTY: bool,
    const BIG: bool,
    const FIRST: bool,
>(
    bytes: &[u8],
    data64: u64,
    scale: u8,
    b: u8,
) -> Result<Decimal, Error> {
    match b {
        b'0'..=b'9' => {
            handle_digit_64::<SAW_DECIMAL_SEPARATOR, NEGATIVE, BIG>(bytes, data64, scale, b - b'0')
        }
        b',' if !SAW_DECIMAL_SEPARATOR => {
            handle_point::<NEGATIVE, NON_EMPTY, BIG>(bytes, data64, scale)
        }
        b => non_digit_dispatch_u64::<SAW_DECIMAL_SEPARATOR, NEGATIVE, NON_EMPTY, BIG, FIRST>(
            bytes, data64, scale, b,
        ),
    }
}

#[inline(never)]
fn handle_digit_64<const SAW_DECIMAL_SEPARATOR: bool, const NEGATIVE: bool, const BIG: bool>(
    bytes: &[u8],
    data64: u64,
    scale: u8,
    digit: u8,
) -> Result<Decimal, Error> {
    // we have already validated that we cannot overflow
    let data64 = data64 * 10 + digit as u64;
    let scale = if SAW_DECIMAL_SEPARATOR { scale + 1 } else { 0 };

    if let Some((next, bytes)) = bytes.split_first() {
        let next = *next;
        if SAW_DECIMAL_SEPARATOR && BIG && scale >= 28 {
            Err(Error::Underflow)
        } else if BIG && overflow_64(data64) {
            handle_full_128::<SAW_DECIMAL_SEPARATOR, NEGATIVE>(data64 as u128, bytes, scale, next)
        } else {
            byte_dispatch_u64::<SAW_DECIMAL_SEPARATOR, NEGATIVE, true, BIG, false>(
                bytes, data64, scale, next,
            )
        }
    } else {
        let data: u128 = data64 as u128;

        handle_data::<NEGATIVE, true>(data, scale)
    }
}

#[inline(never)]
fn handle_point<const NEG: bool, const NON_EMPTY: bool, const BIG: bool>(
    bytes: &[u8],
    data64: u64,
    scale: u8,
) -> Result<Decimal, Error> {
    dispatch_next::<true, NEG, NON_EMPTY, BIG>(bytes, data64, scale)
}

#[inline(never)]
fn handle_separator<const SAW_DECIMAL_SEPARATOR: bool, const NEG: bool, const BIG: bool>(
    bytes: &[u8],
    data64: u64,
    scale: u8,
) -> Result<Decimal, Error> {
    dispatch_next::<SAW_DECIMAL_SEPARATOR, NEG, true, BIG>(bytes, data64, scale)
}

#[cold]
fn tail_error(from: &'static str) -> Result<Decimal, Error> {
    Err(from.into())
}

#[inline(never)]
#[cold]
fn tail_invalid_digit(digit: u8) -> Result<Decimal, Error> {
    match digit {
        b',' => tail_error("invalid decimal: two decimal points"),
        // b'_' => tail_error("Invalid decimal: must start lead with a number"),
        _ => tail_error("invalid decimal: unknown character"),
    }
}

#[inline(never)]
#[cold]
fn handle_full_128<const SAW_DECIMAL_SEPARATOR: bool, const NEG: bool>(
    mut data: u128,
    bytes: &[u8],
    scale: u8,
    next_byte: u8,
) -> Result<Decimal, Error> {
    let b = next_byte;
    match b {
        b'0'..=b'9' => {
            let digit = u32::from(b - b'0');

            // If the data is going to overflow then we should go into recovery mode
            let next = (data * 10) + digit as u128;
            if overflow_128(next) {
                if !SAW_DECIMAL_SEPARATOR {
                    tail_error("invalid decimal: overflow from too many digits")
                } else {
                    Err(Error::Underflow)
                }
            } else {
                data = next;
                let scale = scale + SAW_DECIMAL_SEPARATOR as u8;
                if let Some((next, bytes)) = bytes.split_first() {
                    let next = *next;
                    if SAW_DECIMAL_SEPARATOR && scale >= 28 {
                        Err(Error::Underflow)
                    } else {
                        handle_full_128::<SAW_DECIMAL_SEPARATOR, NEG>(data, bytes, scale, next)
                    }
                } else {
                    handle_data::<NEG, true>(data, scale)
                }
            }
        }
        b',' if !SAW_DECIMAL_SEPARATOR => {
            // This call won't tail?
            if let Some((next, bytes)) = bytes.split_first() {
                handle_full_128::<true, NEG>(data, bytes, scale, *next)
            } else {
                handle_data::<NEG, true>(data, scale)
            }
        }
        b'.' => {
            if let Some((next, bytes)) = bytes.split_first() {
                handle_full_128::<SAW_DECIMAL_SEPARATOR, NEG>(data, bytes, scale, *next)
            } else {
                handle_data::<NEG, true>(data, scale)
            }
        }
        b => tail_invalid_digit(b),
    }
}

#[inline(never)]
fn tail_empty() -> Result<Decimal, Error> {
    tail_error("invalid decimal: no digits found")
}

#[inline]
fn handle_data<const NEG: bool, const HAS: bool>(data: u128, scale: u8) -> Result<Decimal, Error> {
    debug_assert_eq!(data >> 96, 0);

    if !HAS {
        tail_empty()
    } else {
        Ok(Decimal::from_parts(
            data as u32,
            (data >> 32) as u32,
            (data >> 64) as u32,
            NEG,
            scale as u32,
        ))
    }
}