//! A pretty printer for [`beancount`](https://beancount.github.io) files.

extern crate alloc;

use beancount_types::Acc;
use beancount_types::Amount;
use beancount_types::Balance;
use beancount_types::Close;
use beancount_types::CostSpec;
use beancount_types::Directive;
use beancount_types::Open;
use beancount_types::Posting;
use beancount_types::Price;
use beancount_types::Transaction;
use bstr::ByteSlice as _;
use core::ops::RangeInclusive;
use fixed_decimal::FixedDecimal;
use fixed_decimal::Sign;
use icu_decimal::FixedDecimalFormatter;
use icu_locid::locale;
use once_cell::sync::Lazy;
use rayon::iter::Either;
use rayon::prelude::IntoParallelIterator;
use rayon::prelude::IntoParallelRefIterator as _;
use rayon::prelude::ParallelIterator;
use rust_decimal::Decimal;
use serde::Deserialize;
use std::io::BufWriter;
use std::io::Result;
use std::io::Write;
use unicode_segmentation::UnicodeSegmentation as _;

/// Data for pretty printing decimals with english locale.
#[allow(clippy::all, clippy::pedantic, unreachable_pub)]
mod icu_data;

// TODO consider having signs line up in a single column

#[derive(Clone, Copy, Debug, Deserialize)]
pub struct AmountConfig {
    pub sign_column: usize,
    pub decimal_separator_column: usize,
    pub commodity_column: usize,
}

impl AmountConfig {
    fn from_metrics(metrics: AmountMetrics) -> Self {
        const SIGN_WIDTH: usize = 1;

        let AmountMetrics {
            magnitude_range,
            start_column,
        } = metrics;

        let sign_column = start_column;

        let magnitude: usize = (*magnitude_range.end()).try_into().unwrap_or_default();
        let integral_digits = ((4 * magnitude) / 3) + 1;
        let left_width = integral_digits + SIGN_WIDTH;

        let decimal_separator_column = sign_column + left_width;

        let decimals = (-magnitude_range.start()).try_into().unwrap_or(0);
        let right_width = decimals + 1;
        let commodity_column = decimal_separator_column + right_width + 1;

        Self {
            sign_column,
            decimal_separator_column,
            commodity_column,
        }
    }
}

// TODO format cost basis and prices as well

#[derive(Clone, Copy, Debug, Deserialize)]
pub struct Config {
    pub account_column: usize,

    pub amount: AmountConfig,

    pub flag_column: usize,
}

impl Config {
    pub fn derive_from_directives<'d>(
        directives: impl IntoParallelIterator<Item = &'d Directive>,
    ) -> Self {
        let DirectiveMetrics {
            amount,
            posting_flag,
            ..
        } = DirectiveMetrics::derive_from_directives(directives);

        let flag_column = (posting_flag as usize) * 2;

        let account_column = flag_column + 2;

        let amount = AmountConfig::from_metrics(amount);

        Self {
            account_column,
            amount,
            flag_column,
        }
    }
}

pub struct PrettyPrinter<W> {
    config: Config,

    inner: TrackingWriter<W>,
}

impl<W> PrettyPrinter<W>
where
    W: Write,
{
    pub fn unbuffered(config: Config, inner: W) -> Self {
        let inner = TrackingWriter::new(inner);
        Self { config, inner }
    }
}

impl<W> PrettyPrinter<BufWriter<W>>
where
    W: Write,
{
    pub fn buffered(config: Config, inner: W) -> Self {
        Self::unbuffered(config, BufWriter::new(inner))
    }
}

impl<W> PrettyPrinter<W>
where
    W: Write,
{
    pub fn print_balance(&mut self, balance: &Balance) -> Result<()> {
        let Balance {
            date,
            account,
            amount,
            meta,
        } = balance;

        write!(self.inner, "{date} balance {account}")?;
        self.print_amount(self.config.amount, amount)?;

        for (key, value) in meta {
            write!(self.inner, "\n  {key}: {value}")?;
        }

        Ok(())
    }

    pub fn print_directive(&mut self, directive: &Directive) -> Result<()> {
        match directive {
            Directive::Balance(balance) => self.print_balance(balance),
            Directive::Close(close) => self.print_close(close),
            Directive::Open(open) => self.print_open(open),
            Directive::Price(price) => self.print_price(price),
            Directive::Transaction(transaction) => self.print_transaction(transaction),
        }
    }

    pub fn print_directives<'d>(
        &mut self,
        directives: impl IntoIterator<Item = &'d Directive>,
    ) -> Result<()> {
        directives
            .into_iter()
            .enumerate()
            .try_for_each(|(index, directive)| {
                if 0 < index {
                    // Create empty line between directives
                    self.inner.write_all(b"\n")?;
                }

                self.print_directive(directive)?;
                self.inner.write_all(b"\n")
            })
            .and_then(|_| self.inner.flush())
    }

    fn print_close(&mut self, close: &Close) -> Result<()> {
        write!(self.inner, "{close}")
    }

    fn print_open(&mut self, open: &Open) -> Result<()> {
        write!(self.inner, "{open}")
    }

    pub fn print_price(&mut self, price: &Price) -> Result<()> {
        write!(self.inner, "{price}")
    }

    pub fn print_transaction(&mut self, transaction: &Transaction) -> Result<()> {
        let Transaction {
            date,
            flag,
            payee,
            narration,
            links,
            meta,
            postings,
        } = transaction;

        write!(self.inner, "{date} {flag}")?;

        match (payee, narration) {
            (Some(payee), Some(narration)) => write!(self.inner, r" {payee:?} {narration:?}")?,
            (Some(payee), None) => write!(self.inner, r#" {payee:?} """#)?,
            (None, Some(narration)) => write!(self.inner, r" {narration:?}")?,
            (None, None) => {}
        }

        for link in links {
            write!(self.inner, "\n  {link}")?;
        }

        for (key, value) in meta {
            write!(self.inner, "\n  {key}: {value}")?;
        }

        postings
            .iter()
            .try_for_each(|posting| self.print_posting(posting))
    }
}

