vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

config.rs (6261B)


      1 use core::fmt::{self, Display, Formatter};
      2 use core::num::NonZeroU8;
      3 use core::str;
      4 use rocket::config::{CipherSuite, LogLevel, TlsConfig};
      5 use rocket::data::{Limits, ToByteUnit};
      6 use std::error;
      7 use std::fs;
      8 use std::io::Error;
      9 use std::net::IpAddr;
     10 use std::sync::OnceLock;
     11 use toml::{self, de};
     12 use url::{ParseError, Url};
     13 use webauthn_rp::request::{AsciiDomain, RpId};
     14 static CONFIG: OnceLock<Config> = OnceLock::new();
     15 #[inline]
     16 pub fn init_config() {
     17     CONFIG
     18         .set(Config::load().expect("valid TOML config file at 'config.toml'"))
     19         .expect("CONFIG must only be initialized once");
     20 }
     21 #[inline]
     22 pub fn get_config() -> &'static Config {
     23     CONFIG.get().expect("CONFIG must be initialized in main")
     24 }
     25 #[derive(Debug)]
     26 pub enum ConfigErr {
     27     Io(Error),
     28     De(de::Error),
     29     Url(ParseError),
     30     BadDomain,
     31     InvalidPasswordIterations(u32),
     32 }
     33 impl Display for ConfigErr {
     34     #[inline]
     35     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
     36         match *self {
     37             Self::Io(ref err) => err.fmt(f),
     38             Self::De(ref err) => err.fmt(f),
     39             Self::Url(ref err) => err.fmt(f),
     40             Self::BadDomain => f.write_str(
     41                 "https://<domain>:<port> was unable to be parsed into a URL with a domain",
     42             ),
     43             Self::InvalidPasswordIterations(count) => write!(
     44                 f,
     45                 "password iterations is {count} but must be at least 100000"
     46             ),
     47         }
     48     }
     49 }
     50 impl error::Error for ConfigErr {}
     51 impl From<Error> for ConfigErr {
     52     #[inline]
     53     fn from(value: Error) -> Self {
     54         Self::Io(value)
     55     }
     56 }
     57 impl From<de::Error> for ConfigErr {
     58     #[inline]
     59     fn from(value: de::Error) -> Self {
     60         Self::De(value)
     61     }
     62 }
     63 impl From<ParseError> for ConfigErr {
     64     #[inline]
     65     fn from(value: ParseError) -> Self {
     66         Self::Url(value)
     67     }
     68 }
     69 #[derive(serde::Deserialize)]
     70 #[serde(deny_unknown_fields)]
     71 struct Tls {
     72     ciphers: Option<Vec<CipherSuite>>,
     73     cert: String,
     74     key: String,
     75     prefer_server_cipher_order: Option<bool>,
     76 }
     77 #[derive(serde::Deserialize)]
     78 #[serde(deny_unknown_fields)]
     79 struct ConfigFile {
     80     database_max_conns: Option<NonZeroU8>,
     81     database_timeout: Option<u16>,
     82     db_connection_retries: Option<NonZeroU8>,
     83     domain: String,
     84     ip: IpAddr,
     85     password_iterations: Option<u32>,
     86     port: u16,
     87     tls: Tls,
     88     web_vault_enabled: Option<bool>,
     89     workers: Option<NonZeroU8>,
     90 }
     91 #[derive(Debug)]
     92 pub struct Config {
     93     pub database_max_conns: NonZeroU8,
     94     pub database_timeout: u16,
     95     pub db_connection_retries: NonZeroU8,
     96     domain_url: Url,
     97     pub rp_id: RpId,
     98     pub password_iterations: u32,
     99     pub rocket: rocket::Config,
    100     pub web_vault_enabled: bool,
    101 }
    102 impl Config {
    103     #[inline]
    104     pub fn load() -> Result<Self, ConfigErr> {
    105         let config_file =
    106             toml::from_str::<ConfigFile>(fs::read_to_string("config.toml")?.as_str())?;
    107         let mut tls = TlsConfig::from_paths(config_file.tls.cert, config_file.tls.key);
    108         tls = match config_file.tls.ciphers {
    109             Some(ciphers) => {
    110                 if ciphers.is_empty() {
    111                     tls
    112                 } else {
    113                     tls.with_ciphers(ciphers)
    114                 }
    115             }
    116             None => tls,
    117         };
    118         tls = match config_file.tls.prefer_server_cipher_order {
    119             Some(prefer) => tls.with_preferred_server_cipher_order(prefer),
    120             None => tls,
    121         };
    122         let mut rocket = rocket::Config {
    123             address: config_file.ip,
    124             cli_colors: false,
    125             limits: Limits::new()
    126                 .limit("json", 20i32.megabytes())
    127                 .limit("data-form", 525i32.megabytes())
    128                 .limit("file", 525i32.megabytes()),
    129             log_level: LogLevel::Off,
    130             port: config_file.port,
    131             temp_dir: "data/tmp".into(),
    132             tls: Some(tls),
    133             ..Default::default()
    134         };
    135         if let Some(count) = config_file.workers {
    136             rocket.workers = usize::from(count.get());
    137         }
    138         let domain =
    139             AsciiDomain::try_from(config_file.domain).map_err(|_e| ConfigErr::BadDomain)?;
    140         let url = format!(
    141             "https://{}{}",
    142             domain.as_ref(),
    143             if config_file.port == 443 {
    144                 String::new()
    145             } else {
    146                 format!(":{}", config_file.port)
    147             }
    148         );
    149         let domain_url = Url::parse(url.as_str())?;
    150         Ok(Self {
    151             database_max_conns: config_file
    152                 .database_max_conns
    153                 .unwrap_or(NonZeroU8::new(10).unwrap()),
    154             database_timeout: config_file.database_timeout.unwrap_or(30),
    155             db_connection_retries: config_file
    156                 .db_connection_retries
    157                 .unwrap_or(NonZeroU8::new(15).unwrap()),
    158             domain_url,
    159             rp_id: RpId::Domain(domain),
    160             password_iterations: match config_file.password_iterations {
    161                 None => 600_000,
    162                 Some(count) => {
    163                     if count < 100_000u32 {
    164                         return Err(ConfigErr::InvalidPasswordIterations(count));
    165                     }
    166                     count
    167                 }
    168             },
    169             rocket,
    170             web_vault_enabled: config_file.web_vault_enabled.unwrap_or(true),
    171         })
    172     }
    173 }
    174 impl Config {
    175     pub const DATA_FOLDER: &'static str = "data";
    176     pub const DATABASE_URL: &'static str = "data/db.sqlite3";
    177     pub const PRIVATE_ED25519_KEY: &'static str = "data/ed25519_key.pem";
    178     pub const WEB_VAULT_FOLDER: &'static str = "web-vault/";
    179     #[allow(clippy::arithmetic_side_effects, clippy::string_slice)]
    180     #[inline]
    181     pub fn domain_url(&self) -> &str {
    182         let val = self.domain_url.as_str();
    183         // The last Unicode scalar value is '/' which is a
    184         // single UTF-8 code unit, and we want to remove that.
    185         // Note if this changes in the future such that the last
    186         // Unicode scalar value is encoded using more than one
    187         // UTF-8 code unit, then this will panic.
    188         // Additionally if `len` is somehow 0, indexing will panic
    189         // making this memory and logic safe.
    190         &val[..val.len() - 1]
    191     }
    192 }