use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::vision::MNISTDataset},
module::Module,
nn::loss::CrossEntropyLossConfig,
optim::AdamConfig,
record::CompactRecorder,
tensor::{
backend::{AutodiffBackend, Backend},
Int, Tensor,
},
train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
},
};
use crate::{
data::{MNISTBatch, MNISTBatcher},
model::{Model, ModelConfig},
};
impl<B: Backend> Model<B> {
pub fn forward_classification(
&self,
images: Tensor<B, 3>,
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
let output = self.forward(images);
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), targets.clone());
ClassificationOutput::new(loss, output, targets)
}
}
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let item = self.forward_classification(batch.images, batch.targets);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: MNISTBatch<B>) -> ClassificationOutput<B> {
self.forward_classification(batch.images, batch.targets)
}
}
#[derive(Config)]
pub struct TrainingConfig {
pub model: ModelConfig,
pub optimizer: AdamConfig,
#[config(default = 10)]
pub num_epochs: usize,
#[config(default = 64)]
pub batch_size: usize,
#[config(default = 4)]
pub num_workers: usize,
#[config(default = 42)]
pub seed: u64,
#[config(default = 1.0e-4)]
pub learning_rate: f64,
}
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
std::fs::create_dir_all(artifact_dir).ok();
config
.save(format!("{artifact_dir}/config.json"))
.expect("config not saved correctly");
B::seed(config.seed);
let batcher_train = MNISTBatcher::<B>::new(device.clone());
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(MNISTDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(MNISTDataset::test());
let model = config.model.init::<B>(&device);
let learner = LearnerBuilder::new(artifact_dir)
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device])
.num_epochs(config.num_epochs)
.build(model, config.optimizer.init(), config.learning_rate);
let model_trained = learner.fit(dataloader_train, dataloader_test);
model_trained
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
.expect("Trained model should be saved successfully");
}