mod.rs (8182B)
1 use crate::{ 2 config::{self, Config}, 3 error::{Error, MapResult}, 4 }; 5 use diesel::{ 6 connection::SimpleConnection, 7 r2d2::{self, ConnectionManager, CustomizeConnection, Pool, PooledConnection}, 8 SqliteConnection, 9 }; 10 use rocket::{ 11 http::Status, 12 request::{FromRequest, Outcome}, 13 Request, 14 }; 15 use std::{panic, sync::Arc, time::Duration}; 16 use tokio::{ 17 runtime, 18 sync::{Mutex, OwnedSemaphorePermit, Semaphore}, 19 task, 20 time::timeout, 21 }; 22 mod schema; 23 24 // These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools 25 // A wrapper around spawn_blocking that propagates panics to the calling code. 26 async fn run_blocking<F, R>(job: F) -> R 27 where 28 F: FnOnce() -> R + Send + 'static, 29 R: Send + 'static, 30 { 31 match task::spawn_blocking(job).await { 32 Ok(ret) => ret, 33 Err(e) => e.try_into_panic().map_or_else( 34 |_| unreachable!("spawn_blocking tasks are never cancelled"), 35 |panic| panic::resume_unwind(panic), 36 ), 37 } 38 } 39 pub struct DbConn { 40 conn: Arc<Mutex<Option<PooledConnection<ConnectionManager<SqliteConnection>>>>>, 41 permit: Option<OwnedSemaphorePermit>, 42 } 43 #[derive(Debug)] 44 struct DbConnOptions; 45 impl CustomizeConnection<SqliteConnection, r2d2::Error> for DbConnOptions { 46 fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), r2d2::Error> { 47 conn.batch_execute("PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;") 48 .map_err(r2d2::Error::QueryError) 49 } 50 } 51 #[derive(Clone)] 52 pub struct DbPool { 53 // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'. 54 pool: Option<Pool<ConnectionManager<SqliteConnection>>>, 55 semaphore: Arc<Semaphore>, 56 } 57 impl Drop for DbConn { 58 fn drop(&mut self) { 59 let conn = Arc::clone(&self.conn); 60 let permit = self.permit.take(); 61 // Since connection can't be on the stack in an async fn during an 62 // await, we have to spawn a new blocking-safe thread... 63 task::spawn_blocking(move || { 64 // And then re-enter the runtime to wait on the async mutex, but in a blocking fashion. 65 let mut conn = runtime::Handle::current().block_on(conn.lock_owned()); 66 if let Some(conn) = conn.take() { 67 drop(conn); 68 } 69 // Drop permit after the connection is dropped 70 drop(permit); 71 }); 72 } 73 } 74 impl Drop for DbPool { 75 fn drop(&mut self) { 76 let pool = self.pool.take(); 77 task::spawn_blocking(move || drop(pool)); 78 } 79 } 80 impl DbPool { 81 // For the given database URL, guess its type, run migrations, create pool, and return it 82 pub fn from_config() -> Result<Self, Error> { 83 let url = Config::DATABASE_URL; 84 let manager = ConnectionManager::new(url); 85 let pool = Pool::builder() 86 .max_size(u32::from(config::get_config().database_max_conns.get())) 87 .connection_timeout(Duration::from_secs(u64::from( 88 config::get_config().database_timeout, 89 ))) 90 .connection_customizer(Box::new(DbConnOptions)) 91 .build(manager) 92 .map_res("Failed to create pool")?; 93 Ok(Self { 94 pool: Some(pool), 95 semaphore: Arc::new(Semaphore::new(usize::from( 96 config::get_config().database_max_conns.get(), 97 ))), 98 }) 99 } 100 // Get a connection from the pool 101 async fn get(&self) -> Result<DbConn, Error> { 102 let duration = Duration::from_secs(u64::from(config::get_config().database_timeout)); 103 let permit = match timeout(duration, Arc::clone(&self.semaphore).acquire_owned()).await { 104 Ok(p) => p.expect("Semaphore should be open"), 105 Err(_) => { 106 err!("Timeout waiting for database connection"); 107 } 108 }; 109 let pool = self 110 .pool 111 .as_ref() 112 .expect("DbPool.pool should always be Some()") 113 .clone(); 114 let c = run_blocking(move || pool.get_timeout(duration)) 115 .await 116 .map_res("Error retrieving connection from pool")?; 117 Ok(DbConn { 118 conn: Arc::new(Mutex::new(Some(c))), 119 permit: Some(permit), 120 }) 121 } 122 } 123 #[macro_export] 124 macro_rules! db_run { 125 ( $conn:ident: $body:block ) => { 126 #[allow(unused)] 127 use diesel::prelude::*; 128 use tokio::task; 129 #[allow(unused)] 130 use $crate::db::FromDb; 131 let mut con = $conn.conn.clone().lock_owned().await; 132 paste::paste! { 133 #[allow(unused)] use $crate::db::schema::{self, *}; 134 #[allow(unused)] use __sqlite_model::*; 135 } 136 task::block_in_place(move || { 137 let $conn = con 138 .as_mut() 139 .expect("internal invariant broken: self.connection is Some"); 140 $body 141 }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead 142 }; 143 } 144 145 trait FromDb { 146 type Output; 147 #[allow(clippy::wrong_self_convention)] 148 fn from_db(self) -> Self::Output; 149 } 150 151 impl<T: FromDb> FromDb for Vec<T> { 152 type Output = Vec<T::Output>; 153 #[allow(clippy::wrong_self_convention)] 154 #[inline] 155 fn from_db(self) -> Self::Output { 156 self.into_iter().map(FromDb::from_db).collect() 157 } 158 } 159 160 impl<T: FromDb> FromDb for Option<T> { 161 type Output = Option<T::Output>; 162 #[allow(clippy::wrong_self_convention)] 163 #[inline] 164 fn from_db(self) -> Self::Output { 165 self.map(FromDb::from_db) 166 } 167 } 168 169 // For each struct eg. Cipher, we create a CipherDb inside a module named __sqlite_model, 170 // to implement the Diesel traits. We also provide methods to convert between them and the basic structs. Later, that module will be auto imported when using db_run! 171 #[macro_export] 172 macro_rules! db_object { 173 ( $( 174 $( #[$attr:meta] )* 175 pub struct $name:ident { 176 $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty ),+ 177 $(,)? 178 } 179 )+ ) => { 180 // Create the normal struct, without attributes 181 $( pub struct $name { $( /*$( #[$field_attr] )**/ $vis $field : $typ, )+ } )+ 182 mod __sqlite_model { $( db_object! { $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ } 183 }; 184 185 ( $( #[$attr:meta] )* | $name:ident | $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty),+) => { 186 paste::paste! { 187 #[allow(unused)] use super::*; 188 #[allow(unused)] use diesel::prelude::*; 189 #[allow(unused)] use $crate::db::schema::*; 190 191 $( #[$attr] )* 192 pub struct [<$name Db>] { $( 193 $( #[$field_attr] )* $vis $field : $typ, 194 )+ } 195 196 impl [<$name Db>] { 197 #[allow(clippy::used_underscore_binding, clippy::wrong_self_convention)] 198 #[inline] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } } 199 } 200 201 impl $crate::db::FromDb for [<$name Db>] { 202 type Output = super::$name; 203 #[allow(clippy::used_underscore_binding, clippy::wrong_self_convention)] 204 #[inline] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } } 205 } 206 } 207 }; 208 } 209 210 // Reexport the models, needs to be after the macros are defined so it can access them 211 pub mod models; 212 /// Attempts to retrieve a single connection from the managed database pool. If 213 /// no pool is currently managed, fails with an `InternalServerError` status. If 214 /// no connections are available, fails with a `ServiceUnavailable` status. 215 #[rocket::async_trait] 216 impl<'r> FromRequest<'r> for DbConn { 217 type Error = (); 218 219 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 220 match request.rocket().state::<DbPool>() { 221 Some(p) => (p.get().await).map_or_else( 222 |_| Outcome::Error((Status::ServiceUnavailable, ())), 223 Outcome::Success, 224 ), 225 None => Outcome::Error((Status::InternalServerError, ())), 226 } 227 } 228 }