use std::{io::Cursor, str::SplitWhitespace};

extern crate murmur3;
extern crate ndarray;
extern crate ndarray_linalg;
extern crate sprs;
extern crate unicode_normalization;

use murmur3::murmur3_32;
use ndarray::{azip, Array1, Array2, Axis};
use ndarray_linalg::Norm;
use sprs::CsMat;
use unicode_normalization::UnicodeNormalization;


/// Vectorises a corpus using the hashing trick.
#[derive(Debug, Clone)]
pub struct HashingVectorizer {
    n_min: usize,
    n_max: usize,
    n_features: u32,
    normalise: bool,
}

impl HashingVectorizer {
    pub fn new(n_features: u32, n_min: usize, n_max: usize, normalise: bool) -> Self {
        Self {
            n_features,
            n_min,
            n_max,
            normalise,
        }
    }

    /// Transforms the input into a sparse matrix of occurrences
    pub fn fit_transform(&self, input: &[&str]) -> CsMat<f64> {
        assert!(input.len() != 0);

        let mut row_idx = vec![0];
        let mut col_idx = Vec::new();
        let mut values = Vec::new();
        let mut size = 0;
        let mut sorting_buf = Vec::new();

        for (_, &s) in input.iter().enumerate() {
            let lower = s.to_lowercase();
            for ng in NGramIter::new(&lower, 5) {
                let h = murmur3_32(&mut Cursor::new(ng), 0).unwrap();
                // NOTE: yes, we need all that fuckery because scikit-learn uses a i32 for whatever stupid reason. Fuck them
                let f_idx = (h as i32).abs() as u32 % self.n_features;
                // If the highest bit is 0, we want switch to be a 1
                // This is analogous to scikit-learn's comparison >= 0
                let switch = (!h & 0x80_00_00_00) >> 31;
                let val = (switch * 2) as f64 - 1.0;
                col_idx.push(f_idx as usize);
                values.push(val);
            }

            let start = *row_idx.last().unwrap();
            let row_size = sort_dedup(
                &mut col_idx[start..],
                &mut values[start..],
                &mut sorting_buf,
            );
            size += row_size;
            let end = start + row_size;
            col_idx.truncate(end);
            values.truncate(end);

            row_idx.push(size);
        }

        let mut m = CsMat::new_from_unsorted(
            (input.len(), self.n_features as usize),
            row_idx,
            col_idx,
            values,
        )
        .unwrap();

        if self.normalise {
            for mut row in m.outer_iterator_mut() {
                row.unit_normalize();
            }
        }

        m
    }
}

fn normalize_str(input: &str) -> String {
    input.nfkd().filter(char::is_ascii).collect()
}

/// Sort and deduplicate the index and value arrays for a CSR matrix
fn sort_dedup(indices: &mut [usize], values: &mut [f64], buf: &mut Vec<(usize, f64)>) -> usize {
    buf.clear();
    buf.reserve_exact(indices.len());

    buf.extend(indices.iter().zip(values.iter()).map(|(&i, &v)| (i, v)));

    // We sort first
    buf.sort_unstable_by_key(|&(i, _)| i);

    // And now we dedup
    let mut last = None;
    let mut target_iter = indices.iter_mut().zip(values.iter_mut());
    let mut target = target_iter.next().unwrap();
    let mut size = 0;
    for (idx, val) in buf.drain(..) {
        match last {
            Some(li) => {
                if idx == li {
                    *target.1 += val;
                } else {
                    target = target_iter.next().unwrap();
                    size += 1;
                    *target.0 = idx;
                    *target.1 = val;
                    last = Some(idx);
                }
            }
            None => {
                *target.0 = idx;
                *target.1 = val;
                last = Some(idx);
                size += 1;
            }
        }
    }

    size
}

/// Transformer used to acquire a TF-IDF representation of the provided frequency matrix.
pub struct TfidfTransformer {
    smooth_idf: bool,
    idf: Option<Array1<f64>>,
}

