use crate::DynResult;
use std::borrow::Borrow;
use tokio::sync::{mpsc, oneshot};
fn create_tables(conn: &sqlite::Connection) -> sqlite::Result<()> {
conn.execute(
"CREATE TABLE IF NOT EXISTS names (
name TEXT,
UNIQUE (name) ON CONFLICT FAIL)",
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS preferences (
better INTEGER,
worse INTEGER,
FOREIGN KEY (better) REFERENCES names(rowid) ON DELETE CASCADE,
FOREIGN KEY (worse) REFERENCES names(rowid) ON DELETE CASCADE,
UNIQUE(better,worse))",
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS prefs_index ON preferences(better,worse)",
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS prefs_better ON preferences(better)",
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS prefs_worse ON preferences(worse)",
)?;
Ok(())
}
type DBFunction =
Box<dyn for<'a> FnOnce(&'a sqlite::Connection) + Send + Sync + 'static>;
#[derive(Clone)]
pub struct AsyncConnection(mpsc::Sender<DBFunction>);
impl AsyncConnection {
pub async fn open<T: AsRef<std::path::Path> + Send + 'static>(
path: T,
) -> DynResult<Self> {
let (tx, mut rx): (_, mpsc::Receiver<DBFunction>) = mpsc::channel(1);
std::thread::Builder::new()
.name(format!("database thread: {}", path.as_ref().display()))
.spawn(move || {
let conn = sqlite::open(path).unwrap();
while let Some(cb) = rx.blocking_recv() {
cb(&conn);
}
})?;
let result = Self(tx);
result.post(create_tables).await??;
Ok(result)
}
pub async fn add_name(&self, name: String) -> DynResult<(i64, bool)> {
Ok(self
.post(move |conn| {
let mut stmnt =
conn.prepare("INSERT INTO names(name) VALUES(?)")?;
stmnt.bind::<&str>(1, name.borrow())?;
let is_new = stmnt.next().is_ok();
let mut stmnt = conn
.prepare("SELECT rowid FROM names WHERE names.name = ?")?;
stmnt.bind::<&str>(1, name.borrow())?;
match stmnt.next()? {
sqlite::State::Row => Ok((stmnt.read::<i64>(0)?, is_new)),
_ => Err(sqlite::Error {
code: None,
message: Some(String::from(
"Expected rowid from row just inserted",
)),
}),
}
})
.await??)
}
pub async fn compare(
&self,
a: i64,
b: i64,
) -> DynResult<Option<std::cmp::Ordering>> {
if a == b {
return Ok(Some(std::cmp::Ordering::Equal));
}
Ok(self.post(move |conn| -> DynResult<Option<std::cmp::Ordering>> {
let mut stmnt = conn.prepare("WITH RECURSIVE t(better,worse) AS (SELECT better, worse FROM preferences WHERE preferences.worse IN (?,?) UNION SELECT preferences.better, t.worse FROM preferences INNER JOIN t ON preferences.worse = t.better) SELECT t.better, t.worse FROM t WHERE t.better IN (?,?) LIMIT 1")?;
stmnt.bind::<i64>(1,a)?;
stmnt.bind::<i64>(2,b)?;
stmnt.bind::<i64>(3,a)?;
stmnt.bind::<i64>(4,b)?;
match stmnt.next()? {
sqlite::State::Done => Ok(None),
sqlite::State::Row => {
let better = stmnt.read::<i64>(0)?;
let worse = stmnt.read::<i64>(1)?;
if better == a && worse == b {
Ok(Some(std::cmp::Ordering::Greater))
} else if better == b && worse == a {
Ok(Some(std::cmp::Ordering::Less))
} else {
Ok(Some(std::cmp::Ordering::Equal))
}
}
}
}).await??)
}
pub async fn check_depth(&self, a: i64, cutoff: i64) -> DynResult<i64> {
self.post(|conn| {
let mut stmnt = conn.prepare("WITH RECURSIVE t(better,worse) AS (SELECT better, worse FROM preferences WHERE preferences.worse = ? UNION SELECT preferences.better, t.worse FROM preferences INNER JOIN t ON preferences.worse = t.better) SELECT count(*) FROM (SELECT DISTINCT better FROM t) LIMIT ?")?;
stmnt.bind::<i64>(1, a)?;
stmnt.bind::<i64>(2, cutoff)?;
match stmnt.next()? {
sqlite::State::Done => Ok(0),
sqlite::State::Row => Ok(stmnt.read::<i64>(0)?),
}
}).await?
}
pub async fn list_names(&self) -> DynResult<Vec<(String, i64)>> {
match self
.post(|conn| {
let mut stmnt =
conn.prepare("SELECT name, rowid FROM names")?;
let mut rows = Vec::new();
while let sqlite::State::Row = stmnt.next()? {
rows.push((
stmnt.read::<String>(0)?,
stmnt.read::<i64>(1)?,
));
}
Ok(rows)
})
.await
{
Err(e) => Err(e),
Ok(Err(e)) => Err(e),
Ok(Ok(r)) => Ok(r),
}
}
pub async fn insert_preference(
&self,
better: i64,
worse: i64,
) -> DynResult<()> {
self.post(move |db| -> DynResult<_> {
let mut stmnt = db.prepare(
"INSERT INTO preferences(better, worse) VALUES(?, ?)",
)?;
stmnt.bind::<i64>(1, better)?;
stmnt.bind::<i64>(2, worse)?;
stmnt.next()?;
Ok(())
})
.await?
}
async fn post<
F: FnOnce(&sqlite::Connection) -> U + Send,
U: Send + 'static,
>(
&self,
cb: F,
) -> DynResult<U> {
let (sc, rc): (_, oneshot::Receiver<U>) = oneshot::channel();
let h = move |conn: &sqlite::Connection| {
sc.send(cb(conn)).unwrap_or(());
};
self.0
.send(unsafe {
std::mem::transmute::<_, DBFunction>(Box::new(h)
as Box<dyn for<'a> FnOnce(&'a sqlite::Connection) + Send>)
})
.await
.map_err(|mpsc::error::SendError(_)| mpsc::error::SendError(()))?;
Ok(rc.await?)
}
}