use chisel_tuto::get_handle;
use cxxrtl::{CxxrtlHandle, CxxrtlSignal, Vcd};
use std::{env, fs::File};

struct Counter {
    pub handle: CxxrtlHandle,
    pub clk: CxxrtlSignal<1>,
    pub io_inc: CxxrtlSignal<1>,
    pub io_amt: CxxrtlSignal<4>,
    pub io_tot: CxxrtlSignal<8>,
}

impl Counter {
    fn new() -> Self {
        let lib = concat!(env!("OUT_DIR"), "/Counter.so");
        let handle = get_handle(lib);
        let clk = handle.get("clock").unwrap().signal();
        let io_inc = handle.get("io_inc").unwrap().signal();
        let io_amt = handle.get("io_amt").unwrap().signal();
        let io_tot = handle.get("io_tot").unwrap().signal();
        Self {
            handle,
            clk,
            io_inc,
            io_amt,
            io_tot,
        }
    }

    fn step(&mut self) {
        self.handle.step()
    }
}

#[test]
fn test_counter() {
    let mut dut = Counter::new();
    const MAX_INT: u8 = 16;
    let mut cur_cnt = 0;

    let int_wrap_around = |n: u8, max| {
        if n > max {
            0
        } else {
            n
        }
    };

    dut.io_inc.set(false);
    dut.io_amt.set::<u8>(0);

    for _i in 0..5 {
        dut.step();
    }
    // dut.step();

    // let mut vcd = Vcd::default();
    // vcd.timescale(cxxrtl::TimescaleNumber::One, cxxrtl::TimescaleUnit::Us);
    // vcd.add(&dut.handle);
    // let mut vcd_file = File::create("counter.vcd").unwrap();

    for _i in 0..10 {
        let inc = rand::random();
        let amt = rand::random::<u8>() % MAX_INT;

        dut.io_inc.set::<bool>(inc);
        dut.io_amt.set::<u8>(amt);

        dut.clk.set(false);
        dut.step();
        // vcd.sample(i * 2);
        dut.clk.set(true);
        dut.step();
        // vcd.sample(i * 2 + 1);

        // vcd.write(&mut vcd_file).unwrap();
        cur_cnt = if inc {
            int_wrap_around(cur_cnt + amt, 255)
        } else {
            cur_cnt
        };
        assert_eq!(dut.io_tot.get::<u8>(), cur_cnt);
    }
}