use std::{sync::Arc, time::Duration};
use diesel::{
r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection},
use rocket::{
request::{FromRequest, Outcome},
use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
use crate::{
error::{Error, MapResult},
#[path = "schemas/sqlite/"]
pub mod __sqlite_schema;
#[path = "schemas/mysql/"]
pub mod __mysql_schema;
#[path = "schemas/postgresql/"]
pub mod __postgresql_schema;
// These changes are based on Rocket 0.5-rc wrapper of Diesel:
// A wrapper around spawn_blocking that propagates panics to the calling code.
pub async fn run_blocking<F, R>(job: F) -> R
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
match tokio::task::spawn_blocking(job).await {
Ok(ret) => ret,
Err(e) => match e.try_into_panic() {
Ok(panic) => std::panic::resume_unwind(panic),
Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
macro_rules! generate_connections {
( $( $name:ident: $ty:ty ),+ ) => {
#[allow(non_camel_case_types, dead_code)]
#[derive(Eq, PartialEq)]
pub enum DbConnType { $( $name, )+ }
pub struct DbConn {
conn: Arc<Mutex<Option<DbConnInner>>>,
permit: Option<OwnedSemaphorePermit>,
pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
pub struct DbConnOptions {
pub init_stmts: String,
$( // Based on <>.
impl CustomizeConnection<$ty, diesel::r2d2::Error> for DbConnOptions {
fn on_acquire(&self, conn: &mut $ty) -> Result<(), diesel::r2d2::Error> {
(|| {
if !self.init_stmts.is_empty() {
pub struct DbPool {
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
pool: Option<DbPoolInner>,
semaphore: Arc<Semaphore>
pub enum DbPoolInner { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
impl Drop for DbConn {
fn drop(&mut self) {
let conn = self.conn.clone();
let permit = self.permit.take();
// Since connection can't be on the stack in an async fn during an
// await, we have to spawn a new blocking-safe thread...
tokio::task::spawn_blocking(move || {
// And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
if let Some(conn) = conn.take() {
// Drop permit after the connection is dropped
impl Drop for DbPool {
fn drop(&mut self) {
let pool = self.pool.take();
tokio::task::spawn_blocking(move || drop(pool));
impl DbPool {
// For the given database URL, guess its type, run migrations, create pool, and return it
pub fn from_config() -> Result<Self, Error> {
let url = CONFIG.database_url();
let conn_type = DbConnType::from_url(&url)?;
match conn_type { $(
DbConnType::$name => {
paste::paste!{ [< $name _migrations >]::run_migrations()?; }
let manager = ConnectionManager::new(&url);
let pool = Pool::builder()
init_stmts: conn_type.get_init_stmts()
.map_res("Failed to create pool")?;
return Ok(DbPool {
pool: Some(DbPoolInner::$name(pool)),
semaphore: Arc::new(Semaphore::new(CONFIG.database_max_conns() as usize)),
return unreachable!("Trying to use a DB backend when it's feature is disabled");
)+ }
// Get a connection from the pool
pub async fn get(&self) -> Result<DbConn, Error> {
let duration = Duration::from_secs(CONFIG.database_timeout());
let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
Ok(p) => p.expect("Semaphore should be open"),
Err(_) => {
err!("Timeout waiting for database connection");
match self.pool.as_ref().expect("DbPool.pool should always be Some()") { $(
DbPoolInner::$name(p) => {
let pool = p.clone();
let c = run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
return Ok(DbConn {
conn: Arc::new(Mutex::new(Some(DbConnInner::$name(c)))),
permit: Some(permit)
)+ }
generate_connections! {
sqlite: diesel::sqlite::SqliteConnection,
mysql: diesel::mysql::MysqlConnection,
postgresql: diesel::pg::PgConnection
impl DbConnType {
pub fn from_url(url: &str) -> Result<DbConnType, Error> {
// Mysql
if url.starts_with("mysql:") {
return Ok(DbConnType::mysql);
err!("`DATABASE_URL` is a MySQL URL, but the 'mysql' feature is not enabled")
// Postgres
} else if url.starts_with("postgresql:") || url.starts_with("postgres:") {
return Ok(DbConnType::postgresql);
err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled")
} else {
return Ok(DbConnType::sqlite);
err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled")
pub fn get_init_stmts(&self) -> String {
let init_stmts = CONFIG.database_conn_init();
if !init_stmts.is_empty() {
} else {
pub fn default_init_stmts(&self) -> String {
match self {
Self::sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(),
Self::mysql => "".to_string(),
Self::postgresql => "".to_string(),
macro_rules! db_run {
// Same for all dbs
( $conn:ident: $body:block ) => {
db_run! { $conn: sqlite, mysql, postgresql $body }
( @raw $conn:ident: $body:block ) => {
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
// Different code for each db
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*;
#[allow(unused)] use crate::db::FromDb;
let conn = $conn.conn.clone();
let mut conn = conn.lock_owned().await;
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
crate::db::DbConnInner::$db($conn) => {
paste::paste! {
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
#[allow(unused)] use [<__ $db _model>]::*;
tokio::task::block_in_place(move || { $body }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*;
#[allow(unused)] use crate::db::FromDb;
let conn = $conn.conn.clone();
let mut conn = conn.lock_owned().await;
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
crate::db::DbConnInner::$db($conn) => {
paste::paste! {
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
// @ RAW: #[allow(unused)] use [<__ $db _model>]::*;
tokio::task::block_in_place(move || { $body }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
pub trait FromDb {
type Output;
fn from_db(self) -> Self::Output;
impl<T: FromDb> FromDb for Vec<T> {
type Output = Vec<T::Output>;
fn from_db(self) -> Self::Output {
impl<T: FromDb> FromDb for Option<T> {
type Output = Option<T::Output>;
fn from_db(self) -> Self::Output {
// For each struct eg. Cipher, we create a CipherDb inside a module named __$db_model (where $db is sqlite, mysql or postgresql),
// to implement the Diesel traits. We also provide methods to convert between them and the basic structs. Later, that module will be auto imported when using db_run!
macro_rules! db_object {
( $(
$( #[$attr:meta] )*
pub struct $name:ident {
$( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty ),+
)+ ) => {
// Create the normal struct, without attributes
$( pub struct $name { $( /*$( #[$field_attr] )**/ $vis $field : $typ, )+ } )+
pub mod __sqlite_model { $( db_object! { @db sqlite | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
pub mod __mysql_model { $( db_object! { @db mysql | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
pub mod __postgresql_model { $( db_object! { @db postgresql | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
( @db $db:ident | $( #[$attr:meta] )* | $name:ident | $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty),+) => {
paste::paste! {
#[allow(unused)] use super::*;
#[allow(unused)] use diesel::prelude::*;
#[allow(unused)] use crate::db::[<__ $db _schema>]::*;
$( #[$attr] )*
pub struct [<$name Db>] { $(
$( #[$field_attr] )* $vis $field : $typ,
)+ }
impl [<$name Db>] {
#[inline(always)] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } }
impl crate::db::FromDb for [<$name Db>] {
type Output = super::$name;
#[inline(always)] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } }
// Reexport the models, needs to be after the macros are defined so it can access them
pub mod models;
/// Creates a back-up of the sqlite database
/// MySQL/MariaDB and PostgreSQL are not supported.
pub async fn backup_database(conn: &DbConn) -> Result<(), Error> {
db_run! {@raw conn:
postgresql, mysql {
let _ = conn;
err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
sqlite {
use std::path::Path;
let db_url = CONFIG.database_url();
let db_path = Path::new(&db_url).parent().unwrap().to_string_lossy();
let file_date = chrono::Utc::now().format("%Y%m%d_%H%M%S").to_string();
diesel::sql_query(format!("VACUUM INTO '{}/db_{}.sqlite3'", db_path, file_date)).execute(conn)?;
/// Get the SQL Server version
pub async fn get_sql_server_version(conn: &DbConn) -> String {
db_run! {@raw conn:
postgresql, mysql {
no_arg_sql_function!(version, diesel::sql_types::Text);
diesel::select(version).get_result::<String>(conn).unwrap_or_else(|_| "Unknown".to_string())
sqlite {
no_arg_sql_function!(sqlite_version, diesel::sql_types::Text);
diesel::select(sqlite_version).get_result::<String>(conn).unwrap_or_else(|_| "Unknown".to_string())
/// Attempts to retrieve a single connection from the managed database pool. If
/// no pool is currently managed, fails with an `InternalServerError` status. If
/// no connections are available, fails with a `ServiceUnavailable` status.
impl<'r> FromRequest<'r> for DbConn {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.rocket().state::<DbPool>() {
Some(p) => p.get().await.map_err(|_| ()).into_outcome(Status::ServiceUnavailable),
None => Outcome::Failure((Status::InternalServerError, ())),
// Embed the migrations from the migrations folder into the application
// This way, the program automatically migrates the database to the latest version
mod sqlite_migrations {
pub fn run_migrations() -> Result<(), super::Error> {
// Make sure the directory exists
let url = crate::CONFIG.database_url();
let path = std::path::Path::new(&url);
if let Some(parent) = path.parent() {
if std::fs::create_dir_all(parent).is_err() {
error!("Error creating database directory");
use diesel::{Connection, RunQueryDsl};
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
let connection = diesel::sqlite::SqliteConnection::establish(&crate::CONFIG.database_url())?;
// Disable Foreign Key Checks during migration
// Scoped to a connection.
diesel::sql_query("PRAGMA foreign_keys = OFF")
.expect("Failed to disable Foreign Key Checks during migrations");
// Turn on WAL in SQLite
if crate::CONFIG.enable_db_wal() {
diesel::sql_query("PRAGMA journal_mode=wal").execute(&connection).expect("Failed to turn on WAL");
embedded_migrations::run_with_output(&connection, &mut std::io::stdout())?;
mod mysql_migrations {
pub fn run_migrations() -> Result<(), super::Error> {
use diesel::{Connection, RunQueryDsl};
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
let connection = diesel::mysql::MysqlConnection::establish(&crate::CONFIG.database_url())?;
// Disable Foreign Key Checks during migration
// Scoped to a connection/session.
diesel::sql_query("SET FOREIGN_KEY_CHECKS = 0")
.expect("Failed to disable Foreign Key Checks during migrations");
embedded_migrations::run_with_output(&connection, &mut std::io::stdout())?;
mod postgresql_migrations {
pub fn run_migrations() -> Result<(), super::Error> {
use diesel::{Connection, RunQueryDsl};
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
let connection = diesel::pg::PgConnection::establish(&crate::CONFIG.database_url())?;
// Disable Foreign Key Checks during migration
// FIXME: Per,
// "SET CONSTRAINTS sets the behavior of constraint checking within the
// current transaction", so this setting probably won't take effect for
// any of the migrations since it's being run outside of a transaction.
// Migrations that need to disable foreign key checks should run this
// from within the migration script itself.
.expect("Failed to disable Foreign Key Checks during migrations");
embedded_migrations::run_with_output(&connection, &mut std::io::stdout())?;