mod data;
mod inference;
mod model;
mod training;

use crate::model::ModelConfig;
use burn::{
    backend::{
        wgpu::{AutoGraphicsApi, WgpuDevice},
        Autodiff, Wgpu,
    },
    data::dataset::{vision::MNISTDataset, Dataset},
    optim::AdamConfig,
};
use clap::Parser;
use crossterm::{
    event::{self, KeyCode, KeyEventKind},
    terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
    ExecutableCommand,
};
use humantime::format_duration;
use image::{imageops, GrayImage};
use ratatui::{
    layout::{Constraint, Layout, Margin},
    prelude::{CrosstermBackend, Stylize, Terminal},
    style::{Modifier, Style},
    text::{Line, Span, Text},
    widgets::{Block, Borders, Gauge, Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
    Frame,
};
use ratatui_image::{picker::Picker, protocol::StatefulProtocol, StatefulImage};
use std::{
    io::stdout,
    sync::{
        atomic::{AtomicBool, AtomicUsize, Ordering},
        mpsc::channel,
        Arc,
    },
    time::{Duration, Instant},
};

struct App {
    should_close: bool,
    item: usize,
    results: Vec<(Box<dyn StatefulProtocol>, i32, u8)>,

    body_layout: Layout,
    scrollbar_state: ScrollbarState,
}

impl App {
    fn new(results: Vec<(Box<dyn StatefulProtocol>, i32, u8)>) -> Self {
        Self {
            should_close: false,
            item: 0,
            body_layout: Layout::vertical([
                Constraint::Length(1),
                Constraint::Min(1),
                Constraint::Length(1),
            ]),
            scrollbar_state: ScrollbarState::new(results.len()),
            results,
        }
    }

    fn draw(&mut self, frame: &mut Frame) {
        let area = frame.size();
        let [text_area, im_area, bottom_area] =
            self.body_layout.areas(area.inner(&Margin::new(1, 1)));
        let (ref mut im, predicted, label) = self.results[self.item];

        frame.render_widget(
            Paragraph::new(Text::from(vec![Line::from(vec![
                Span::raw("Predicted: "),
                if predicted == label as i32 {
                    Span::styled(
                        predicted.to_string(),
                        Style::default().fg(ratatui::style::Color::Green),
                    )
                } else {
                    Span::styled(
                        predicted.to_string(),
                        Style::default().fg(ratatui::style::Color::Red),
                    )
                },
                Span::raw(format!(", Expected: {label}")),
            ])]))
            .white(),
            text_area,
        );
        let image = StatefulImage::new(None);
        frame.render_stateful_widget(image, im_area, im);
        frame.render_widget(Block::new().borders(Borders::RIGHT), area);
        let scrollbar = Scrollbar::new(ScrollbarOrientation::VerticalRight)
            .begin_symbol(Some(""))
            .end_symbol(Some(""));

        self.scrollbar_state = self.scrollbar_state.position(self.item);
        frame.render_stateful_widget(
            scrollbar,
            area.inner(&Margin {
                horizontal: 0,
                vertical: 1,
            }),
            &mut self.scrollbar_state,
        );

        frame.render_widget(
            Paragraph::new(format!("{} out of {}", self.item, self.results.len())),
            bottom_area,
        );
    }
}

#[derive(thiserror::Error, Debug)]
#[error("{}", _0.lock().unwrap())]
struct ThreadsafeBoxedError(std::sync::Arc<std::sync::Mutex<Box<dyn std::error::Error>>>);
// We are thread-safe as we do *not* allow mutation (we must own the Box)
// and all reads are (assumedly) idempotent and pure, so...bullshit the compiler
unsafe impl Send for ThreadsafeBoxedError {}
unsafe impl Sync for ThreadsafeBoxedError {}
impl ThreadsafeBoxedError {
    fn new(boxed: Box<dyn std::error::Error>) -> Self {
        use std::sync::Mutex;
        #[allow(clippy::arc_with_non_send_sync)]
        Self(Arc::new(Mutex::new(boxed)))
    }
}

#[derive(clap::Parser, Debug)]
struct Cli {
    #[command(subcommand)]
    mode: RunMode,
}

#[derive(clap::Subcommand, Debug)]
enum RunMode {
    // Run training
    Train,
    // Get inference results
    Infer,
}

type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
type AutodiffBackend = Autodiff<Backend>;

fn main() -> color_eyre::Result<()> {
    color_eyre::install()?;
    let cli = Cli::parse();
    match cli.mode {
        RunMode::Train => {
            train();
            Ok(())
        }
        RunMode::Infer => infer(),
    }
}

fn train() {
    let device = WgpuDevice::default();

    training::train::<AutodiffBackend>(
        "/tmp/guide",
        training::TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
        device,
    );
}

fn round_duration_to_millisecond(duration: Duration) -> Duration {
    let duration: u64 = duration.as_millis().try_into().unwrap_or(u64::MAX);
    Duration::from_millis(duration)
}

fn infer() -> color_eyre::Result<()> {
    // Adjust the panic hook to leave terminal specialness
    let hook = std::panic::take_hook();
    std::panic::set_hook(Box::new(move |panic_info| {
        disable_raw_mode().unwrap();
        stdout().execute(LeaveAlternateScreen).unwrap();
        (hook)(panic_info);
    }));

    let device = WgpuDevice::default();

    let mnist_data = MNISTDataset::test();
    stdout().execute(EnterAlternateScreen)?;
    enable_raw_mode()?;
    let mut terminal = Terminal::new(CrosstermBackend::new(stdout()))?;
    terminal.clear()?;

    let mut picker = Picker::from_termios().map_err(ThreadsafeBoxedError::new)?;
    picker.guess_protocol();

    let data_len = mnist_data.len();
    let inferred_count = Arc::new(AtomicUsize::new(0));
    let (tx, rx) = channel();
    let immediate_stop = Arc::new(AtomicBool::new(false));

    let handles: Vec<_> = (0..num_cpus::get())
        .map(|i| {
            let tx = tx.clone();
            let inferred_count = inferred_count.clone();
            let dataset = mnist_data
                .iter()
                .skip(i * data_len / 4)
                .take(data_len / 4)
                .collect::<Vec<_>>();
            let device = device.clone();
            let immediate_stop = immediate_stop.clone();
            std::thread::spawn(move || {
                for item in dataset {
                    if immediate_stop.load(Ordering::Acquire) {
                        break;
                    }

                    let dyn_img = GrayImage::from_raw(
                        28,
                        28,
                        item.image
                            .iter()
                            .flatten()
                            .map(|f| (255.0 * *f) as u8)
                            .collect(),
                    )
                    .unwrap();
                    // Resize the image to 150x150
                    let dyn_img =
                        imageops::resize(&dyn_img, 150, 150, imageops::FilterType::Nearest);
                    // Wrap as dynamic image
                    let dyn_img = image::DynamicImage::ImageLuma8(dyn_img);
                    let image = picker.new_resize_protocol(dyn_img);

                    let (predicted, label) =
                        inference::infer::<Backend>("/tmp/guide", device.clone(), item);

                    tx.send((image, predicted, label)).unwrap();
                    inferred_count.fetch_add(1, Ordering::AcqRel);
                }
            })
        })
        .collect();

    let loading_layout = Layout::vertical([
        Constraint::Length(1),
        Constraint::Length(1),
        Constraint::Min(1),
    ]);
    let mut running_avg_dur = Duration::ZERO;
    let mut previous_count = 0;
    let mut start = Instant::now();
    while inferred_count.load(Ordering::Acquire) < data_len {
        if event::poll(std::time::Duration::from_millis(16))? {
            if let event::Event::Key(key) = event::read()? {
                if key.kind == KeyEventKind::Press && key.code == KeyCode::Char('q') {
                    immediate_stop.store(true, Ordering::Release);
                    break;
                }
            }
        }
        let inferred_count = inferred_count.load(Ordering::Acquire);
        // Update the running average (around every 100 items)
        if inferred_count > previous_count + 100 {
            let dur = start.elapsed();
            start = Instant::now();
            running_avg_dur = Duration::from_secs_f64(
                (running_avg_dur * previous_count.try_into().unwrap() + dur).as_secs_f64()
                    / inferred_count as f64,
            );
            previous_count = inferred_count;
        }
        terminal.draw(|f| {
            let [text_area, bar_area, _] = loading_layout.areas(f.size());
            let avg_secs_per_item = running_avg_dur.as_secs_f64();
            let throughput = if avg_secs_per_item > 0.0 {
                format!("{:.2} items/sec", avg_secs_per_item.recip())
            } else {
                "--.-- items/sec".to_string()
            };
            f.render_widget(
                Paragraph::new(format!(
                    "Loading... ({}/{}) [{throughput}] ({})",
                    inferred_count,
                    data_len,
                    format_duration(round_duration_to_millisecond(
                        running_avg_dur * previous_count.try_into().unwrap() + start.elapsed()
                    ))
                )),
                text_area,
            );
            f.render_widget(
                Gauge::default()
                    .gauge_style(
                        Style::default()
                            .fg(ratatui::style::Color::White)
                            .bg(ratatui::style::Color::Black)
                            .add_modifier(Modifier::ITALIC),
                    )
                    .use_unicode(true)
                    .ratio(inferred_count as f64 / data_len as f64),
                bar_area,
            )
        })?;
    }

    for handle in handles {
        handle.join().unwrap();
    }

    let mut results = Vec::with_capacity(data_len);
    drop(tx);
    results.extend(rx);
    let mut app = App::new(results);

    while !app.should_close {
        if event::poll(std::time::Duration::from_millis(16))? {
            if let event::Event::Key(key) = event::read()? {
                if key.kind == KeyEventKind::Press && key.code == KeyCode::Char('q') {
                    break;
                }

                if matches!(key.code, KeyCode::Char('u') | KeyCode::Up)
                    && matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat)
                {
                    app.item += app.results.len() - 1;
                    app.item %= app.results.len();
                }

                if matches!(key.code, KeyCode::Char('d') | KeyCode::Down)
                    && matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat)
                {
                    app.item += 1;
                    app.item %= app.results.len();
                }

                if matches!(key.code, KeyCode::Char('n') | KeyCode::Right)
                    && matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat)
                {
                    let mut mistakes = app
                        .results
                        .iter()
                        .enumerate()
                        .cycle()
                        .skip(app.item)
                        .take(app.results.len())
                        .filter(|(idx, (_, p, l))| *idx != app.item && *p != *l as i32);
                    if let Some((idx, _)) = mistakes.next() {
                        app.item = idx;
                    }
                }

                if matches!(key.code, KeyCode::Char('p') | KeyCode::Left)
                    && matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat)
                {
                    let mut mistakes = app
                        .results
                        .iter()
                        .enumerate()
                        .rev()
                        .cycle()
                        .skip(app.results.len() - app.item)
                        .take(app.results.len())
                        .filter(|(idx, (_, p, l))| *idx != app.item && *p != *l as i32);
                    if let Some((idx, _)) = mistakes.next() {
                        app.item = idx;
                    }
                }
            }
        }

        terminal.draw(|f| app.draw(f))?;
    }

    stdout().execute(LeaveAlternateScreen)?;
    disable_raw_mode()?;
    Ok(())
}