use std::ffi::c_void;

#[repr(C)]
pub struct SDLPixelFormat {
    data: c_void,
}

#[repr(C)]
pub struct SDLRect {
    x: i16,
    y: i16,
    w: u16,
    h: u16,
}

#[repr(C)]
pub struct SDLSurface {
    flags: u32,
    format: *const SDLPixelFormat,
    w: i32,
    h: i32,
    pitch: i32,
    pixels: *mut c_void,
    userdata: *mut c_void,
    locked: i32,
    lock_data: *mut c_void,
    clip_rect: SDLRect,
    map: *mut c_void,

    refcount: i32,
}

#[link(name = "SDL2")]

extern "C" {
    #[link_name = "SDL_MapRGBA"]
    fn c_sdl_map_rgba(format: *const SDLPixelFormat, r: u8, g: u8, b: u8, a: u8) -> u32;
    #[link_name = "SDL_CreateRGBSurface"]
    fn c_sdl_create_rgb_surface(
        flags: u32,
        width: u32,
        height: u32,
        depth: u32,
        r_mask: u32,
        g_mask: u32,
        b_mask: u32,
        a_mask: u32,
    ) -> *mut SDLSurface;
}

fn map_rgba(format: &SDLPixelFormat, r: u8, g: u8, b: u8, a: u8) -> u32 {
    unsafe { c_sdl_map_rgba(format as *const SDLPixelFormat, r, g, b, a) }
}

#[no_mangle]
pub extern "C" fn render_tensor(
    data: *const u8,
    width: u32,
    height: u32,
    channels: i32,
) -> *mut SDLSurface {
    let surface = unsafe {
        c_sdl_create_rgb_surface(0, width, height, 32, 0xff, 0xff00, 0xff0000, 0xff000000)
    };
    let width = width as usize;
    let height = height as usize;
    let channels = channels as usize;
    let pitch = unsafe { (*surface).pitch as usize / 4 };
    let data = unsafe { std::slice::from_raw_parts(data, width * height * channels) };
    let pixels: &mut [u32] =
        unsafe { std::slice::from_raw_parts_mut((*surface).pixels.cast(), height * pitch) };
    let format = unsafe { (*surface).format.as_ref().unwrap() };
    for y in 0..height {
        let data_rowstart = y * width * (channels as usize);
        let pixels_rowstart = y * pitch;
        for x in 0..width {
            let data_pixstart = data_rowstart + x * channels;
            let (red, green, blue, alpha) = match channels {
                1 => {
                    let grey = data[data_pixstart];
                    (grey, grey, grey, 255)
                }
                3 => {
                    let red = data[data_pixstart];
                    let green = data[data_pixstart + 1];
                    let blue = data[data_pixstart + 2];
                    (red, green, blue, 255)
                }
                4 => {
                    let red = data[data_pixstart];
                    let green = data[data_pixstart + 1];
                    let blue = data[data_pixstart + 2];
                    let alpha = data[data_pixstart + 3];
                    (red, green, blue, alpha)
                }
                _ => panic!("We only handle 1, 3, and 4 channels!"),
            };
            let color_value = map_rgba(format, red, green, blue, alpha);
            pixels[pixels_rowstart + x] = color_value;
        }
    }
    surface
}

#[no_mangle]
pub extern "C" fn equirectangular(target: *mut f64, width: u32, height: u32) {
    use std::f64::consts::PI;
    let target =
        unsafe { std::slice::from_raw_parts_mut(target, (width as usize) * (height as usize) * 3) };
    for y in 0..(height as usize) {
        let rowstart = y * (width as usize) * 3;
        let yr = 0. - (y as f64) / (height as f64) * PI;
        let yrs = yr.sin();
        let yrc = yr.cos();
        for x in 0..(width as usize) {
            let xr = (x as f64) / (width as f64) * 2. * PI;
            let pixstart = rowstart + 3 * x;
            target[pixstart] = yrs * xr.sin();
            target[pixstart + 2] = yrs * xr.cos();
            target[pixstart + 1] = yrc;
        }
    }
}

const SEA_LEVEL: f64 = 0.1;

#[no_mangle]
pub extern "C" fn colourize_heightmap(
    target: *mut f64,
    source: *const f64,
    width: u32,
    height: u32,
) {
    let target =
        unsafe { std::slice::from_raw_parts_mut(target, (width as usize) * (height as usize) * 3) };
    let source =
        unsafe { std::slice::from_raw_parts(source, (width as usize) * (height as usize)) };
    for y in 0..(height as usize) {
        let rowstart_s = y * (width as usize);
        let snow_line = {
            let y = (y as f64) / (height as f64);
            y * (1. - y) * 2.
        };
        for x in 0..(width as usize) {
            let ix = rowstart_s + x;
            let s = source[ix];
            let (r, g, b) = if s > snow_line {
                (1., 1., 1.)
            } else if s > SEA_LEVEL {
                (s.powf(0.6), 0.9, s * 0.5)
            } else {
                (0., 0., s * 4. + 0.4)
            };
            let ixt = ix * 3;
            target[ixt] = r;
            target[ixt + 1] = g;
            target[ixt + 2] = b;
        }
    }
}

#[no_mangle]
pub extern "C" fn landmass_steradians(
    source: *const f64,
    width: u32,
    height: u32,
    x: u32,
    y: u32,
) -> f64 {
    use std::collections::HashMap;
    use std::collections::HashSet;
    use std::collections::VecDeque;
    use std::f64::consts::PI;
    let source =
        unsafe { std::slice::from_raw_parts(source, (width as usize) * (height as usize)) };
    let mut patch = HashSet::new();
    let mut q = VecDeque::new();
    q.push_back((x, y));
    while let Some((x, y)) = q.pop_front() {
        if source[(y * width + x) as usize] >= SEA_LEVEL {
            if patch.insert((x, y)) {
                let y = y as i32;
                let x = x as i32;
                for (x, y) in [
                    (x - 1, y - 1),
                    (x, y - 1),
                    (x + 1, y - 1),
                    (x + 1, y),
                    (x + 1, y + 1),
                    (x, y + 1),
                    (x - 1, y + 1),
                    (x - 1, y),
                ]
                .iter()
                {
                    if *y > 0 && (*y as u32) < height {
                        let np = ((x % width as i32) as u32, *y as u32);
                        q.push_back(np);
                    }
                }
            }
        }
    }
    std::mem::drop(q);
    let mut lines: HashMap<u32, u32> = HashMap::new();
    for (_, y) in patch.iter() {
        lines.entry(*y).and_modify(|w| *w += 1).or_insert(1);
    }
    std::mem::drop(patch);
    let dy = PI / (height as f64);
    let dx = PI * 2. / (width as f64);
    lines
        .iter()
        .map(|(y, w)| dy * ((*y as f64) * dy).sin() * (*w as f64) * dx)
        .sum()
}