use ci::Message;
use serde_derive::*;
use std::io::Read;
use std::net::{Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use tracing::*;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[derive(Serialize, Deserialize)]
struct ConfigFile {
    key_path: String,
    port: u16,
    timeout_secs: usize,
    server_public_keys: Vec<String>,
    log_path: String,
    tarball_path: String,
}

#[derive(Serialize, Deserialize)]
struct BuildResult {
    finished: chrono::DateTime<chrono::Utc>,
    status: Option<i32>,
    link: Option<PathBuf>,
    job: ci::Job,
}

use clap::*;

#[derive(Debug, Parser)]
pub struct App {
    #[arg(short, long)]
    config: PathBuf,
}

#[tokio::main]
async fn main() {
    tracing_subscriber::registry()
        .with(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| "ci=debug".into()),
        )
        .with(tracing_subscriber::fmt::layer())
        .init();
    let matches = App::parse();
    let conf: ConfigFile =
        toml::from_str(&std::fs::read_to_string(&matches.config).unwrap()).unwrap();
    let config = Arc::new(thrussh::client::Config::default());
    let addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0);
    let addr = (addr, conf.port).to_socket_addrs().unwrap().next().unwrap();
    let key = Arc::new(thrussh_keys::load_secret_key(&conf.key_path, None).unwrap());
    let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
    let client = CiClient {
        process: Arc::new(Mutex::new(Process::default())),
        log_path: Path::new(&conf.log_path).to_path_buf(),
        tarball_path: Path::new(&conf.tarball_path).to_path_buf(),
        last_window_adjustment: SystemTime::now(),
        server_public_keys: Arc::new(
            conf.server_public_keys
                .iter()
                .map(|p| thrussh_keys::parse_public_key_base64(p).unwrap())
                .collect(),
        ),
        sender,
    };
    loop {
        if let Err(e) = client
            .protocol(&addr, config.clone(), key.clone(), &mut receiver)
            .await
        {
            error!("restarting because of error: {:?}", e)
        }
        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
    }
}

#[derive(Clone, Debug)]
pub struct CiClient {
    process: Arc<Mutex<Process>>,
    tarball_path: PathBuf,
    log_path: PathBuf,
    last_window_adjustment: SystemTime,
    server_public_keys: Arc<Vec<thrussh_keys::key::PublicKey>>,
    sender: tokio::sync::mpsc::Sender<(ci::Job, Option<i32>, Option<PathBuf>)>,
}

#[derive(Debug, Default)]
struct Process {
    child: Option<tokio::task::JoinHandle<Result<(), anyhow::Error>>>,
    job: Option<ci::Job>,
    tarball: Option<(std::fs::File, usize)>,
}

impl Process {
    fn is_ready(&self) -> bool {
        self.child.is_none() && self.job.is_none()
    }
}

