config.rs (6261B)
1 use core::fmt::{self, Display, Formatter}; 2 use core::num::NonZeroU8; 3 use core::str; 4 use rocket::config::{CipherSuite, LogLevel, TlsConfig}; 5 use rocket::data::{Limits, ToByteUnit}; 6 use std::error; 7 use std::fs; 8 use std::io::Error; 9 use std::net::IpAddr; 10 use std::sync::OnceLock; 11 use toml::{self, de}; 12 use url::{ParseError, Url}; 13 use webauthn_rp::request::{AsciiDomain, RpId}; 14 static CONFIG: OnceLock<Config> = OnceLock::new(); 15 #[inline] 16 pub fn init_config() { 17 CONFIG 18 .set(Config::load().expect("valid TOML config file at 'config.toml'")) 19 .expect("CONFIG must only be initialized once"); 20 } 21 #[inline] 22 pub fn get_config() -> &'static Config { 23 CONFIG.get().expect("CONFIG must be initialized in main") 24 } 25 #[derive(Debug)] 26 pub enum ConfigErr { 27 Io(Error), 28 De(de::Error), 29 Url(ParseError), 30 BadDomain, 31 InvalidPasswordIterations(u32), 32 } 33 impl Display for ConfigErr { 34 #[inline] 35 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 36 match *self { 37 Self::Io(ref err) => err.fmt(f), 38 Self::De(ref err) => err.fmt(f), 39 Self::Url(ref err) => err.fmt(f), 40 Self::BadDomain => f.write_str( 41 "https://<domain>:<port> was unable to be parsed into a URL with a domain", 42 ), 43 Self::InvalidPasswordIterations(count) => write!( 44 f, 45 "password iterations is {count} but must be at least 100000" 46 ), 47 } 48 } 49 } 50 impl error::Error for ConfigErr {} 51 impl From<Error> for ConfigErr { 52 #[inline] 53 fn from(value: Error) -> Self { 54 Self::Io(value) 55 } 56 } 57 impl From<de::Error> for ConfigErr { 58 #[inline] 59 fn from(value: de::Error) -> Self { 60 Self::De(value) 61 } 62 } 63 impl From<ParseError> for ConfigErr { 64 #[inline] 65 fn from(value: ParseError) -> Self { 66 Self::Url(value) 67 } 68 } 69 #[derive(serde::Deserialize)] 70 #[serde(deny_unknown_fields)] 71 struct Tls { 72 ciphers: Option<Vec<CipherSuite>>, 73 cert: String, 74 key: String, 75 prefer_server_cipher_order: Option<bool>, 76 } 77 #[derive(serde::Deserialize)] 78 #[serde(deny_unknown_fields)] 79 struct ConfigFile { 80 database_max_conns: Option<NonZeroU8>, 81 database_timeout: Option<u16>, 82 db_connection_retries: Option<NonZeroU8>, 83 domain: String, 84 ip: IpAddr, 85 password_iterations: Option<u32>, 86 port: u16, 87 tls: Tls, 88 web_vault_enabled: Option<bool>, 89 workers: Option<NonZeroU8>, 90 } 91 #[derive(Debug)] 92 pub struct Config { 93 pub database_max_conns: NonZeroU8, 94 pub database_timeout: u16, 95 pub db_connection_retries: NonZeroU8, 96 domain_url: Url, 97 pub rp_id: RpId, 98 pub password_iterations: u32, 99 pub rocket: rocket::Config, 100 pub web_vault_enabled: bool, 101 } 102 impl Config { 103 #[inline] 104 pub fn load() -> Result<Self, ConfigErr> { 105 let config_file = 106 toml::from_str::<ConfigFile>(fs::read_to_string("config.toml")?.as_str())?; 107 let mut tls = TlsConfig::from_paths(config_file.tls.cert, config_file.tls.key); 108 tls = match config_file.tls.ciphers { 109 Some(ciphers) => { 110 if ciphers.is_empty() { 111 tls 112 } else { 113 tls.with_ciphers(ciphers) 114 } 115 } 116 None => tls, 117 }; 118 tls = match config_file.tls.prefer_server_cipher_order { 119 Some(prefer) => tls.with_preferred_server_cipher_order(prefer), 120 None => tls, 121 }; 122 let mut rocket = rocket::Config { 123 address: config_file.ip, 124 cli_colors: false, 125 limits: Limits::new() 126 .limit("json", 20i32.megabytes()) 127 .limit("data-form", 525i32.megabytes()) 128 .limit("file", 525i32.megabytes()), 129 log_level: LogLevel::Off, 130 port: config_file.port, 131 temp_dir: "data/tmp".into(), 132 tls: Some(tls), 133 ..Default::default() 134 }; 135 if let Some(count) = config_file.workers { 136 rocket.workers = usize::from(count.get()); 137 } 138 let domain = 139 AsciiDomain::try_from(config_file.domain).map_err(|_e| ConfigErr::BadDomain)?; 140 let url = format!( 141 "https://{}{}", 142 domain.as_ref(), 143 if config_file.port == 443 { 144 String::new() 145 } else { 146 format!(":{}", config_file.port) 147 } 148 ); 149 let domain_url = Url::parse(url.as_str())?; 150 Ok(Self { 151 database_max_conns: config_file 152 .database_max_conns 153 .unwrap_or(NonZeroU8::new(10).unwrap()), 154 database_timeout: config_file.database_timeout.unwrap_or(30), 155 db_connection_retries: config_file 156 .db_connection_retries 157 .unwrap_or(NonZeroU8::new(15).unwrap()), 158 domain_url, 159 rp_id: RpId::Domain(domain), 160 password_iterations: match config_file.password_iterations { 161 None => 600_000, 162 Some(count) => { 163 if count < 100_000u32 { 164 return Err(ConfigErr::InvalidPasswordIterations(count)); 165 } 166 count 167 } 168 }, 169 rocket, 170 web_vault_enabled: config_file.web_vault_enabled.unwrap_or(true), 171 }) 172 } 173 } 174 impl Config { 175 pub const DATA_FOLDER: &'static str = "data"; 176 pub const DATABASE_URL: &'static str = "data/db.sqlite3"; 177 pub const PRIVATE_ED25519_KEY: &'static str = "data/ed25519_key.pem"; 178 pub const WEB_VAULT_FOLDER: &'static str = "web-vault/"; 179 #[allow(clippy::arithmetic_side_effects, clippy::string_slice)] 180 #[inline] 181 pub fn domain_url(&self) -> &str { 182 let val = self.domain_url.as_str(); 183 // The last Unicode scalar value is '/' which is a 184 // single UTF-8 code unit, and we want to remove that. 185 // Note if this changes in the future such that the last 186 // Unicode scalar value is encoded using more than one 187 // UTF-8 code unit, then this will panic. 188 // Additionally if `len` is somehow 0, indexing will panic 189 // making this memory and logic safe. 190 &val[..val.len() - 1] 191 } 192 }