vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

config.rs (6286B)


      1 use core::{
      2     fmt::{self, Display, Formatter},
      3     num::{NonZeroU8, NonZeroUsize},
      4     str,
      5 };
      6 use rocket::config::{CipherSuite, LogLevel, TlsConfig};
      7 use rocket::data::{Limits, ToByteUnit as _};
      8 use std::error;
      9 use std::fs;
     10 use std::io::Error;
     11 use std::net::IpAddr;
     12 use std::sync::OnceLock;
     13 use toml::{self, de};
     14 use url::{ParseError, Url};
     15 use webauthn_rp::request::{AsciiDomain, RpId};
     16 static CONFIG: OnceLock<Config> = OnceLock::new();
     17 #[inline]
     18 pub fn init_config() {
     19     CONFIG
     20         .set(Config::load().expect("valid TOML config file at 'config.toml'"))
     21         .expect("CONFIG must only be initialized once");
     22 }
     23 #[inline]
     24 pub fn get_config() -> &'static Config {
     25     CONFIG.get().expect("CONFIG must be initialized in main")
     26 }
     27 #[derive(Debug)]
     28 pub enum ConfigErr {
     29     Io(Error),
     30     De(de::Error),
     31     Url(ParseError),
     32     BadDomain,
     33     InvalidPasswordIterations(u32),
     34 }
     35 impl Display for ConfigErr {
     36     #[inline]
     37     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
     38         match *self {
     39             Self::Io(ref err) => err.fmt(f),
     40             Self::De(ref err) => err.fmt(f),
     41             Self::Url(ref err) => err.fmt(f),
     42             Self::BadDomain => f.write_str(
     43                 "https://<domain>:<port> was unable to be parsed into a URL with a domain",
     44             ),
     45             Self::InvalidPasswordIterations(count) => write!(
     46                 f,
     47                 "password iterations is {count} but must be at least 100000"
     48             ),
     49         }
     50     }
     51 }
     52 impl error::Error for ConfigErr {}
     53 impl From<Error> for ConfigErr {
     54     #[inline]
     55     fn from(value: Error) -> Self {
     56         Self::Io(value)
     57     }
     58 }
     59 impl From<de::Error> for ConfigErr {
     60     #[inline]
     61     fn from(value: de::Error) -> Self {
     62         Self::De(value)
     63     }
     64 }
     65 impl From<ParseError> for ConfigErr {
     66     #[inline]
     67     fn from(value: ParseError) -> Self {
     68         Self::Url(value)
     69     }
     70 }
     71 #[derive(serde::Deserialize)]
     72 #[serde(deny_unknown_fields)]
     73 struct Tls {
     74     ciphers: Option<Vec<CipherSuite>>,
     75     cert: String,
     76     key: String,
     77     prefer_server_cipher_order: Option<bool>,
     78 }
     79 #[derive(serde::Deserialize)]
     80 #[serde(deny_unknown_fields)]
     81 struct ConfigFile {
     82     database_max_conns: Option<NonZeroU8>,
     83     database_timeout: Option<u16>,
     84     db_connection_retries: Option<NonZeroU8>,
     85     domain: String,
     86     ip: IpAddr,
     87     password_iterations: Option<u32>,
     88     port: u16,
     89     tls: Tls,
     90     web_vault_enabled: Option<bool>,
     91     workers: Option<NonZeroU8>,
     92 }
     93 #[derive(Debug)]
     94 pub struct Config {
     95     pub database_max_conns: NonZeroU8,
     96     pub database_timeout: u16,
     97     pub db_connection_retries: NonZeroU8,
     98     domain_url: Url,
     99     pub rp_id: RpId,
    100     pub password_iterations: u32,
    101     pub rocket: rocket::Config,
    102     pub web_vault_enabled: bool,
    103 }
    104 impl Config {
    105     #[inline]
    106     pub fn load() -> Result<Self, ConfigErr> {
    107         let config_file =
    108             toml::from_str::<ConfigFile>(fs::read_to_string("config.toml")?.as_str())?;
    109         let mut tls = TlsConfig::from_paths(config_file.tls.cert, config_file.tls.key);
    110         tls = match config_file.tls.ciphers {
    111             Some(ciphers) => {
    112                 if ciphers.is_empty() {
    113                     tls
    114                 } else {
    115                     tls.with_ciphers(ciphers)
    116                 }
    117             }
    118             None => tls,
    119         };
    120         tls = match config_file.tls.prefer_server_cipher_order {
    121             Some(prefer) => tls.with_preferred_server_cipher_order(prefer),
    122             None => tls,
    123         };
    124         let mut rocket = rocket::Config {
    125             address: config_file.ip,
    126             cli_colors: false,
    127             limits: Limits::new()
    128                 .limit("json", 20i32.megabytes())
    129                 .limit("data-form", 525i32.megabytes())
    130                 .limit("file", 525i32.megabytes()),
    131             log_level: LogLevel::Off,
    132             port: config_file.port,
    133             temp_dir: "data/tmp".into(),
    134             tls: Some(tls),
    135             ..Default::default()
    136         };
    137         if let Some(count) = config_file.workers {
    138             rocket.workers = NonZeroUsize::from(count).get();
    139         }
    140         let domain =
    141             AsciiDomain::try_from(config_file.domain).map_err(|_e| ConfigErr::BadDomain)?;
    142         let url = format!(
    143             "https://{}{}",
    144             domain.as_ref(),
    145             if config_file.port == 443 {
    146                 String::new()
    147             } else {
    148                 format!(":{}", config_file.port)
    149             }
    150         );
    151         let domain_url = Url::parse(url.as_str())?;
    152         Ok(Self {
    153             database_max_conns: config_file
    154                 .database_max_conns
    155                 .unwrap_or(NonZeroU8::new(10).unwrap()),
    156             database_timeout: config_file.database_timeout.unwrap_or(30),
    157             db_connection_retries: config_file
    158                 .db_connection_retries
    159                 .unwrap_or(NonZeroU8::new(15).unwrap()),
    160             domain_url,
    161             rp_id: RpId::Domain(domain),
    162             password_iterations: match config_file.password_iterations {
    163                 None => 600_000,
    164                 Some(count) => {
    165                     if count < 100_000u32 {
    166                         return Err(ConfigErr::InvalidPasswordIterations(count));
    167                     }
    168                     count
    169                 }
    170             },
    171             rocket,
    172             web_vault_enabled: config_file.web_vault_enabled.unwrap_or(true),
    173         })
    174     }
    175 }
    176 impl Config {
    177     pub const DATA_FOLDER: &'static str = "data";
    178     pub const DATABASE_URL: &'static str = "data/db.sqlite3";
    179     pub const PRIVATE_ED25519_KEY: &'static str = "data/ed25519_key.pem";
    180     pub const WEB_VAULT_FOLDER: &'static str = "web-vault/";
    181     #[allow(clippy::arithmetic_side_effects, clippy::string_slice)]
    182     #[inline]
    183     pub fn domain_url(&self) -> &str {
    184         let val = self.domain_url.as_str();
    185         // The last Unicode scalar value is '/' which is a
    186         // single UTF-8 code unit, and we want to remove that.
    187         // Note if this changes in the future such that the last
    188         // Unicode scalar value is encoded using more than one
    189         // UTF-8 code unit, then this will panic.
    190         // Additionally if `len` is somehow 0, indexing will panic
    191         // making this memory and logic safe.
    192         &val[..val.len() - 1]
    193     }
    194 }