mod.rs (8254B)
1 use crate::{ 2 config::{self, Config}, 3 error::{Error, MapResult as _}, 4 }; 5 use core::num::{NonZeroU32, NonZeroUsize}; 6 use diesel::{ 7 SqliteConnection, 8 connection::SimpleConnection as _, 9 r2d2::{self, ConnectionManager, CustomizeConnection, Pool, PooledConnection}, 10 }; 11 use rocket::{ 12 Request, 13 http::Status, 14 request::{FromRequest, Outcome}, 15 }; 16 use std::{panic, sync::Arc, time::Duration}; 17 use tokio::{ 18 runtime, 19 sync::{Mutex, OwnedSemaphorePermit, Semaphore}, 20 task, 21 time::timeout, 22 }; 23 mod schema; 24 25 // These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools 26 // A wrapper around spawn_blocking that propagates panics to the calling code. 27 async fn run_blocking<F, R>(job: F) -> R 28 where 29 F: FnOnce() -> R + Send + 'static, 30 R: Send + 'static, 31 { 32 match task::spawn_blocking(job).await { 33 Ok(ret) => ret, 34 Err(e) => e.try_into_panic().map_or_else( 35 |_| unreachable!("spawn_blocking tasks are never cancelled"), 36 |panic| panic::resume_unwind(panic), 37 ), 38 } 39 } 40 pub struct DbConn { 41 conn: Arc<Mutex<Option<PooledConnection<ConnectionManager<SqliteConnection>>>>>, 42 permit: Option<OwnedSemaphorePermit>, 43 } 44 #[derive(Debug)] 45 struct DbConnOptions; 46 impl CustomizeConnection<SqliteConnection, r2d2::Error> for DbConnOptions { 47 fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), r2d2::Error> { 48 conn.batch_execute("PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;") 49 .map_err(r2d2::Error::QueryError) 50 } 51 } 52 #[derive(Clone)] 53 pub struct DbPool { 54 // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'. 55 pool: Option<Pool<ConnectionManager<SqliteConnection>>>, 56 semaphore: Arc<Semaphore>, 57 } 58 impl Drop for DbConn { 59 fn drop(&mut self) { 60 let conn = Arc::clone(&self.conn); 61 let permit = self.permit.take(); 62 // Since connection can't be on the stack in an async fn during an 63 // await, we have to spawn a new blocking-safe thread... 64 task::spawn_blocking(move || { 65 // And then re-enter the runtime to wait on the async mutex, but in a blocking fashion. 66 let mut conn = runtime::Handle::current().block_on(conn.lock_owned()); 67 if let Some(conn) = conn.take() { 68 drop(conn); 69 } 70 // Drop permit after the connection is dropped 71 drop(permit); 72 }); 73 } 74 } 75 impl Drop for DbPool { 76 fn drop(&mut self) { 77 let pool = self.pool.take(); 78 task::spawn_blocking(move || drop(pool)); 79 } 80 } 81 impl DbPool { 82 // For the given database URL, guess its type, run migrations, create pool, and return it 83 pub fn from_config() -> Result<Self, Error> { 84 let url = Config::DATABASE_URL; 85 let manager = ConnectionManager::new(url); 86 let pool = Pool::builder() 87 .max_size(NonZeroU32::from(config::get_config().database_max_conns).get()) 88 .connection_timeout(Duration::from_secs(u64::from( 89 config::get_config().database_timeout, 90 ))) 91 .connection_customizer(Box::new(DbConnOptions)) 92 .build(manager) 93 .map_res("Failed to create pool")?; 94 Ok(Self { 95 pool: Some(pool), 96 semaphore: Arc::new(Semaphore::new( 97 NonZeroUsize::from(config::get_config().database_max_conns).get(), 98 )), 99 }) 100 } 101 // Get a connection from the pool 102 async fn get(&self) -> Result<DbConn, Error> { 103 let duration = Duration::from_secs(u64::from(config::get_config().database_timeout)); 104 let permit = match timeout(duration, Arc::clone(&self.semaphore).acquire_owned()).await { 105 Ok(p) => p.expect("Semaphore should be open"), 106 Err(_) => { 107 err!("Timeout waiting for database connection"); 108 } 109 }; 110 let pool = self 111 .pool 112 .as_ref() 113 .expect("DbPool.pool should always be Some()") 114 .clone(); 115 let c = run_blocking(move || pool.get_timeout(duration)) 116 .await 117 .map_res("Error retrieving connection from pool")?; 118 Ok(DbConn { 119 conn: Arc::new(Mutex::new(Some(c))), 120 permit: Some(permit), 121 }) 122 } 123 } 124 #[macro_export] 125 macro_rules! db_run { 126 ( $conn:ident: $body:block ) => { 127 #[allow(unused)] 128 use diesel::prelude::*; 129 use tokio::task; 130 #[allow(unused)] 131 use $crate::db::FromDb as _; 132 let mut con = $conn.conn.clone().lock_owned().await; 133 paste::paste! { 134 #[allow(unused)] use $crate::db::schema::{self, *}; 135 #[allow(unused)] use __sqlite_model::*; 136 } 137 task::block_in_place(move || { 138 let $conn = con 139 .as_mut() 140 .expect("internal invariant broken: self.connection is Some"); 141 $body 142 }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead 143 }; 144 } 145 146 trait FromDb { 147 type Output; 148 #[allow(clippy::wrong_self_convention)] 149 fn from_db(self) -> Self::Output; 150 } 151 152 impl<T: FromDb> FromDb for Vec<T> { 153 type Output = Vec<T::Output>; 154 #[allow(clippy::wrong_self_convention)] 155 #[inline] 156 fn from_db(self) -> Self::Output { 157 self.into_iter().map(FromDb::from_db).collect() 158 } 159 } 160 161 impl<T: FromDb> FromDb for Option<T> { 162 type Output = Option<T::Output>; 163 #[allow(clippy::wrong_self_convention)] 164 #[inline] 165 fn from_db(self) -> Self::Output { 166 self.map(FromDb::from_db) 167 } 168 } 169 170 // For each struct eg. Cipher, we create a CipherDb inside a module named __sqlite_model, 171 // to implement the Diesel traits. We also provide methods to convert between them and the basic structs. Later, that module will be auto imported when using db_run! 172 #[macro_export] 173 macro_rules! db_object { 174 ( $( 175 $( #[$attr:meta] )* 176 pub struct $name:ident { 177 $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty ),+ 178 $(,)? 179 } 180 )+ ) => { 181 // Create the normal struct, without attributes 182 $( pub struct $name { $( /*$( #[$field_attr] )**/ $vis $field : $typ, )+ } )+ 183 mod __sqlite_model { $( db_object! { $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ } 184 }; 185 186 ( $( #[$attr:meta] )* | $name:ident | $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty),+) => { 187 paste::paste! { 188 #[allow(unused)] use super::*; 189 #[allow(unused)] use diesel::prelude::*; 190 #[allow(unused)] use $crate::db::schema::*; 191 192 $( #[$attr] )* 193 pub struct [<$name Db>] { $( 194 $( #[$field_attr] )* $vis $field : $typ, 195 )+ } 196 197 impl [<$name Db>] { 198 #[allow(clippy::used_underscore_binding, clippy::wrong_self_convention)] 199 #[inline] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } } 200 } 201 202 impl $crate::db::FromDb for [<$name Db>] { 203 type Output = super::$name; 204 #[allow(clippy::used_underscore_binding, clippy::wrong_self_convention)] 205 #[inline] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } } 206 } 207 } 208 }; 209 } 210 211 // Reexport the models, needs to be after the macros are defined so it can access them 212 pub mod models; 213 /// Attempts to retrieve a single connection from the managed database pool. If 214 /// no pool is currently managed, fails with an `InternalServerError` status. If 215 /// no connections are available, fails with a `ServiceUnavailable` status. 216 #[rocket::async_trait] 217 impl<'r> FromRequest<'r> for DbConn { 218 type Error = (); 219 220 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 221 match request.rocket().state::<DbPool>() { 222 Some(p) => (p.get().await).map_or_else( 223 |_| Outcome::Error((Status::ServiceUnavailable, ())), 224 Outcome::Success, 225 ), 226 None => Outcome::Error((Status::InternalServerError, ())), 227 } 228 } 229 }