vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

config.rs (6640B)


      1 use core::fmt::{self, Display, Formatter};
      2 use core::num::NonZeroU8;
      3 use core::str;
      4 use rocket::config::{CipherSuite, LogLevel, TlsConfig};
      5 use rocket::data::{Limits, ToByteUnit};
      6 use std::error;
      7 use std::fs;
      8 use std::io::Error;
      9 use std::net::IpAddr;
     10 use std::sync::OnceLock;
     11 use toml::{self, de};
     12 use url::{ParseError, Url};
     13 static CONFIG: OnceLock<Config> = OnceLock::new();
     14 #[inline]
     15 pub fn init_config() {
     16     CONFIG
     17         .set(Config::load().expect("valid TOML config file at 'config.toml'"))
     18         .expect("CONFIG must only be initialized once");
     19 }
     20 #[inline]
     21 pub fn get_config() -> &'static Config {
     22     CONFIG.get().expect("CONFIG must be initialized in main")
     23 }
     24 #[derive(Debug)]
     25 pub enum ConfigErr {
     26     Io(Error),
     27     De(de::Error),
     28     Url(ParseError),
     29     BadDomain,
     30     InvalidPasswordIterations(u32),
     31 }
     32 impl Display for ConfigErr {
     33     #[inline]
     34     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
     35         match *self {
     36             Self::Io(ref err) => err.fmt(f),
     37             Self::De(ref err) => err.fmt(f),
     38             Self::Url(ref err) => err.fmt(f),
     39             Self::BadDomain => f.write_str(
     40                 "https://<domain>:<port> was unable to be parsed into a URL with a domain",
     41             ),
     42             Self::InvalidPasswordIterations(count) => write!(
     43                 f,
     44                 "password iterations is {count} but must be at least 100000"
     45             ),
     46         }
     47     }
     48 }
     49 impl error::Error for ConfigErr {}
     50 impl From<Error> for ConfigErr {
     51     #[inline]
     52     fn from(value: Error) -> Self {
     53         Self::Io(value)
     54     }
     55 }
     56 impl From<de::Error> for ConfigErr {
     57     #[inline]
     58     fn from(value: de::Error) -> Self {
     59         Self::De(value)
     60     }
     61 }
     62 impl From<ParseError> for ConfigErr {
     63     #[inline]
     64     fn from(value: ParseError) -> Self {
     65         Self::Url(value)
     66     }
     67 }
     68 #[derive(serde::Deserialize)]
     69 #[serde(deny_unknown_fields)]
     70 struct Tls {
     71     ciphers: Option<Vec<CipherSuite>>,
     72     cert: String,
     73     key: String,
     74     prefer_server_cipher_order: Option<bool>,
     75 }
     76 #[derive(serde::Deserialize)]
     77 #[serde(deny_unknown_fields)]
     78 struct ConfigFile {
     79     database_max_conns: Option<NonZeroU8>,
     80     database_timeout: Option<u16>,
     81     db_connection_retries: Option<NonZeroU8>,
     82     domain: String,
     83     ip: IpAddr,
     84     password_iterations: Option<u32>,
     85     port: u16,
     86     tls: Tls,
     87     web_vault_enabled: Option<bool>,
     88     workers: Option<NonZeroU8>,
     89 }
     90 #[derive(Debug)]
     91 pub struct Config {
     92     pub database_max_conns: NonZeroU8,
     93     pub database_timeout: u16,
     94     pub db_connection_retries: NonZeroU8,
     95     domain: Url,
     96     pub password_iterations: u32,
     97     pub rocket: rocket::Config,
     98     pub web_vault_enabled: bool,
     99 }
    100 impl Config {
    101     #[inline]
    102     pub fn load() -> Result<Self, ConfigErr> {
    103         let config_file =
    104             toml::from_str::<ConfigFile>(fs::read_to_string("config.toml")?.as_str())?;
    105         let mut tls = TlsConfig::from_paths(config_file.tls.cert, config_file.tls.key);
    106         tls = match config_file.tls.ciphers {
    107             Some(ciphers) => {
    108                 if ciphers.is_empty() {
    109                     tls
    110                 } else {
    111                     tls.with_ciphers(ciphers)
    112                 }
    113             }
    114             None => tls,
    115         };
    116         tls = match config_file.tls.prefer_server_cipher_order {
    117             Some(prefer) => tls.with_preferred_server_cipher_order(prefer),
    118             None => tls,
    119         };
    120         let mut rocket = rocket::Config {
    121             address: config_file.ip,
    122             cli_colors: false,
    123             limits: Limits::new()
    124                 .limit("json", 20i32.megabytes())
    125                 .limit("data-form", 525i32.megabytes())
    126                 .limit("file", 525i32.megabytes()),
    127             log_level: LogLevel::Off,
    128             port: config_file.port,
    129             temp_dir: "data/tmp".into(),
    130             tls: Some(tls),
    131             ..Default::default()
    132         };
    133         if let Some(count) = config_file.workers {
    134             rocket.workers = usize::from(count.get());
    135         }
    136         let url = format!(
    137             "https://{}{}",
    138             config_file.domain,
    139             if config_file.port == 443 {
    140                 String::new()
    141             } else {
    142                 format!(":{}", config_file.port)
    143             }
    144         );
    145         let domain = Url::parse(url.as_str())?;
    146         if domain
    147             .domain()
    148             // We only allow domains in the config file.
    149             // Note currently this check is overly conservative and
    150             // disallows any domains that `Url` will encode in Punycode.
    151             .map_or(true, |dom| !dom.eq_ignore_ascii_case(&config_file.domain))
    152         {
    153             return Err(ConfigErr::BadDomain);
    154         }
    155         Ok(Self {
    156             database_max_conns: config_file
    157                 .database_max_conns
    158                 .unwrap_or(NonZeroU8::new(10).unwrap()),
    159             database_timeout: config_file.database_timeout.unwrap_or(30),
    160             db_connection_retries: config_file
    161                 .db_connection_retries
    162                 .unwrap_or(NonZeroU8::new(15).unwrap()),
    163             domain,
    164             password_iterations: match config_file.password_iterations {
    165                 None => 600_000,
    166                 Some(count) => {
    167                     if count < 100_000u32 {
    168                         return Err(ConfigErr::InvalidPasswordIterations(count));
    169                     }
    170                     count
    171                 }
    172             },
    173             rocket,
    174             web_vault_enabled: config_file.web_vault_enabled.unwrap_or(true),
    175         })
    176     }
    177 }
    178 impl Config {
    179     pub const DATA_FOLDER: &'static str = "data";
    180     pub const DATABASE_URL: &'static str = "data/db.sqlite3";
    181     pub const PRIVATE_ED25519_KEY: &'static str = "data/ed25519_key.pem";
    182     pub const WEB_VAULT_FOLDER: &'static str = "web-vault/";
    183     #[allow(clippy::arithmetic_side_effects, clippy::string_slice)]
    184     #[inline]
    185     pub fn domain_url(&self) -> &str {
    186         let val = self.domain.as_str();
    187         // The last Unicode scalar value is '/' which is a
    188         // single UTF-8 code unit, and we want to remove that.
    189         // Note if this changes in the future such that the last
    190         // Unicode scalar value is encoded using more than one
    191         // UTF-8 code unit, then this will panic.
    192         // Additionally if `len` is somehow 0, indexing will panic
    193         // making this memory and logic safe.
    194         &val[..val.len() - 1]
    195     }
    196     #[inline]
    197     pub fn domain(&self) -> &str {
    198         self.domain.domain().expect("impossible to error")
    199     }
    200     #[inline]
    201     pub fn domain_origin(&self) -> String {
    202         self.domain.origin().ascii_serialization()
    203     }
    204 }