impl<W> PrettyPrinter<W>
where
    W: Write,
{
    fn print_amount(&mut self, config: AmountConfig, amount: &Amount) -> Result<()> {
        let Amount { amount, commodity } = amount;

        self.print_decimal_aligned(config.sign_column, config.decimal_separator_column, amount)?;

        self.inner.ensure_column(config.commodity_column)?;
        write!(self.inner, "{commodity}")?;

        Ok(())
    }

    fn print_cost(&mut self, cost: &Option<CostSpec>) -> Result<()> {
        // TODO align amounts?
        if let Some(cost) = cost {
            write!(self.inner, " {cost}")?;
        }

        Ok(())
    }

    fn print_decimal_aligned(
        &mut self,
        sign_column: usize,
        decimal_separator_column: usize,
        decimal: &Decimal,
    ) -> Result<()> {
        static FORMATTER: Lazy<FixedDecimalFormatter> = Lazy::new(|| {
            FixedDecimalFormatter::try_new_unstable(
                &icu_data::BakedDataProvider,
                &locale!("en").into(),
                Default::default(),
            )
            .expect("")
        });

        self.inner.ensure_column(sign_column)?;
        self.inner.write_all(if decimal.is_sign_negative() {
            b"-"
        } else {
            b" "
        })?;

        // Ensure decimal is now positive
        let decimal = decimal.abs();
        let decimal = fixed_decimal_from(&decimal);

        // We can ignore the part right of the decimal separator when aligning the amount on the decimal separator
        self.inner
            .ensure_column(decimal_separator_column - left_width(&decimal))?;

        let decimal = FORMATTER.format(&decimal);
        write!(self.inner, "{decimal}")
    }

    fn print_posting(&mut self, posting: &Posting) -> Result<()> {
        let Posting {
            flag,
            account,
            amount,
            cost,
            price,
            meta,
        } = posting;
        write!(self.inner, "\n  ")?;

        if let Some(flag) = flag {
            write!(self.inner, "{flag} ")?;
        }

        write!(self.inner, "{account}")?;

        if let Some(amount) = amount {
            self.print_amount(self.config.amount, amount)?;

            self.print_cost(cost)?;

            if let Some(price) = price {
                write!(self.inner, " {price}")?;
            }
        }

        for (key, value) in meta {
            write!(self.inner, "\n    {key}: {value}")?;
        }

        Ok(())
    }
}

impl<W> core::fmt::Debug for PrettyPrinter<W> {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        f.debug_struct("PrettyPrinter")
            .field("config", &self.config)
            .finish_non_exhaustive()
    }
}

#[derive(Clone, Debug)]
struct AmountMetrics {
    magnitude_range: RangeInclusive<i16>,
    start_column: usize,
}

impl AmountMetrics {
    const ZERO: Self = Self::zero();

    fn derive(start_column: usize, amount: &Amount) -> Self {
        Self::derive_opt(start_column, Some(amount))
    }

    fn derive_opt(start_column: usize, amount: Option<&Amount>) -> Self {
        let magnitude_range = amount.map_or(0..=0, |amount| {
            fixed_decimal_from(&amount.amount).magnitude_range()
        });

        Self {
            start_column,
            magnitude_range,
        }
    }

    const fn zero() -> Self {
        Self {
            start_column: 0,
            magnitude_range: 0..=0,
        }
    }
}

impl AmountMetrics {
    fn merge(self, other: Self) -> Self {
        let start_column = self.start_column.max(other.start_column);
        let magnitude_range = merge_ranges(self.magnitude_range, other.magnitude_range);

        Self {
            start_column,
            magnitude_range,
        }
    }
}

