Bindings to the seekable variant of the ZSTD compression format
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
mod bindings;
use bindings::*;
use libc::*;
use std::ffi::CStr;
use std::{fmt, io, ptr};
use thiserror::*;

mod compress;
pub use compress::*;

mod decompress;
pub use decompress::*;

const BLOCK_SIZE_LOG_MAX: usize = 17;
const BLOCK_SIZE_MAX: usize = 1 << BLOCK_SIZE_LOG_MAX;

const SEEK_TABLE_FOOTER_SIZE: usize = 9;
const SEEKABLE_MAGIC_NUMBER: u32 = 0x8F92EAB1;
const SKIPPABLE_HEADER_SIZE: usize = 8;
const MAGIC_SKIPPABLE_START: u32 = 0x184D2A50;
const ZSTD_reset_session_only: c_uint = 1;

/// The type of compressors.
pub struct CStream {
    p: *mut ZSTD_CStream,
}

#[derive(Debug, Error)]
pub enum Error {
    #[error("Could not open file {}", 0)]
    CouldNotOpenFile(String),
    #[error(transparent)]
    ZSTD(#[from] ZSTDError),
    #[error(transparent)]
    Io(#[from] io::Error),
    #[error("Null pointer")]
    Null,
    #[error("Frame index too large, expected at most {}, found {}", 0, 1)]
    FIndexTooLarge(usize, usize),
    #[error("Unknown error")]
    Generic,
    #[error("Unsupported frame parameter. Expected at most {}, found {}", 0, 1)]
    FParamUnsupported(usize, usize),
    #[error("Unknown frame descriptor {}", 0)]
    PrefixUnknown(u32),
    #[error("Corrupted block detected {}", 0)]
    Corruption(&'static str),
    #[error("Destination buffer too small. Expected at most {}, found {}", 0, 1)]
    DSizeTooSmall(u64, u64),
}

#[derive(Error)]
pub struct ZSTDError(size_t);

impl fmt::Display for ZSTDError {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        unsafe {
            let error = CStr::from_ptr(ZSTD_getErrorName(self.0));
            write!(fmt, "{}", error.to_str().unwrap())
        }
    }
}

impl fmt::Debug for ZSTDError {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, "{}", self)
    }
}

impl Drop for CStream {
    fn drop(&mut self) {
        if !self.p.is_null() {
            unsafe {
                ZSTD_freeCStream(self.p);
                self.p = ptr::null_mut();
            }
        }
    }
}

impl CStream {
    pub fn new(level: usize) -> Result<Self, Error> {
        unsafe {
            let p = ZSTD_createCStream();
            if p.is_null() {
                return Err(Error::Null);
            }
            let result = ZSTD_initCStream(p, level as c_int);
            if ZSTD_isError(result) != 0 {
                return Err(Error::ZSTD(ZSTDError(result)));
            }
            Ok(CStream { p })
        }
    }

    pub fn in_size() -> usize {
        BLOCK_SIZE_MAX
    }

    pub fn out_size() -> usize {
        unsafe { ZSTD_CStreamOutSize() }
    }

    /// Compress one chunk of input, and write it into the output. The
    /// `output` array must be large enough to hold the result. If
    /// successful, this function returns three integers `(out_pos,
    /// in_pos, next_read_size)`, where `out_pos` is the number of
    /// bytes written in `output`, `in_pos` is the number of input
    /// bytes consumed, and `next_read_size` is a hint for the next
    /// read size.
    pub fn compress(
        &mut self,
        output: &mut [u8],
        input: &[u8],
    ) -> Result<(usize, usize, usize), Error> {
        let mut input = ZSTD_inBuffer {
            src: input.as_ptr() as *const c_void,
            size: input.len() as size_t,
            pos: 0,
        };

        let mut output = ZSTD_outBuffer {
            dst: output.as_mut_ptr() as *mut c_void,
            size: output.len() as size_t,
            pos: 0,
        };

        let result = unsafe { ZSTD_compressStream(self.p, &mut output, &mut input) };
        Ok((output.pos as usize, input.pos as usize, result as usize))
    }

    pub fn compress2(
        &mut self,
        output: &mut [u8],
        input: &[u8],
        op: EndDirective,
    ) -> Result<(usize, usize, usize), Error> {
        let mut input = ZSTD_inBuffer {
            src: input.as_ptr() as *const c_void,
            size: input.len() as size_t,
            pos: 0,
        };
        let mut output = ZSTD_outBuffer {
            dst: output.as_mut_ptr() as *mut c_void,
            size: output.len() as size_t,
            pos: 0,
        };
        let result = unsafe {
            ZSTD_compressStream2(self.p, &mut output, &mut input, op as ZSTD_EndDirective)
        };
        Ok((output.pos as usize, input.pos as usize, result as usize))
    }

    pub fn flush(&mut self, output: &mut [u8]) -> Result<(usize, usize), Error> {
        let mut output = ZSTD_outBuffer {
            dst: output.as_mut_ptr() as *mut c_void,
            size: output.len() as size_t,
            pos: 0,
        };
        let result = unsafe { ZSTD_flushStream(self.p, &mut output) };
        Ok((output.pos as usize, result as usize))
    }

    /// Finish writing the message, i.e. write the remaining pending block.
    pub fn end(&mut self, output: &mut [u8]) -> Result<usize, Error> {
        let mut output = ZSTD_outBuffer {
            dst: output.as_mut_ptr() as *mut c_void,
            size: output.len() as size_t,
            pos: 0,
        };

        unsafe {
            let result = ZSTD_endStream(self.p, &mut output);
            if ZSTD_isError(result) != 0 {
                return Err(Error::ZSTD(ZSTDError(result)));
            }
        }

        Ok(output.pos as usize)
    }
}

pub enum EndDirective {
    Continue = 0,
    Flush = 1,
    End = 2,
}

pub struct DStream {
    p: *mut ZSTD_DStream,
}

impl Drop for DStream {
    fn drop(&mut self) {
        unsafe {
            ZSTD_freeDStream(self.p);
            self.p = ptr::null_mut();
        }
    }
}

impl DStream {
    pub fn new() -> Result<Self, Error> {
        unsafe {
            let p = ZSTD_createDStream();
            if p.is_null() {
                return Err(Error::Null);
            }
            let result = ZSTD_initDStream(p);
            if ZSTD_isError(result) != 0 {
                return Err(Error::ZSTD(ZSTDError(result)));
            }
            Ok(DStream { p })
        }
    }

    pub fn decompress(&mut self, output: &mut [u8], input: &[u8]) -> Result<(usize, usize), Error> {
        let mut input = ZSTD_inBuffer {
            src: input.as_ptr() as *const c_void,
            size: input.len() as size_t,
            pos: 0,
        };
        let mut output = ZSTD_outBuffer {
            dst: output.as_mut_ptr() as *mut c_void,
            size: output.len() as size_t,
            pos: 0,
        };

        unsafe {
            let _result = ZSTD_decompressStream(self.p, &mut output, &mut input);

            if ZSTD_isError(_result) != 0 {
                return Err(Error::ZSTD(ZSTDError(_result)));
            }
        }

        Ok((output.pos as usize, input.pos as usize))
    }
}