impl CiClient {
    pub async fn protocol(
        &self,
        addr: &SocketAddr,
        config: Arc<thrussh::client::Config>,
        key: Arc<thrussh_keys::key::KeyPair>,
        receiver: &mut tokio::sync::mpsc::Receiver<(ci::Job, Option<i32>, Option<PathBuf>)>,
    ) -> Result<(), anyhow::Error> {
        let mut h = thrussh::client::connect(config, &addr, self.clone()).await?;
        debug!("Opening session");
        if !h.authenticate_publickey("ci", key).await? {
            return Ok(());
        }
        let mut channel = h.channel_open_session().await?;
        channel
            .data(
                &bincode::serialize(&Message::Handshake {
                    version: ci::VERSION,
                    id: 0,
                })
                .unwrap()[..],
            )
            .await?;
        debug!("handshake done");
        'outer: loop {
            if self.process.lock().unwrap().is_ready() {
                channel
                    .data(&bincode::serialize(&Message::Ready).unwrap()[..])
                    .await?;
                debug!("ready");
            }
            loop {
                tokio::select! {
                    msg = channel.wait() => {
                        debug!("msg = {:?}", msg);
                        if !self.handle_msg(&mut channel, &self.sender, msg).await? {
                            break 'outer
                        }
                    }
                    msg = receiver.recv() => {
                        debug!("message {:#?}", msg);
                        if let Some(p) = self.process.lock().unwrap().child.take() {
                            p.await??
                        }
                        if let Some((job, exit_status, path)) = msg {
                            self.send_log(&mut channel, job, exit_status, path).await?
                        }
                        channel.data(&bincode::serialize(&Message::Ready).unwrap()[..]).await?;
                    }
                }
            }
        }
        Ok(())
    }

    async fn handle_msg(
        &self,
        channel: &mut thrussh::client::Channel,
        sender: &tokio::sync::mpsc::Sender<(ci::Job, Option<i32>, Option<PathBuf>)>,
        msg: Option<thrussh::ChannelMsg>,
    ) -> Result<bool, anyhow::Error> {
        match msg {
            Some(thrussh::ChannelMsg::Data { data }) => {
                let mut proc = self.process.lock().unwrap();
                if let Some((mut f, mut len)) = proc.tarball.take() {
                    debug!("len = {:?}", len);
                    use std::io::Write;
                    f.write_all(&data)?;
                    len -= data.len();
                    if len > 0 {
                        proc.tarball = Some((f, len));
                    }
                    return Ok(true);
                }
                let msg = bincode::deserialize::<Message>(&data);
                debug!("msg = {:?}", msg);
                match msg {
                    Ok(Message::Job(job)) => {
                        self.handle_job(channel, sender.clone(), &mut proc, job)
                            .await?
                    }
                    Ok(Message::Chunk { id, len, .. }) => {
                        let p = self.tarball_path.join(&format!("{}.tar.gz.tmp", id));
                        if len == 0 {
                            let p2 = self.tarball_path.join(&format!("{}.tar.gz", id));
                            std::fs::rename(&p, &p2)?;
                            proc.tarball = None;
                            let job = proc.job.take().unwrap();
                            self.handle_job(channel, sender.clone(), &mut proc, job)
                                .await?;
                            return Ok(true);
                        }
                        let file = std::fs::OpenOptions::new()
                            .write(true)
                            .create(true)
                            .append(true)
                            .open(&p)
                            .unwrap();
                        proc.tarball = Some((file, len as usize));
                    }
                    Ok(msg) => {
                        debug!("msg = {:?}", msg);
                    }
                    _ => return Ok(false),
                }
            }
            None => return Ok(false),
            msg => debug!("{:?}", msg),
        }
        Ok(true)
    }

    async fn handle_job(
        &self,
        channel: &mut thrussh::client::Channel,
        sender: tokio::sync::mpsc::Sender<(ci::Job, Option<i32>, Option<PathBuf>)>,
        process: &mut Process,
        job: ci::Job,
    ) -> Result<(), anyhow::Error> {
        let p = self.tarball_path.join(&format!("{}.tar.gz", job.id));
        debug!("p = {:?}", p);
        if std::fs::metadata(&p).is_err() {
            debug!("getting tarball");
            channel
                .data(&bincode::serialize(&Message::GetTarball { id: job.id }).unwrap()[..])
                .await?;
            process.job = Some(job);
            return Ok(());
        }
        debug!("tar = {:?}", p);
        let status = std::process::Command::new("tar")
            .args(&["-xf", p.to_str().unwrap()])
            .current_dir(&self.tarball_path)
            .status()
            .unwrap();
        debug!("nix: {:?}", status);

        let tarballp = self.tarball_path.join(job.id.to_string());
        let logp = self.log_path.clone();

        let result_path = logp.join(&format!("{}.result", job.id));
        if let Ok(mut f) = std::fs::File::open(&result_path) {
            if let Ok(build_result) = serde_json::from_reader::<_, BuildResult>(&mut f) {
                sender
                    .send((build_result.job, build_result.status, build_result.link))
                    .await?;
                return Ok(());
            }
        }

        debug!("p = {:?}", tarballp);
        process.child = Some(tokio::task::spawn(async move {
            let mut process = tokio::process::Command::new("nix-build")
                .arg("default.nix")
                .current_dir(&tarballp)
                .stdin(std::process::Stdio::null())
                .stdout(std::process::Stdio::piped())
                .stderr(std::process::Stdio::piped())
                .spawn()
                .unwrap();
            let stdout = process.stdout.as_mut().unwrap();
            let stderr = process.stderr.as_mut().unwrap();
            let mut fstdout =
                tokio::fs::File::create(logp.join(&format!("{}.stdout", job.id))).await?;
            let mut fstderr =
                tokio::fs::File::create(logp.join(&format!("{}.stderr", job.id))).await?;
            let (a, b) = futures::future::join(
                tokio::io::copy(stdout, &mut fstdout),
                tokio::io::copy(stderr, &mut fstderr),
            )
            .await;
            a?;
            b?;
            let status = process.wait().await?;
            debug!("status = {:?}", status);

            let mut result_file = std::fs::File::create(&result_path)?;
            let link = std::fs::read_link(&tarballp.join("result")).ok();
            serde_json::to_writer(
                &mut result_file,
                &BuildResult {
                    finished: chrono::Utc::now(),
                    status: status.code(),
                    job: job.clone(),
                    link: link.clone(),
                },
            )?;

            sender.send((job, status.code(), link)).await?;

            std::fs::remove_dir_all(&tarballp).unwrap_or(());
            std::fs::remove_file(&p)?;

            Ok(())
        }));
        Ok(())
    }

    async fn send_log(
        &self,
        channel: &mut thrussh::client::Channel,
        job: ci::Job,
        exit_status: Option<i32>,
        path: Option<PathBuf>,
    ) -> Result<(), anyhow::Error> {
        let id = job.id;
        let msg = Message::Log {
            job,
            exit_status,
            path,
        };
        channel.data(&bincode::serialize(&msg).unwrap()[..]).await?;

        let mut buf = Vec::with_capacity(4096);
        debug!(
            "stdout: {:?}",
            self.log_path.join(&format!("{}.stdout", id))
        );
        if let Ok(ref mut stdout) =
            std::fs::File::open(&self.log_path.join(&format!("{}.stdout", id)))
        {
            let len = channel.writable_packet_size().min(MAX_BUF_SIZE);
            buf.resize(len, 0);
            while let Ok(n) = stdout.read(&mut buf) {
                if n == 0 {
                    channel
                        .data(
                            &bincode::serialize(&Message::Chunk {
                                id,
                                stderr: false,
                                len: 0,
                            })
                            .unwrap()[..],
                        )
                        .await?;
                    break;
                }
                channel
                    .data(
                        &bincode::serialize(&Message::Chunk {
                            id,
                            stderr: false,
                            len: n as u32,
                        })
                        .unwrap()[..],
                    )
                    .await?;
                channel.data(&buf[..n]).await?
            }
        }
        if let Ok(ref mut stdout) =
            std::fs::File::open(&self.log_path.join(&format!("{}.stderr", id)))
        {
            let len = channel.writable_packet_size().min(MAX_BUF_SIZE);
            buf.resize(len, 0);
            while let Ok(n) = stdout.read(&mut buf) {
                if n == 0 {
                    channel
                        .data(
                            &bincode::serialize(&Message::Chunk {
                                id,
                                stderr: true,
                                len: 0,
                            })
                            .unwrap()[..],
                        )
                        .await?;
                    break;
                }
                channel
                    .data(
                        &bincode::serialize(&Message::Chunk {
                            id,
                            stderr: true,
                            len: n as u32,
                        })
                        .unwrap()[..],
                    )
                    .await?;
                channel.data(&buf[..n]).await?
            }
        }
        Ok(())
    }
}

const MAX_BUF_SIZE: usize = 1 << 16;

impl thrussh::client::Handler for CiClient {
    type Error = anyhow::Error;

    fn check_server_key(
        self,
        server_public_key: &thrussh_keys::key::PublicKey,
    ) -> impl futures::Future<Output = Result<(Self, bool), Self::Error>> {
        let valid = self
            .server_public_keys
            .iter()
            .any(|p| p == server_public_key);
        futures::future::ready(Ok((self, valid)))
    }

    fn adjust_window(&mut self, _channel: thrussh::ChannelId, target: u32) -> u32 {
        let elapsed = self.last_window_adjustment.elapsed().unwrap();
        self.last_window_adjustment = SystemTime::now();
        if target >= 10_000_000 {
            return target;
        }
        if elapsed < Duration::from_secs(2) {
            target * 2
        } else if elapsed > Duration::from_secs(8) {
            target / 2
        } else {
            target
        }
    }
}