vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

mod.rs (8182B)


      1 use crate::{
      2     config::{self, Config},
      3     error::{Error, MapResult},
      4 };
      5 use diesel::{
      6     connection::SimpleConnection,
      7     r2d2::{self, ConnectionManager, CustomizeConnection, Pool, PooledConnection},
      8     SqliteConnection,
      9 };
     10 use rocket::{
     11     http::Status,
     12     request::{FromRequest, Outcome},
     13     Request,
     14 };
     15 use std::{panic, sync::Arc, time::Duration};
     16 use tokio::{
     17     runtime,
     18     sync::{Mutex, OwnedSemaphorePermit, Semaphore},
     19     task,
     20     time::timeout,
     21 };
     22 mod schema;
     23 
     24 // These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
     25 // A wrapper around spawn_blocking that propagates panics to the calling code.
     26 async fn run_blocking<F, R>(job: F) -> R
     27 where
     28     F: FnOnce() -> R + Send + 'static,
     29     R: Send + 'static,
     30 {
     31     match task::spawn_blocking(job).await {
     32         Ok(ret) => ret,
     33         Err(e) => e.try_into_panic().map_or_else(
     34             |_| unreachable!("spawn_blocking tasks are never cancelled"),
     35             |panic| panic::resume_unwind(panic),
     36         ),
     37     }
     38 }
     39 pub struct DbConn {
     40     conn: Arc<Mutex<Option<PooledConnection<ConnectionManager<SqliteConnection>>>>>,
     41     permit: Option<OwnedSemaphorePermit>,
     42 }
     43 #[derive(Debug)]
     44 struct DbConnOptions;
     45 impl CustomizeConnection<SqliteConnection, r2d2::Error> for DbConnOptions {
     46     fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), r2d2::Error> {
     47         conn.batch_execute("PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;")
     48             .map_err(r2d2::Error::QueryError)
     49     }
     50 }
     51 #[derive(Clone)]
     52 pub struct DbPool {
     53     // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
     54     pool: Option<Pool<ConnectionManager<SqliteConnection>>>,
     55     semaphore: Arc<Semaphore>,
     56 }
     57 impl Drop for DbConn {
     58     fn drop(&mut self) {
     59         let conn = Arc::clone(&self.conn);
     60         let permit = self.permit.take();
     61         // Since connection can't be on the stack in an async fn during an
     62         // await, we have to spawn a new blocking-safe thread...
     63         task::spawn_blocking(move || {
     64             // And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
     65             let mut conn = runtime::Handle::current().block_on(conn.lock_owned());
     66             if let Some(conn) = conn.take() {
     67                 drop(conn);
     68             }
     69             // Drop permit after the connection is dropped
     70             drop(permit);
     71         });
     72     }
     73 }
     74 impl Drop for DbPool {
     75     fn drop(&mut self) {
     76         let pool = self.pool.take();
     77         task::spawn_blocking(move || drop(pool));
     78     }
     79 }
     80 impl DbPool {
     81     // For the given database URL, guess its type, run migrations, create pool, and return it
     82     pub fn from_config() -> Result<Self, Error> {
     83         let url = Config::DATABASE_URL;
     84         let manager = ConnectionManager::new(url);
     85         let pool = Pool::builder()
     86             .max_size(u32::from(config::get_config().database_max_conns.get()))
     87             .connection_timeout(Duration::from_secs(u64::from(
     88                 config::get_config().database_timeout,
     89             )))
     90             .connection_customizer(Box::new(DbConnOptions))
     91             .build(manager)
     92             .map_res("Failed to create pool")?;
     93         Ok(Self {
     94             pool: Some(pool),
     95             semaphore: Arc::new(Semaphore::new(usize::from(
     96                 config::get_config().database_max_conns.get(),
     97             ))),
     98         })
     99     }
    100     // Get a connection from the pool
    101     async fn get(&self) -> Result<DbConn, Error> {
    102         let duration = Duration::from_secs(u64::from(config::get_config().database_timeout));
    103         let permit = match timeout(duration, Arc::clone(&self.semaphore).acquire_owned()).await {
    104             Ok(p) => p.expect("Semaphore should be open"),
    105             Err(_) => {
    106                 err!("Timeout waiting for database connection");
    107             }
    108         };
    109         let pool = self
    110             .pool
    111             .as_ref()
    112             .expect("DbPool.pool should always be Some()")
    113             .clone();
    114         let c = run_blocking(move || pool.get_timeout(duration))
    115             .await
    116             .map_res("Error retrieving connection from pool")?;
    117         Ok(DbConn {
    118             conn: Arc::new(Mutex::new(Some(c))),
    119             permit: Some(permit),
    120         })
    121     }
    122 }
    123 #[macro_export]
    124 macro_rules! db_run {
    125     ( $conn:ident: $body:block ) => {
    126         #[allow(unused)]
    127         use diesel::prelude::*;
    128         use tokio::task;
    129         #[allow(unused)]
    130         use $crate::db::FromDb;
    131         let mut con = $conn.conn.clone().lock_owned().await;
    132         paste::paste! {
    133             #[allow(unused)] use $crate::db::schema::{self, *};
    134             #[allow(unused)] use __sqlite_model::*;
    135         }
    136         task::block_in_place(move || {
    137             let $conn = con
    138                 .as_mut()
    139                 .expect("internal invariant broken: self.connection is Some");
    140             $body
    141         }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
    142     };
    143 }
    144 
    145 trait FromDb {
    146     type Output;
    147     #[allow(clippy::wrong_self_convention)]
    148     fn from_db(self) -> Self::Output;
    149 }
    150 
    151 impl<T: FromDb> FromDb for Vec<T> {
    152     type Output = Vec<T::Output>;
    153     #[allow(clippy::wrong_self_convention)]
    154     #[inline]
    155     fn from_db(self) -> Self::Output {
    156         self.into_iter().map(FromDb::from_db).collect()
    157     }
    158 }
    159 
    160 impl<T: FromDb> FromDb for Option<T> {
    161     type Output = Option<T::Output>;
    162     #[allow(clippy::wrong_self_convention)]
    163     #[inline]
    164     fn from_db(self) -> Self::Output {
    165         self.map(FromDb::from_db)
    166     }
    167 }
    168 
    169 // For each struct eg. Cipher, we create a CipherDb inside a module named __sqlite_model,
    170 // to implement the Diesel traits. We also provide methods to convert between them and the basic structs. Later, that module will be auto imported when using db_run!
    171 #[macro_export]
    172 macro_rules! db_object {
    173     ( $(
    174         $( #[$attr:meta] )*
    175         pub struct $name:ident {
    176             $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty ),+
    177             $(,)?
    178         }
    179     )+ ) => {
    180         // Create the normal struct, without attributes
    181         $( pub struct $name { $( /*$( #[$field_attr] )**/ $vis $field : $typ, )+ } )+
    182         mod __sqlite_model     { $( db_object! { $( #[$attr] )* | $name |  $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
    183     };
    184 
    185     ( $( #[$attr:meta] )* | $name:ident | $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty),+) => {
    186         paste::paste! {
    187             #[allow(unused)] use super::*;
    188             #[allow(unused)] use diesel::prelude::*;
    189             #[allow(unused)] use $crate::db::schema::*;
    190 
    191             $( #[$attr] )*
    192             pub struct [<$name Db>] { $(
    193                 $( #[$field_attr] )* $vis $field : $typ,
    194             )+ }
    195 
    196             impl [<$name Db>] {
    197                 #[allow(clippy::used_underscore_binding, clippy::wrong_self_convention)]
    198                 #[inline] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } }
    199             }
    200 
    201             impl $crate::db::FromDb for [<$name Db>] {
    202                 type Output = super::$name;
    203                 #[allow(clippy::used_underscore_binding, clippy::wrong_self_convention)]
    204                 #[inline] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } }
    205             }
    206         }
    207     };
    208 }
    209 
    210 // Reexport the models, needs to be after the macros are defined so it can access them
    211 pub mod models;
    212 /// Attempts to retrieve a single connection from the managed database pool. If
    213 /// no pool is currently managed, fails with an `InternalServerError` status. If
    214 /// no connections are available, fails with a `ServiceUnavailable` status.
    215 #[rocket::async_trait]
    216 impl<'r> FromRequest<'r> for DbConn {
    217     type Error = ();
    218 
    219     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    220         match request.rocket().state::<DbPool>() {
    221             Some(p) => (p.get().await).map_or_else(
    222                 |_| Outcome::Error((Status::ServiceUnavailable, ())),
    223                 Outcome::Success,
    224             ),
    225             None => Outcome::Error((Status::InternalServerError, ())),
    226         }
    227     }
    228 }