impl TfidfTransformer {
    /// Constructs a new `TfidfTransformer`
    ///
    /// # Arguments
    ///
    /// * `smooth_idf` - Adds a 1 to the numerator and denominator of the IDF-term. Prevents zero divisions
    pub fn new(smooth_idf: bool) -> Self {
        Self {
            smooth_idf,
            idf: None,
        }
    }

    /// Fit the transformer to the input matrix
    pub fn fit(&mut self, matrix: &CsMat<f64>) {
        fn bincount(mat: &CsMat<f64>) -> Vec<f64> {
            let mut v = vec![0.0; mat.cols()];
            for &idx in mat.indices() {
                v[idx] += 1.0;
            }
            v
        }

        // We need the number of documents to calculate the idf
        let mut n_samples = matrix.rows();
        let df = bincount(matrix);
        assert_eq!(matrix.cols(), df.len());

        let mut idf_vector = Array1::from_iter(df.into_iter());

        // Apply smoothing if desired. Acts as if there is one document containing every term
        if self.smooth_idf {
            idf_vector += 1.;
            n_samples += 1;
        }

        // Calculate idf from df
        idf_vector.mapv_inplace(|df| (n_samples as f64 / df).ln() + 1.0);
        self.idf = Some(idf_vector);
    }

    /// Transform the frequency matrix to a TF-IDF representation
    pub fn transform(&self, matrix: &CsMat<f64>) -> Array2<f64> {
        let mut matrix = matrix.to_dense();
        for mut row in matrix.axis_iter_mut(Axis(0)) {
            azip!((tf in &mut row, &idf in self.idf.as_ref().unwrap()) *tf = *tf * idf);
            let norm = row.norm();
            row /= norm;
        }
        matrix
    }

    pub fn fit_transform(&mut self, matrix: &CsMat<f64>) -> Array2<f64> {
        if self.idf.is_none() {
            self.fit(matrix);
        }
        self.transform(matrix)
    }

    pub fn idf(&self) -> Option<&Array1<f64>> {
        self.idf.as_ref()
    }
}

/// Iterator over n-grams analogous to those used by scikit-learn's HashingVectorizer
#[derive(Debug, Clone)]
struct NGramIter<'a> {
    source: SplitWhitespace<'a>,
    cur_word: Option<String>,
    size: usize,
    pos: usize,
}

impl<'a> NGramIter<'a> {
    pub fn new(source: &'a str, size: usize) -> Self {
        assert!(size != 0);
        Self {
            source: source.split_whitespace(),
            cur_word: None,
            size,
            pos: 0,
        }
    }
}

impl<'a> Iterator for NGramIter<'a> {
    type Item = String;

