use burn::{
config::Config,
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, ReLU,
},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
pool: AdaptiveAvgPool2d,
dropout: Dropout,
linear1: Linear<B>,
linear2: Linear<B>,
activation: ReLU,
}
impl<B: Backend> Model<B> {
pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch_size, height, width] = images.dims();
let x = images.reshape([batch_size, 1, height, width]);
let x = self.conv1.forward(x); let x = self.dropout.forward(x);
let x = self.conv2.forward(x); let x = self.dropout.forward(x);
let x = self.activation.forward(x);
let x = self.pool.forward(x); let x = x.reshape([batch_size, 16 * 8 * 8]);
let x = self.linear1.forward(x);
let x = self.dropout.forward(x);
let x = self.activation.forward(x);
self.linear2.forward(x)
}
}
#[derive(Config, Debug)]
pub struct ModelConfig {
num_classes: usize,
hidden_size: usize,
#[config(default = "0.5")]
dropout: f64,
}
impl ModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
Model {
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
activation: ReLU::new(),
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
dropout: DropoutConfig::new(self.dropout).init(),
}
}
pub fn init_with<B: Backend>(&self, record: ModelRecord<B>) -> Model<B> {
Model {
conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1),
conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2),
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
activation: ReLU::new(),
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1),
linear2: LinearConfig::new(self.hidden_size, self.num_classes)
.init_with(record.linear2),
dropout: DropoutConfig::new(self.dropout).init(),
}
}
}