use ci::Message;
use serde_derive::*;
use std::io::Read;
use std::net::{Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use tracing::*;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Serialize, Deserialize)]
struct ConfigFile {
key_path: String,
port: u16,
timeout_secs: usize,
server_public_keys: Vec<String>,
log_path: String,
tarball_path: String,
}
#[derive(Serialize, Deserialize)]
struct BuildResult {
finished: chrono::DateTime<chrono::Utc>,
status: Option<i32>,
link: Option<PathBuf>,
job: ci::Job,
}
use clap::*;
#[derive(Debug, Parser)]
pub struct App {
#[arg(short, long)]
config: PathBuf,
}
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "ci=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let matches = App::parse();
let conf: ConfigFile =
toml::from_str(&std::fs::read_to_string(&matches.config).unwrap()).unwrap();
let config = Arc::new(thrussh::client::Config::default());
let addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0);
let addr = (addr, conf.port).to_socket_addrs().unwrap().next().unwrap();
let key = Arc::new(thrussh_keys::load_secret_key(&conf.key_path, None).unwrap());
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
let client = CiClient {
process: Arc::new(Mutex::new(Process::default())),
log_path: Path::new(&conf.log_path).to_path_buf(),
tarball_path: Path::new(&conf.tarball_path).to_path_buf(),
last_window_adjustment: SystemTime::now(),
server_public_keys: Arc::new(
conf.server_public_keys
.iter()
.map(|p| thrussh_keys::parse_public_key_base64(p).unwrap())
.collect(),
),
sender,
};
loop {
if let Err(e) = client
.protocol(&addr, config.clone(), key.clone(), &mut receiver)
.await
{
error!("restarting because of error: {:?}", e)
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
#[derive(Clone, Debug)]
pub struct CiClient {
process: Arc<Mutex<Process>>,
tarball_path: PathBuf,
log_path: PathBuf,
last_window_adjustment: SystemTime,
server_public_keys: Arc<Vec<thrussh_keys::key::PublicKey>>,
sender: tokio::sync::mpsc::Sender<(ci::Job, Option<i32>, Option<PathBuf>)>,
}
#[derive(Debug, Default)]
struct Process {
child: Option<tokio::task::JoinHandle<Result<(), anyhow::Error>>>,
job: Option<ci::Job>,
tarball: Option<(std::fs::File, usize)>,
}
impl Process {
fn is_ready(&self) -> bool {
self.child.is_none() && self.job.is_none()
}
}
impl CiClient {
pub async fn protocol(
&self,
addr: &SocketAddr,
config: Arc<thrussh::client::Config>,
key: Arc<thrussh_keys::key::KeyPair>,
receiver: &mut tokio::sync::mpsc::Receiver<(ci::Job, Option<i32>, Option<PathBuf>)>,
) -> Result<(), anyhow::Error> {
let mut h = thrussh::client::connect(config, &addr, self.clone()).await?;
debug!("Opening session");
if !h.authenticate_publickey("ci", key).await? {
return Ok(());
}
let mut channel = h.channel_open_session().await?;
channel
.data(
&bincode::serialize(&Message::Handshake {
version: ci::VERSION,
id: 0,
})
.unwrap()[..],
)
.await?;
debug!("handshake done");
'outer: loop {
if self.process.lock().unwrap().is_ready() {
channel
.data(&bincode::serialize(&Message::Ready).unwrap()[..])
.await?;
debug!("ready");
}
loop {
tokio::select! {
msg = channel.wait() => {
debug!("msg = {:?}", msg);
if !self.handle_msg(&mut channel, &self.sender, msg).await? {
break 'outer
}
}
msg = receiver.recv() => {
debug!("message {:#?}", msg);
if let Some(p) = self.process.lock().unwrap().child.take() {
p.await??
}
if let Some((job, exit_status, path)) = msg {
self.send_log(&mut channel, job, exit_status, path).await?
}
channel.data(&bincode::serialize(&Message::Ready).unwrap()[..]).await?;
}
}
}
}
Ok(())
}
async fn handle_msg(
&self,
channel: &mut thrussh::client::Channel,
sender: &tokio::sync::mpsc::Sender<(ci::Job, Option<i32>, Option<PathBuf>)>,
msg: Option<thrussh::ChannelMsg>,
) -> Result<bool, anyhow::Error> {
match msg {
Some(thrussh::ChannelMsg::Data { data }) => {
let mut proc = self.process.lock().unwrap();
if let Some((mut f, mut len)) = proc.tarball.take() {
debug!("len = {:?}", len);
use std::io::Write;
f.write_all(&data)?;
len -= data.len();
if len > 0 {
proc.tarball = Some((f, len));
}
return Ok(true);
}
let msg = bincode::deserialize::<Message>(&data);
debug!("msg = {:?}", msg);
match msg {
Ok(Message::Job(job)) => {
self.handle_job(channel, sender.clone(), &mut proc, job)
.await?
}
Ok(Message::Chunk { id, len, .. }) => {
let p = self.tarball_path.join(&format!("{}.tar.gz.tmp", id));
if len == 0 {
let p2 = self.tarball_path.join(&format!("{}.tar.gz", id));
std::fs::rename(&p, &p2)?;
proc.tarball = None;
let job = proc.job.take().unwrap();
self.handle_job(channel, sender.clone(), &mut proc, job)
.await?;
return Ok(true);
}
let file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.append(true)
.open(&p)
.unwrap();
proc.tarball = Some((file, len as usize));
}
Ok(msg) => {
debug!("msg = {:?}", msg);
}
_ => return Ok(false),
}
}
None => return Ok(false),
msg => debug!("{:?}", msg),
}
Ok(true)
}
async fn handle_job(
&self,
channel: &mut thrussh::client::Channel,
sender: tokio::sync::mpsc::Sender<(ci::Job, Option<i32>, Option<PathBuf>)>,
process: &mut Process,
job: ci::Job,
) -> Result<(), anyhow::Error> {
let p = self.tarball_path.join(&format!("{}.tar.gz", job.id));
debug!("p = {:?}", p);
if std::fs::metadata(&p).is_err() {
debug!("getting tarball");
channel
.data(&bincode::serialize(&Message::GetTarball { id: job.id }).unwrap()[..])
.await?;
process.job = Some(job);
return Ok(());
}
debug!("tar = {:?}", p);
let status = std::process::Command::new("tar")
.args(&["-xf", p.to_str().unwrap()])
.current_dir(&self.tarball_path)
.status()
.unwrap();
debug!("nix: {:?}", status);
let tarballp = self.tarball_path.join(job.id.to_string());
let logp = self.log_path.clone();
let result_path = logp.join(&format!("{}.result", job.id));
if let Ok(mut f) = std::fs::File::open(&result_path) {
if let Ok(build_result) = serde_json::from_reader::<_, BuildResult>(&mut f) {
sender
.send((build_result.job, build_result.status, build_result.link))
.await?;
return Ok(());
}
}
debug!("p = {:?}", tarballp);
process.child = Some(tokio::task::spawn(async move {
let mut process = tokio::process::Command::new("nix-build")
.arg("default.nix")
.current_dir(&tarballp)
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.unwrap();
let stdout = process.stdout.as_mut().unwrap();
let stderr = process.stderr.as_mut().unwrap();
let mut fstdout =
tokio::fs::File::create(logp.join(&format!("{}.stdout", job.id))).await?;
let mut fstderr =
tokio::fs::File::create(logp.join(&format!("{}.stderr", job.id))).await?;
let (a, b) = futures::future::join(
tokio::io::copy(stdout, &mut fstdout),
tokio::io::copy(stderr, &mut fstderr),
)
.await;
a?;
b?;
let status = process.wait().await?;
debug!("status = {:?}", status);
let mut result_file = std::fs::File::create(&result_path)?;
let link = std::fs::read_link(&tarballp.join("result")).ok();
serde_json::to_writer(
&mut result_file,
&BuildResult {
finished: chrono::Utc::now(),
status: status.code(),
job: job.clone(),
link: link.clone(),
},
)?;
sender.send((job, status.code(), link)).await?;
std::fs::remove_dir_all(&tarballp).unwrap_or(());
std::fs::remove_file(&p)?;
Ok(())
}));
Ok(())
}
async fn send_log(
&self,
channel: &mut thrussh::client::Channel,
job: ci::Job,
exit_status: Option<i32>,
path: Option<PathBuf>,
) -> Result<(), anyhow::Error> {
let id = job.id;
let msg = Message::Log {
job,
exit_status,
path,
};
channel.data(&bincode::serialize(&msg).unwrap()[..]).await?;
let mut buf = Vec::with_capacity(4096);
debug!(
"stdout: {:?}",
self.log_path.join(&format!("{}.stdout", id))
);
if let Ok(ref mut stdout) =
std::fs::File::open(&self.log_path.join(&format!("{}.stdout", id)))
{
let len = channel.writable_packet_size().min(MAX_BUF_SIZE);
buf.resize(len, 0);
while let Ok(n) = stdout.read(&mut buf) {
if n == 0 {
channel
.data(
&bincode::serialize(&Message::Chunk {
id,
stderr: false,
len: 0,
})
.unwrap()[..],
)
.await?;
break;
}
channel
.data(
&bincode::serialize(&Message::Chunk {
id,
stderr: false,
len: n as u32,
})
.unwrap()[..],
)
.await?;
channel.data(&buf[..n]).await?
}
}
if let Ok(ref mut stdout) =
std::fs::File::open(&self.log_path.join(&format!("{}.stderr", id)))
{
let len = channel.writable_packet_size().min(MAX_BUF_SIZE);
buf.resize(len, 0);
while let Ok(n) = stdout.read(&mut buf) {
if n == 0 {
channel
.data(
&bincode::serialize(&Message::Chunk {
id,
stderr: true,
len: 0,
})
.unwrap()[..],
)
.await?;
break;
}
channel
.data(
&bincode::serialize(&Message::Chunk {
id,
stderr: true,
len: n as u32,
})
.unwrap()[..],
)
.await?;
channel.data(&buf[..n]).await?
}
}
Ok(())
}
}
const MAX_BUF_SIZE: usize = 1 << 16;
impl thrussh::client::Handler for CiClient {
type Error = anyhow::Error;
fn check_server_key(
self,
server_public_key: &thrussh_keys::key::PublicKey,
) -> impl futures::Future<Output = Result<(Self, bool), Self::Error>> {
let valid = self
.server_public_keys
.iter()
.any(|p| p == server_public_key);
futures::future::ready(Ok((self, valid)))
}
fn adjust_window(&mut self, _channel: thrussh::ChannelId, target: u32) -> u32 {
let elapsed = self.last_window_adjustment.elapsed().unwrap();
self.last_window_adjustment = SystemTime::now();
if target >= 10_000_000 {
return target;
}
if elapsed < Duration::from_secs(2) {
target * 2
} else if elapsed > Duration::from_secs(8) {
target / 2
} else {
target
}
}
}