vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

mod.rs (8254B)


      1 use crate::{
      2     config::{self, Config},
      3     error::{Error, MapResult as _},
      4 };
      5 use core::num::{NonZeroU32, NonZeroUsize};
      6 use diesel::{
      7     SqliteConnection,
      8     connection::SimpleConnection as _,
      9     r2d2::{self, ConnectionManager, CustomizeConnection, Pool, PooledConnection},
     10 };
     11 use rocket::{
     12     Request,
     13     http::Status,
     14     request::{FromRequest, Outcome},
     15 };
     16 use std::{panic, sync::Arc, time::Duration};
     17 use tokio::{
     18     runtime,
     19     sync::{Mutex, OwnedSemaphorePermit, Semaphore},
     20     task,
     21     time::timeout,
     22 };
     23 mod schema;
     24 
     25 // These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
     26 // A wrapper around spawn_blocking that propagates panics to the calling code.
     27 async fn run_blocking<F, R>(job: F) -> R
     28 where
     29     F: FnOnce() -> R + Send + 'static,
     30     R: Send + 'static,
     31 {
     32     match task::spawn_blocking(job).await {
     33         Ok(ret) => ret,
     34         Err(e) => e.try_into_panic().map_or_else(
     35             |_| unreachable!("spawn_blocking tasks are never cancelled"),
     36             |panic| panic::resume_unwind(panic),
     37         ),
     38     }
     39 }
     40 pub struct DbConn {
     41     conn: Arc<Mutex<Option<PooledConnection<ConnectionManager<SqliteConnection>>>>>,
     42     permit: Option<OwnedSemaphorePermit>,
     43 }
     44 #[derive(Debug)]
     45 struct DbConnOptions;
     46 impl CustomizeConnection<SqliteConnection, r2d2::Error> for DbConnOptions {
     47     fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), r2d2::Error> {
     48         conn.batch_execute("PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;")
     49             .map_err(r2d2::Error::QueryError)
     50     }
     51 }
     52 #[derive(Clone)]
     53 pub struct DbPool {
     54     // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
     55     pool: Option<Pool<ConnectionManager<SqliteConnection>>>,
     56     semaphore: Arc<Semaphore>,
     57 }
     58 impl Drop for DbConn {
     59     fn drop(&mut self) {
     60         let conn = Arc::clone(&self.conn);
     61         let permit = self.permit.take();
     62         // Since connection can't be on the stack in an async fn during an
     63         // await, we have to spawn a new blocking-safe thread...
     64         task::spawn_blocking(move || {
     65             // And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
     66             let mut conn = runtime::Handle::current().block_on(conn.lock_owned());
     67             if let Some(conn) = conn.take() {
     68                 drop(conn);
     69             }
     70             // Drop permit after the connection is dropped
     71             drop(permit);
     72         });
     73     }
     74 }
     75 impl Drop for DbPool {
     76     fn drop(&mut self) {
     77         let pool = self.pool.take();
     78         task::spawn_blocking(move || drop(pool));
     79     }
     80 }
     81 impl DbPool {
     82     // For the given database URL, guess its type, run migrations, create pool, and return it
     83     pub fn from_config() -> Result<Self, Error> {
     84         let url = Config::DATABASE_URL;
     85         let manager = ConnectionManager::new(url);
     86         let pool = Pool::builder()
     87             .max_size(NonZeroU32::from(config::get_config().database_max_conns).get())
     88             .connection_timeout(Duration::from_secs(u64::from(
     89                 config::get_config().database_timeout,
     90             )))
     91             .connection_customizer(Box::new(DbConnOptions))
     92             .build(manager)
     93             .map_res("Failed to create pool")?;
     94         Ok(Self {
     95             pool: Some(pool),
     96             semaphore: Arc::new(Semaphore::new(
     97                 NonZeroUsize::from(config::get_config().database_max_conns).get(),
     98             )),
     99         })
    100     }
    101     // Get a connection from the pool
    102     async fn get(&self) -> Result<DbConn, Error> {
    103         let duration = Duration::from_secs(u64::from(config::get_config().database_timeout));
    104         let permit = match timeout(duration, Arc::clone(&self.semaphore).acquire_owned()).await {
    105             Ok(p) => p.expect("Semaphore should be open"),
    106             Err(_) => {
    107                 err!("Timeout waiting for database connection");
    108             }
    109         };
    110         let pool = self
    111             .pool
    112             .as_ref()
    113             .expect("DbPool.pool should always be Some()")
    114             .clone();
    115         let c = run_blocking(move || pool.get_timeout(duration))
    116             .await
    117             .map_res("Error retrieving connection from pool")?;
    118         Ok(DbConn {
    119             conn: Arc::new(Mutex::new(Some(c))),
    120             permit: Some(permit),
    121         })
    122     }
    123 }
    124 #[macro_export]
    125 macro_rules! db_run {
    126     ( $conn:ident: $body:block ) => {
    127         #[allow(unused)]
    128         use diesel::prelude::*;
    129         use tokio::task;
    130         #[allow(unused)]
    131         use $crate::db::FromDb as _;
    132         let mut con = $conn.conn.clone().lock_owned().await;
    133         paste::paste! {
    134             #[allow(unused)] use $crate::db::schema::{self, *};
    135             #[allow(unused)] use __sqlite_model::*;
    136         }
    137         task::block_in_place(move || {
    138             let $conn = con
    139                 .as_mut()
    140                 .expect("internal invariant broken: self.connection is Some");
    141             $body
    142         }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
    143     };
    144 }
    145 
    146 trait FromDb {
    147     type Output;
    148     #[allow(clippy::wrong_self_convention)]
    149     fn from_db(self) -> Self::Output;
    150 }
    151 
    152 impl<T: FromDb> FromDb for Vec<T> {
    153     type Output = Vec<T::Output>;
    154     #[allow(clippy::wrong_self_convention)]
    155     #[inline]
    156     fn from_db(self) -> Self::Output {
    157         self.into_iter().map(FromDb::from_db).collect()
    158     }
    159 }
    160 
    161 impl<T: FromDb> FromDb for Option<T> {
    162     type Output = Option<T::Output>;
    163     #[allow(clippy::wrong_self_convention)]
    164     #[inline]
    165     fn from_db(self) -> Self::Output {
    166         self.map(FromDb::from_db)
    167     }
    168 }
    169 
    170 // For each struct eg. Cipher, we create a CipherDb inside a module named __sqlite_model,
    171 // to implement the Diesel traits. We also provide methods to convert between them and the basic structs. Later, that module will be auto imported when using db_run!
    172 #[macro_export]
    173 macro_rules! db_object {
    174     ( $(
    175         $( #[$attr:meta] )*
    176         pub struct $name:ident {
    177             $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty ),+
    178             $(,)?
    179         }
    180     )+ ) => {
    181         // Create the normal struct, without attributes
    182         $( pub struct $name { $( /*$( #[$field_attr] )**/ $vis $field : $typ, )+ } )+
    183         mod __sqlite_model     { $( db_object! { $( #[$attr] )* | $name |  $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
    184     };
    185 
    186     ( $( #[$attr:meta] )* | $name:ident | $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty),+) => {
    187         paste::paste! {
    188             #[allow(unused)] use super::*;
    189             #[allow(unused)] use diesel::prelude::*;
    190             #[allow(unused)] use $crate::db::schema::*;
    191 
    192             $( #[$attr] )*
    193             pub struct [<$name Db>] { $(
    194                 $( #[$field_attr] )* $vis $field : $typ,
    195             )+ }
    196 
    197             impl [<$name Db>] {
    198                 #[allow(clippy::used_underscore_binding, clippy::wrong_self_convention)]
    199                 #[inline] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } }
    200             }
    201 
    202             impl $crate::db::FromDb for [<$name Db>] {
    203                 type Output = super::$name;
    204                 #[allow(clippy::used_underscore_binding, clippy::wrong_self_convention)]
    205                 #[inline] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } }
    206             }
    207         }
    208     };
    209 }
    210 
    211 // Reexport the models, needs to be after the macros are defined so it can access them
    212 pub mod models;
    213 /// Attempts to retrieve a single connection from the managed database pool. If
    214 /// no pool is currently managed, fails with an `InternalServerError` status. If
    215 /// no connections are available, fails with a `ServiceUnavailable` status.
    216 #[rocket::async_trait]
    217 impl<'r> FromRequest<'r> for DbConn {
    218     type Error = ();
    219 
    220     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    221         match request.rocket().state::<DbPool>() {
    222             Some(p) => (p.get().await).map_or_else(
    223                 |_| Outcome::Error((Status::ServiceUnavailable, ())),
    224                 Outcome::Success,
    225             ),
    226             None => Outcome::Error((Status::InternalServerError, ())),
    227         }
    228     }
    229 }