mod data;
mod inference;
mod model;
mod training;
use crate::model::ModelConfig;
use burn::{
backend::{
wgpu::{AutoGraphicsApi, WgpuDevice},
Autodiff, Wgpu,
},
data::dataset::{vision::MNISTDataset, Dataset},
optim::AdamConfig,
};
use clap::Parser;
use crossterm::{
event::{self, KeyCode, KeyEventKind},
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
ExecutableCommand,
};
use humantime::format_duration;
use image::{imageops, GrayImage};
use ratatui::{
layout::{Constraint, Layout, Margin},
prelude::{CrosstermBackend, Stylize, Terminal},
style::{Modifier, Style},
text::{Line, Span, Text},
widgets::{Block, Borders, Gauge, Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
Frame,
};
use ratatui_image::{picker::Picker, protocol::StatefulProtocol, StatefulImage};
use std::{
io::stdout,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
mpsc::channel,
Arc,
},
time::{Duration, Instant},
};
struct App {
should_close: bool,
item: usize,
results: Vec<(Box<dyn StatefulProtocol>, i32, u8)>,
body_layout: Layout,
scrollbar_state: ScrollbarState,
}
impl App {
fn new(results: Vec<(Box<dyn StatefulProtocol>, i32, u8)>) -> Self {
Self {
should_close: false,
item: 0,
body_layout: Layout::vertical([
Constraint::Length(1),
Constraint::Min(1),
Constraint::Length(1),
]),
scrollbar_state: ScrollbarState::new(results.len()),
results,
}
}
fn draw(&mut self, frame: &mut Frame) {
let area = frame.size();
let [text_area, im_area, bottom_area] =
self.body_layout.areas(area.inner(&Margin::new(1, 1)));
let (ref mut im, predicted, label) = self.results[self.item];
frame.render_widget(
Paragraph::new(Text::from(vec![Line::from(vec![
Span::raw("Predicted: "),
if predicted == label as i32 {
Span::styled(
predicted.to_string(),
Style::default().fg(ratatui::style::Color::Green),
)
} else {
Span::styled(
predicted.to_string(),
Style::default().fg(ratatui::style::Color::Red),
)
},
Span::raw(format!(", Expected: {label}")),
])]))
.white(),
text_area,
);
let image = StatefulImage::new(None);
frame.render_stateful_widget(image, im_area, im);
frame.render_widget(Block::new().borders(Borders::RIGHT), area);
let scrollbar = Scrollbar::new(ScrollbarOrientation::VerticalRight)
.begin_symbol(Some("↑"))
.end_symbol(Some("↓"));
self.scrollbar_state = self.scrollbar_state.position(self.item);
frame.render_stateful_widget(
scrollbar,
area.inner(&Margin {
horizontal: 0,
vertical: 1,
}),
&mut self.scrollbar_state,
);
frame.render_widget(
Paragraph::new(format!("{} out of {}", self.item, self.results.len())),
bottom_area,
);
}
}
#[derive(thiserror::Error, Debug)]
#[error("{}", _0.lock().unwrap())]
struct ThreadsafeBoxedError(std::sync::Arc<std::sync::Mutex<Box<dyn std::error::Error>>>);
unsafe impl Send for ThreadsafeBoxedError {}
unsafe impl Sync for ThreadsafeBoxedError {}
impl ThreadsafeBoxedError {
fn new(boxed: Box<dyn std::error::Error>) -> Self {
use std::sync::Mutex;
#[allow(clippy::arc_with_non_send_sync)]
Self(Arc::new(Mutex::new(boxed)))
}
}
#[derive(clap::Parser, Debug)]
struct Cli {
#[command(subcommand)]
mode: RunMode,
}
#[derive(clap::Subcommand, Debug)]
enum RunMode {
Train,
Infer,
}
type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
type AutodiffBackend = Autodiff<Backend>;
fn main() -> color_eyre::Result<()> {
color_eyre::install()?;
let cli = Cli::parse();
match cli.mode {
RunMode::Train => {
train();
Ok(())
}
RunMode::Infer => infer(),
}
}
fn train() {
let device = WgpuDevice::default();
training::train::<AutodiffBackend>(
"/tmp/guide",
training::TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
device,
);
}
fn round_duration_to_millisecond(duration: Duration) -> Duration {
let duration: u64 = duration.as_millis().try_into().unwrap_or(u64::MAX);
Duration::from_millis(duration)
}
fn infer() -> color_eyre::Result<()> {
let hook = std::panic::take_hook();
std::panic::set_hook(Box::new(move |panic_info| {
disable_raw_mode().unwrap();
stdout().execute(LeaveAlternateScreen).unwrap();
(hook)(panic_info);
}));
let device = WgpuDevice::default();
let mnist_data = MNISTDataset::test();
stdout().execute(EnterAlternateScreen)?;
enable_raw_mode()?;
let mut terminal = Terminal::new(CrosstermBackend::new(stdout()))?;
terminal.clear()?;
let mut picker = Picker::from_termios().map_err(ThreadsafeBoxedError::new)?;
picker.guess_protocol();
let data_len = mnist_data.len();
let inferred_count = Arc::new(AtomicUsize::new(0));
let (tx, rx) = channel();
let immediate_stop = Arc::new(AtomicBool::new(false));
let handles: Vec<_> = (0..num_cpus::get())
.map(|i| {
let tx = tx.clone();
let inferred_count = inferred_count.clone();
let dataset = mnist_data
.iter()
.skip(i * data_len / 4)
.take(data_len / 4)
.collect::<Vec<_>>();
let device = device.clone();
let immediate_stop = immediate_stop.clone();
std::thread::spawn(move || {
for item in dataset {
if immediate_stop.load(Ordering::Acquire) {
break;
}
let dyn_img = GrayImage::from_raw(
28,
28,
item.image
.iter()
.flatten()
.map(|f| (255.0 * *f) as u8)
.collect(),
)
.unwrap();
let dyn_img =
imageops::resize(&dyn_img, 150, 150, imageops::FilterType::Nearest);
let dyn_img = image::DynamicImage::ImageLuma8(dyn_img);
let image = picker.new_resize_protocol(dyn_img);
let (predicted, label) =
inference::infer::<Backend>("/tmp/guide", device.clone(), item);
tx.send((image, predicted, label)).unwrap();
inferred_count.fetch_add(1, Ordering::AcqRel);
}
})
})
.collect();
let loading_layout = Layout::vertical([
Constraint::Length(1),
Constraint::Length(1),
Constraint::Min(1),
]);
let mut running_avg_dur = Duration::ZERO;
let mut previous_count = 0;
let mut start = Instant::now();
while inferred_count.load(Ordering::Acquire) < data_len {
if event::poll(std::time::Duration::from_millis(16))? {
if let event::Event::Key(key) = event::read()? {
if key.kind == KeyEventKind::Press && key.code == KeyCode::Char('q') {
immediate_stop.store(true, Ordering::Release);
break;
}
}
}
let inferred_count = inferred_count.load(Ordering::Acquire);
if inferred_count > previous_count + 100 {
let dur = start.elapsed();
start = Instant::now();
running_avg_dur = Duration::from_secs_f64(
(running_avg_dur * previous_count.try_into().unwrap() + dur).as_secs_f64()
/ inferred_count as f64,
);
previous_count = inferred_count;
}
terminal.draw(|f| {
let [text_area, bar_area, _] = loading_layout.areas(f.size());
let avg_secs_per_item = running_avg_dur.as_secs_f64();
let throughput = if avg_secs_per_item > 0.0 {
format!("{:.2} items/sec", avg_secs_per_item.recip())
} else {
"--.-- items/sec".to_string()
};
f.render_widget(
Paragraph::new(format!(
"Loading... ({}/{}) [{throughput}] ({})",
inferred_count,
data_len,
format_duration(round_duration_to_millisecond(
running_avg_dur * previous_count.try_into().unwrap() + start.elapsed()
))
)),
text_area,
);
f.render_widget(
Gauge::default()
.gauge_style(
Style::default()
.fg(ratatui::style::Color::White)
.bg(ratatui::style::Color::Black)
.add_modifier(Modifier::ITALIC),
)
.use_unicode(true)
.ratio(inferred_count as f64 / data_len as f64),
bar_area,
)
})?;
}
for handle in handles {
handle.join().unwrap();
}
let mut results = Vec::with_capacity(data_len);
drop(tx);
results.extend(rx);
let mut app = App::new(results);
while !app.should_close {
if event::poll(std::time::Duration::from_millis(16))? {
if let event::Event::Key(key) = event::read()? {
if key.kind == KeyEventKind::Press && key.code == KeyCode::Char('q') {
break;
}
if matches!(key.code, KeyCode::Char('u') | KeyCode::Up)
&& matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat)
{
app.item += app.results.len() - 1;
app.item %= app.results.len();
}
if matches!(key.code, KeyCode::Char('d') | KeyCode::Down)
&& matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat)
{
app.item += 1;
app.item %= app.results.len();
}
if matches!(key.code, KeyCode::Char('n') | KeyCode::Right)
&& matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat)
{
let mut mistakes = app
.results
.iter()
.enumerate()
.cycle()
.skip(app.item)
.take(app.results.len())
.filter(|(idx, (_, p, l))| *idx != app.item && *p != *l as i32);
if let Some((idx, _)) = mistakes.next() {
app.item = idx;
}
}
if matches!(key.code, KeyCode::Char('p') | KeyCode::Left)
&& matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat)
{
let mut mistakes = app
.results
.iter()
.enumerate()
.rev()
.cycle()
.skip(app.results.len() - app.item)
.take(app.results.len())
.filter(|(idx, (_, p, l))| *idx != app.item && *p != *l as i32);
if let Some((idx, _)) = mistakes.next() {
app.item = idx;
}
}
}
}
terminal.draw(|f| app.draw(f))?;
}
stdout().execute(LeaveAlternateScreen)?;
disable_raw_mode()?;
Ok(())
}