#[derive(Clone, Debug)]
struct DirectiveMetrics {
    amount: AmountMetrics,

    posting_flag: bool,
}

impl DirectiveMetrics {
    const ACCOUNT_AMOUNT_SEPARATION: usize = 2;

    const ZERO: Self = Self::zero();
}

impl DirectiveMetrics {
    fn derive_from_directives<'d>(
        directives: impl IntoParallelIterator<Item = &'d Directive>,
    ) -> Self {
        directives
            .into_par_iter()
            .flat_map(Self::derive_from_directive)
            .reduce(Self::zero, Self::merge)
    }

    fn derive_from_directive(directive: &Directive) -> impl ParallelIterator<Item = Self> + '_ {
        match directive {
            Directive::Close(_) | Directive::Open(_) | Directive::Price(_) => {
                Either::Left(rayon::iter::once(Self::zero()))
            }

            Directive::Balance(balance) => {
                Either::Left(rayon::iter::once(Self::derive_from_balance(balance)))
            }
            Directive::Transaction(transaction) => {
                Either::Right(Self::derive_from_transaction(transaction))
            }
        }
    }

    fn derive_from_balance(balance: &Balance) -> Self {
        const PREFIX_WIDTH: usize = "YYYY-MM-DD balance ".len();

        let Balance {
            account, amount, ..
        } = balance;

        let amount_start_column =
            PREFIX_WIDTH + account_width(account) + Self::ACCOUNT_AMOUNT_SEPARATION;
        let amount = AmountMetrics::derive(amount_start_column, amount);

        Self {
            amount,
            ..Self::ZERO
        }
    }

    fn derive_from_transaction(
        transaction: &Transaction,
    ) -> impl ParallelIterator<Item = Self> + '_ {
        transaction
            .postings
            .par_iter()
            .map(Self::derive_from_posting)
    }

    fn derive_from_posting(posting: &Posting) -> Self {
        const INDENT: usize = 2;

        let Posting {
            flag,
            account,
            amount,
            ..
        } = posting;

        let posting_flag = flag.is_some();

        let amount_start_column = INDENT
            + (2 * posting_flag as usize)
            + account_width(account)
            + Self::ACCOUNT_AMOUNT_SEPARATION;
        let amount = AmountMetrics::derive_opt(amount_start_column, amount.as_ref());

        Self {
            amount,
            posting_flag,
        }
    }

    const fn zero() -> Self {
        Self {
            amount: AmountMetrics::ZERO,
            posting_flag: false,
        }
    }
}

impl DirectiveMetrics {
    fn merge(self, other: Self) -> Self {
        let amount = self.amount.merge(other.amount);
        let posting_flag = self.posting_flag || other.posting_flag;

        Self {
            amount,
            posting_flag,
        }
    }
}

struct TrackingWriter<W> {
    column: usize,

    inner: W,
}

impl<W> TrackingWriter<W>
where
    W: Write,
{
    fn new(inner: W) -> Self {
        Self { column: 0, inner }
    }
}

impl<W> TrackingWriter<W>
where
    W: Write,
{
    fn ensure_column(&mut self, target: usize) -> Result<()> {
        let shift = target.saturating_sub(self.column);

        self.write_fmt(format_args!("{:shift$}", ""))
    }
}

impl<W> Write for TrackingWriter<W>
where
    W: Write,
{
    fn write(&mut self, buf: &[u8]) -> Result<usize> {
        let written = self.inner.write(buf)?;
        let buffer = &buf[..written];

        // Extract the last line written to the inner writer
        let line = buffer.rsplit_once_str(b"\n").map_or(buffer, |(_, line)| {
            // Since we encountered a line separator, reset the column
            self.column = 0;

            line
        });

        self.column += line.graphemes().count();

        Ok(written)
    }

    fn flush(&mut self) -> Result<()> {
        self.inner.flush()
    }
}

fn account_width(account: &Acc) -> usize {
    let account: &str = account.as_ref();
    account.graphemes(true).count()
}

fn fixed_decimal_from(decimal: &Decimal) -> FixedDecimal {
    let scale: i16 = decimal.scale().try_into().unwrap();
    FixedDecimal::from(decimal.mantissa()).multiplied_pow10(-scale)
}

fn left_width(decimal: &FixedDecimal) -> usize {
    let sign_width = matches!(decimal.sign(), Sign::Negative) as usize;
    let magnitude: usize = decimal
        .nonzero_magnitude_start()
        .try_into()
        .unwrap_or_default();
    let integral_width = ((4 * magnitude) / 3) + 1;

    sign_width + integral_width
}

fn merge_ranges(lhs: RangeInclusive<i16>, rhs: RangeInclusive<i16>) -> RangeInclusive<i16> {
    let [lhs, rhs] = [lhs, rhs].map(RangeInclusive::into_inner);
    let (start, end) = (lhs.0.min(rhs.0), lhs.1.max(rhs.1));

    start..=end
}