config.rs (6286B)
1 use core::{ 2 fmt::{self, Display, Formatter}, 3 num::{NonZeroU8, NonZeroUsize}, 4 str, 5 }; 6 use rocket::config::{CipherSuite, LogLevel, TlsConfig}; 7 use rocket::data::{Limits, ToByteUnit as _}; 8 use std::error; 9 use std::fs; 10 use std::io::Error; 11 use std::net::IpAddr; 12 use std::sync::OnceLock; 13 use toml::{self, de}; 14 use url::{ParseError, Url}; 15 use webauthn_rp::request::{AsciiDomain, RpId}; 16 static CONFIG: OnceLock<Config> = OnceLock::new(); 17 #[inline] 18 pub fn init_config() { 19 CONFIG 20 .set(Config::load().expect("valid TOML config file at 'config.toml'")) 21 .expect("CONFIG must only be initialized once"); 22 } 23 #[inline] 24 pub fn get_config() -> &'static Config { 25 CONFIG.get().expect("CONFIG must be initialized in main") 26 } 27 #[derive(Debug)] 28 pub enum ConfigErr { 29 Io(Error), 30 De(de::Error), 31 Url(ParseError), 32 BadDomain, 33 InvalidPasswordIterations(u32), 34 } 35 impl Display for ConfigErr { 36 #[inline] 37 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 38 match *self { 39 Self::Io(ref err) => err.fmt(f), 40 Self::De(ref err) => err.fmt(f), 41 Self::Url(ref err) => err.fmt(f), 42 Self::BadDomain => f.write_str( 43 "https://<domain>:<port> was unable to be parsed into a URL with a domain", 44 ), 45 Self::InvalidPasswordIterations(count) => write!( 46 f, 47 "password iterations is {count} but must be at least 100000" 48 ), 49 } 50 } 51 } 52 impl error::Error for ConfigErr {} 53 impl From<Error> for ConfigErr { 54 #[inline] 55 fn from(value: Error) -> Self { 56 Self::Io(value) 57 } 58 } 59 impl From<de::Error> for ConfigErr { 60 #[inline] 61 fn from(value: de::Error) -> Self { 62 Self::De(value) 63 } 64 } 65 impl From<ParseError> for ConfigErr { 66 #[inline] 67 fn from(value: ParseError) -> Self { 68 Self::Url(value) 69 } 70 } 71 #[derive(serde::Deserialize)] 72 #[serde(deny_unknown_fields)] 73 struct Tls { 74 ciphers: Option<Vec<CipherSuite>>, 75 cert: String, 76 key: String, 77 prefer_server_cipher_order: Option<bool>, 78 } 79 #[derive(serde::Deserialize)] 80 #[serde(deny_unknown_fields)] 81 struct ConfigFile { 82 database_max_conns: Option<NonZeroU8>, 83 database_timeout: Option<u16>, 84 db_connection_retries: Option<NonZeroU8>, 85 domain: String, 86 ip: IpAddr, 87 password_iterations: Option<u32>, 88 port: u16, 89 tls: Tls, 90 web_vault_enabled: Option<bool>, 91 workers: Option<NonZeroU8>, 92 } 93 #[derive(Debug)] 94 pub struct Config { 95 pub database_max_conns: NonZeroU8, 96 pub database_timeout: u16, 97 pub db_connection_retries: NonZeroU8, 98 domain_url: Url, 99 pub rp_id: RpId, 100 pub password_iterations: u32, 101 pub rocket: rocket::Config, 102 pub web_vault_enabled: bool, 103 } 104 impl Config { 105 #[inline] 106 pub fn load() -> Result<Self, ConfigErr> { 107 let config_file = 108 toml::from_str::<ConfigFile>(fs::read_to_string("config.toml")?.as_str())?; 109 let mut tls = TlsConfig::from_paths(config_file.tls.cert, config_file.tls.key); 110 tls = match config_file.tls.ciphers { 111 Some(ciphers) => { 112 if ciphers.is_empty() { 113 tls 114 } else { 115 tls.with_ciphers(ciphers) 116 } 117 } 118 None => tls, 119 }; 120 tls = match config_file.tls.prefer_server_cipher_order { 121 Some(prefer) => tls.with_preferred_server_cipher_order(prefer), 122 None => tls, 123 }; 124 let mut rocket = rocket::Config { 125 address: config_file.ip, 126 cli_colors: false, 127 limits: Limits::new() 128 .limit("json", 20i32.megabytes()) 129 .limit("data-form", 525i32.megabytes()) 130 .limit("file", 525i32.megabytes()), 131 log_level: LogLevel::Off, 132 port: config_file.port, 133 temp_dir: "data/tmp".into(), 134 tls: Some(tls), 135 ..Default::default() 136 }; 137 if let Some(count) = config_file.workers { 138 rocket.workers = NonZeroUsize::from(count).get(); 139 } 140 let domain = 141 AsciiDomain::try_from(config_file.domain).map_err(|_e| ConfigErr::BadDomain)?; 142 let url = format!( 143 "https://{}{}", 144 domain.as_ref(), 145 if config_file.port == 443 { 146 String::new() 147 } else { 148 format!(":{}", config_file.port) 149 } 150 ); 151 let domain_url = Url::parse(url.as_str())?; 152 Ok(Self { 153 database_max_conns: config_file 154 .database_max_conns 155 .unwrap_or(NonZeroU8::new(10).unwrap()), 156 database_timeout: config_file.database_timeout.unwrap_or(30), 157 db_connection_retries: config_file 158 .db_connection_retries 159 .unwrap_or(NonZeroU8::new(15).unwrap()), 160 domain_url, 161 rp_id: RpId::Domain(domain), 162 password_iterations: match config_file.password_iterations { 163 None => 600_000, 164 Some(count) => { 165 if count < 100_000u32 { 166 return Err(ConfigErr::InvalidPasswordIterations(count)); 167 } 168 count 169 } 170 }, 171 rocket, 172 web_vault_enabled: config_file.web_vault_enabled.unwrap_or(true), 173 }) 174 } 175 } 176 impl Config { 177 pub const DATA_FOLDER: &'static str = "data"; 178 pub const DATABASE_URL: &'static str = "data/db.sqlite3"; 179 pub const PRIVATE_ED25519_KEY: &'static str = "data/ed25519_key.pem"; 180 pub const WEB_VAULT_FOLDER: &'static str = "web-vault/"; 181 #[allow(clippy::arithmetic_side_effects, clippy::string_slice)] 182 #[inline] 183 pub fn domain_url(&self) -> &str { 184 let val = self.domain_url.as_str(); 185 // The last Unicode scalar value is '/' which is a 186 // single UTF-8 code unit, and we want to remove that. 187 // Note if this changes in the future such that the last 188 // Unicode scalar value is encoded using more than one 189 // UTF-8 code unit, then this will panic. 190 // Additionally if `len` is somehow 0, indexing will panic 191 // making this memory and logic safe. 192 &val[..val.len() - 1] 193 } 194 }