use burn::{
config::Config,
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
record::{CompactRecorder, Recorder},
tensor::backend::Backend,
};
use crate::{data::MNISTBatcher, training::TrainingConfig};
pub fn infer<B: Backend>(
artifact_dir: &str,
device: B::Device,
item: MNISTItem,
) -> (B::IntElem, u8) {
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist");
let model = config.model.init_with::<B>(record);
let label = item.label;
let batcher = MNISTBatcher::new(device);
let batch = batcher.batch(vec![item]);
let output = model.forward(batch.images);
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
(predicted, label)
}