vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

commit b9c434addbbfcdf4ce430182181d7962053eb951
parent 7f61dd5fe3d27fc3dc5c580d842972c273678362
Author: Daniel GarcĂ­a <dani-garcia@users.noreply.github.com>
Date:   Wed, 11 May 2022 21:36:11 +0200

Merge branch 'jjlin-db-conn-init' into main

Diffstat:
M.env.template | 9+++++++++
Msrc/config.rs | 5++++-
Msrc/db/mod.rs | 46++++++++++++++++++++++++++++++++++++++++++++--
3 files changed, 57 insertions(+), 3 deletions(-)

diff --git a/.env.template b/.env.template @@ -29,6 +29,15 @@ ## Define the size of the connection pool used for connecting to the database. # DATABASE_MAX_CONNS=10 +## Database connection initialization +## Allows SQL statements to be run whenever a new database connection is created. +## This is mainly useful for connection-scoped pragmas. +## If empty, a database-specific default is used: +## - SQLite: "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;" +## - MySQL: "" +## - PostgreSQL: "" +# DATABASE_CONN_INIT="" + ## Individual folders, these override %DATA_FOLDER% # RSA_KEY_FILENAME=data/rsa_key # ICON_CACHE_FOLDER=data/icon_cache diff --git a/src/config.rs b/src/config.rs @@ -515,11 +515,14 @@ make_config! { db_connection_retries: u32, false, def, 15; /// Timeout when aquiring database connection - database_timeout: u64, false, def, 30; + database_timeout: u64, false, def, 30; /// Database connection pool size database_max_conns: u32, false, def, 10; + /// Database connection init |> SQL statements to run when creating a new database connection, mainly useful for connection-scoped pragmas. If empty, a database-specific default is used. + database_conn_init: String, false, def, "".to_string(); + /// Bypass admin page security (Know the risks!) |> Disables the Admin Token for the admin page so you may use your own auth in-front disable_admin_token: bool, true, def, false; diff --git a/src/db/mod.rs b/src/db/mod.rs @@ -1,6 +1,10 @@ use std::{sync::Arc, time::Duration}; -use diesel::r2d2::{ConnectionManager, Pool, PooledConnection}; +use diesel::{ + connection::SimpleConnection, + r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection}, +}; + use rocket::{ http::Status, outcome::IntoOutcome, @@ -62,6 +66,23 @@ macro_rules! generate_connections { #[allow(non_camel_case_types)] pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ } + #[derive(Debug)] + pub struct DbConnOptions { + pub init_stmts: String, + } + + $( // Based on <https://stackoverflow.com/a/57717533>. + #[cfg($name)] + impl CustomizeConnection<$ty, diesel::r2d2::Error> for DbConnOptions { + fn on_acquire(&self, conn: &mut $ty) -> Result<(), diesel::r2d2::Error> { + (|| { + if !self.init_stmts.is_empty() { + conn.batch_execute(&self.init_stmts)?; + } + Ok(()) + })().map_err(diesel::r2d2::Error::QueryError) + } + })+ #[derive(Clone)] pub struct DbPool { @@ -103,7 +124,8 @@ macro_rules! generate_connections { } impl DbPool { - // For the given database URL, guess it's type, run migrations create pool and return it + // For the given database URL, guess its type, run migrations, create pool, and return it + #[allow(clippy::diverging_sub_expression)] pub fn from_config() -> Result<Self, Error> { let url = CONFIG.database_url(); let conn_type = DbConnType::from_url(&url)?; @@ -117,6 +139,9 @@ macro_rules! generate_connections { let pool = Pool::builder() .max_size(CONFIG.database_max_conns()) .connection_timeout(Duration::from_secs(CONFIG.database_timeout())) + .connection_customizer(Box::new(DbConnOptions{ + init_stmts: conn_type.get_init_stmts() + })) .build(manager) .map_res("Failed to create pool")?; return Ok(DbPool { @@ -190,6 +215,23 @@ impl DbConnType { err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled") } } + + pub fn get_init_stmts(&self) -> String { + let init_stmts = CONFIG.database_conn_init(); + if !init_stmts.is_empty() { + init_stmts + } else { + self.default_init_stmts() + } + } + + pub fn default_init_stmts(&self) -> String { + match self { + Self::sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(), + Self::mysql => "".to_string(), + Self::postgresql => "".to_string(), + } + } } #[macro_export]