    fn next(&mut self) -> Option<Self::Item> {
        let word = match self.cur_word.as_ref() {
            Some(w) => Some(w),
            None => {
                // This is a bit wasteful since we allocate a new String for every word
                self.cur_word = self.source.next().map(|w| format!(" {} ", w));
                self.pos = 0;
                self.cur_word.as_ref()
            }
        };

        let mut next_word = false;
        let res = match word {
            Some(w) => {
                // Be careful not to slice inside a multibyte char
                let mut end = self.pos + self.size;
                if end >= w.len() {
                    end = w.len();
                    next_word = true;
                }

                while !w.is_char_boundary(end) {
                    end += 1;
                }

                let ret = &w[self.pos..end];

                // Also don't start inside a multibyte char next time
                self.pos += 1;
                while !w.is_char_boundary(self.pos) {
                    self.pos += 1;
                }
                Some(ret.to_owned())
            }
            None => None,
        };
        if next_word {
            self.cur_word = None;
        }

        res
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashSet;

    use ndarray::{array, Array};

    use crate::{HashingVectorizer, NGramIter, TfidfTransformer};

    static INPUT: [&'static str; 4] = [
        "This is the first document.",
        "This document is the second document.",
        "And this is the third one.",
        "Is this the first document?",
    ];

    fn round_dec(n: f64, d: usize) -> f64 {
        let md = 10.0_f64.powi(d as i32);
        (n * md).round() / md
    }

    #[test]
    fn ngram_5() {
        let input = "This is the first document.";
        let expected: HashSet<_> = [
            " docu", " firs", " is ", " the ", " This", "cumen", "docum", "ent. ", "first",
            "irst ", "ment.", "ocume", "This ", "ument",
        ]
        .iter()
        .map(|&s| s.to_owned())
        .collect();
        let res: HashSet<_> = NGramIter::new(input, 5).collect();
        assert_eq!(expected, res);
    }

    #[test]
    fn vectorization_hashing_not_normalised() {
        let v = HashingVectorizer::new(16, 5, 5, false);
        let r = v.fit_transform(&INPUT).to_dense();
        let t: Array<f64, _> = array![
            [
                0.0, 0.0, -1.0, -1.0, 1.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, -1.0, -1.0, 0.0, -1.0,
                0.0
            ],
            [
                0.0, 1.0, -2.0, -2.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0,
                -1.0,
            ],
            [-2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 1.0, -1.0, 0.0],
            [0.0, 0.0, -2.0, -1.0, 1.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        ];
        assert_eq!(t, r);
    }

    #[test]
    fn vectorization_hashing_normalised() {
        let v = HashingVectorizer::new(16, 5, 5, true);
        let t = array![
            [
                0.,
                0.,
                -0.35355339,
                -0.35355339,
                0.35355339,
                0.,
                0.,
                -0.35355339,
                -0.35355339,
                0.,
                0.,
                -0.35355339,
                -0.35355339,
                0.,
                -0.35355339,
                0.
            ],
            [
                0.,
                0.24253563,
                -0.48507125,
                -0.48507125,
                0.24253563,
                0.24253563,
                0.,
                0.24253563,
                0.24253563,
                0.,
                -0.24253563,
                0.,
                -0.24253563,
                0.,
                -0.24253563,
                -0.24253563
            ],
            [
                -0.63245553,
                0.,
                0.,
                0.,
                0.31622777,
                0.31622777,
                0.,
                -0.31622777,
                0.,
                0.,
                0.,
                -0.31622777,
                0.,
                0.31622777,
                -0.31622777,
                0.
            ],
            [
                0.,
                0.,
                -0.70710678,
                -0.35355339,
                0.35355339,
                0.,
                0.,
                -0.35355339,
                -0.35355339,
                0.,
                0.,
                0.,
                0.,
                0.,
                0.,
                0.
            ]
        ];
        let mut r = v.fit_transform(&INPUT).to_dense();
        r.mapv_inplace(|v| round_dec(v, 8));
        assert_eq!(t, r);
    }

    #[test]
    fn tfidf_transformation() {
        let v = HashingVectorizer::new(16, 5, 5, false);
        let mut tf = TfidfTransformer::new(true);

        let t = array![
            [
                0.,
                0.,
                -0.3726943,
                -0.3726943,
                0.30470201,
                0.,
                0.,
                -0.30470201,
                -0.30470201,
                0.,
                0.,
                -0.30470201,
                -0.46035161,
                0.,
                -0.3726943,
                0.
            ],
            [
                0.,
                0.3506238,
                -0.44759726,
                -0.44759726,
                0.18297004,
                0.27643583,
                0.,
                0.18297004,
                0.18297004,
                0.,
                -0.22379863,
                0.,
                -0.27643583,
                0.,
                -0.22379863,
                -0.3506238
            ],
            [
                -0.76438624,
                0.,
                0.,
                0.,
                0.19944423,
                0.30132545,
                0.,
                -0.19944423,
                0.,
                0.,
                0.,
                -0.19944423,
                0.,
                0.38219312,
                -0.24394892,
                0.
            ],
            [
                0.,
                0.,
                -0.75564616,
                -0.37782308,
                0.30889513,
                0.,
                0.,
                -0.30889513,
                -0.30889513,
                0.,
                0.,
                0.,
                0.,
                0.,
                0.,
                0.
            ]
        ];

        let r = v.fit_transform(&INPUT);
        tf.fit(&r);
        let mut r = tf.transform(&r);

        r.mapv_inplace(|v| round_dec(v, 8));

        assert_eq!(t, r);
    }
}