use crate::Config;
use axum::extract::State;
use axum::response::IntoResponse;
use http_body_util::BodyExt;
use hyper::header::*;
use tracing::*;
pub const HSTS: &str = "max-age=31536000; includeSubDomains";
fn cached(
cache: &std::sync::RwLock<std::collections::HashMap<String, crate::config::Cached>>,
t: Option<&str>,
path: &str,
) -> Option<axum::response::Response> {
if let Ok(c) = cache.read() {
if let Some(cached) = c.get(path) {
debug!("cached {:?}", path);
let mut resp = hyper::Response::new(cached.body.clone().into());
let h = resp.headers_mut();
h.insert(
hyper::header::STRICT_TRANSPORT_SECURITY,
HSTS.try_into().unwrap(),
);
if let Some(ref c) = cached.content_type {
h.insert(CONTENT_TYPE, c.clone());
}
h.insert(
hyper::header::CONTENT_LENGTH,
cached.body.len().to_string().parse().unwrap(),
);
return Some(resp);
}
}
None
}
#[axum::debug_handler]
pub async fn node_proxy_cached(
State(config): State<Config>,
req: axum::extract::Request,
) -> axum::response::Response {
node_proxy(config, true, false, req).await.unwrap()
}
#[axum::debug_handler]
pub async fn node_proxy_strong_cached(
State(config): State<Config>,
req: axum::extract::Request,
) -> axum::response::Response {
node_proxy(config, true, true, req).await.unwrap()
}
#[axum::debug_handler]
pub async fn node_proxy_not_cached(
State(config): State<Config>,
req: axum::extract::Request,
) -> axum::response::Response {
node_proxy(config, false, false, req).await.unwrap()
}
async fn node_proxy(
config: Config,
cache: bool,
forever: bool,
mut req: axum::extract::Request,
) -> Result<axum::response::Response, hyper::http::Error> {
let ref last_server_modif: chrono::DateTime<chrono::Utc> = config.version_time.into();
if let Some(modif) = req.headers().get(http::header::IF_MODIFIED_SINCE) {
debug!("req last_modif: {:?}", modif);
if let Ok(parsed) = chrono::NaiveDateTime::parse_from_str(
std::str::from_utf8(modif.as_bytes()).unwrap_or(""),
crate::config::RFC1123,
) {
let parsed = parsed.and_local_timezone(chrono::Utc).unwrap();
debug!("parsed: {:?} {:?}", parsed, last_server_modif);
if &parsed >= last_server_modif {
return Ok(http::StatusCode::NOT_MODIFIED.into_response());
}
}
}
let mut path = if cache {
let path = req.uri().path();
if let Some(mut r) = cached(&config.cache, None, &path) {
let ref l = config.version_time_str;
debug!("cached {:?}", l);
let h = r.headers_mut();
h.insert(hyper::header::LAST_MODIFIED, l.parse().unwrap());
h.insert(
hyper::header::CACHE_CONTROL,
if forever {
"public, max-age=31536000".parse().unwrap()
} else {
"public".parse().unwrap()
},
);
h.remove(hyper::header::VARY);
return Ok(r);
}
Some(path.to_string())
} else {
None
};
req.headers_mut().remove(ACCEPT_ENCODING);
debug!("node_proxy {:?}", path);
for h in req.headers() {
debug!("{:?}", h);
}
match node_proxy_(&config, req).await {
Ok(resp) => {
let (parts, body) = resp.into_parts();
debug!("parts {:?}", parts);
let body = body.collect().await.unwrap().to_bytes();
debug!("body {:?}", body.len());
if !parts.status.is_success() {
path = None;
}
let html = parts
.headers
.get(CONTENT_TYPE)
.and_then(|x| x.to_str().ok())
== Some("text/html");
let data = if html {
minify_html::minify(&body, &minify_html::Cfg::default()).into()
} else {
body
};
if let Some(path) = path {
config.cache.write().unwrap().insert(
path,
crate::config::Cached {
content_type: parts.headers.get(CONTENT_TYPE).cloned(),
body: data.clone(),
},
);
};
let len = data.len();
debug!("response len {:?}", len);
let mut resp = hyper::Response::new(data.into());
*resp.status_mut() = parts.status;
let h = resp.headers_mut();
for (k, v) in parts.headers {
let k = k.unwrap();
if k == hyper::header::TRANSFER_ENCODING || k == hyper::header::CONNECTION {
continue;
}
h.insert(k, v);
}
h.insert(
hyper::header::CONTENT_LENGTH,
len.to_string().parse().unwrap(),
);
if cache {
let ref l = config.version_time_str;
debug!("cache, last_modified {:?}", l);
h.insert(hyper::header::LAST_MODIFIED, l.parse().unwrap());
if forever {
h.insert(
hyper::header::CACHE_CONTROL,
"public, max-age=31536000".parse().unwrap(),
);
} else {
h.insert(hyper::header::CACHE_CONTROL, "public".parse().unwrap());
}
h.remove(hyper::header::VARY);
} else {
debug!("not cache")
}
Ok(resp)
}
Err(e) => {
error!("error {:?}", e);
Ok(hyper::Response::builder()
.status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
.header(hyper::header::STRICT_TRANSPORT_SECURITY, HSTS)
.body("".into())?)
}
}
}
pub async fn node_proxy_(
config: &Config,
req: hyper::Request<axum::body::Body>,
) -> Result<hyper::Response<axum::body::Body>, crate::Error> {
let mut sender = if let Some(ref svelte_socket) = config.svelte_socket {
use hyper_util::rt::TokioIo;
use tokio::net::UnixStream;
let stream = UnixStream::connect(svelte_socket).await?;
let io = TokioIo::new(stream);
let (sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
error!("Connection failed: {:?}", err);
}
});
sender
} else {
debug!("node_proxy_ 127.0.0.1:5173");
use hyper_util::rt::TokioIo;
use tokio::net::TcpStream;
let stream = TcpStream::connect("127.0.0.1:5173").await?;
let io = TokioIo::new(stream);
let (sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
error!("Connection failed: {:?}", err);
}
});
sender
};
let mut req_ = hyper::Request::builder().method(req.method().clone()).uri(
hyper::Uri::builder()
.path_and_query(req.uri().path_and_query().unwrap().clone())
.build()
.unwrap(),
);
for (h, v) in req.headers() {
if h == HOST {
req_ = req_.header(HOST, config.host.clone());
} else {
req_ = req_.header(h, v.clone());
}
}
let body = req.into_body();
let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let req_: http::Request<axum::body::Body> = req_.body(bytes.into())?;
debug!("node_proxy_ {:?}", req_);
let mut resp = sender.send_request(req_).await?;
let body = resp.body_mut().collect().await.unwrap().to_bytes();
debug!("resp {:?} body {:?}", resp, body.len());
Ok(resp.map(|_| body.into()))
}