config.rs (6640B)
1 use core::fmt::{self, Display, Formatter}; 2 use core::num::NonZeroU8; 3 use core::str; 4 use rocket::config::{CipherSuite, LogLevel, TlsConfig}; 5 use rocket::data::{Limits, ToByteUnit}; 6 use std::error; 7 use std::fs; 8 use std::io::Error; 9 use std::net::IpAddr; 10 use std::sync::OnceLock; 11 use toml::{self, de}; 12 use url::{ParseError, Url}; 13 static CONFIG: OnceLock<Config> = OnceLock::new(); 14 #[inline] 15 pub fn init_config() { 16 CONFIG 17 .set(Config::load().expect("valid TOML config file at 'config.toml'")) 18 .expect("CONFIG must only be initialized once"); 19 } 20 #[inline] 21 pub fn get_config() -> &'static Config { 22 CONFIG.get().expect("CONFIG must be initialized in main") 23 } 24 #[derive(Debug)] 25 pub enum ConfigErr { 26 Io(Error), 27 De(de::Error), 28 Url(ParseError), 29 BadDomain, 30 InvalidPasswordIterations(u32), 31 } 32 impl Display for ConfigErr { 33 #[inline] 34 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 35 match *self { 36 Self::Io(ref err) => err.fmt(f), 37 Self::De(ref err) => err.fmt(f), 38 Self::Url(ref err) => err.fmt(f), 39 Self::BadDomain => f.write_str( 40 "https://<domain>:<port> was unable to be parsed into a URL with a domain", 41 ), 42 Self::InvalidPasswordIterations(count) => write!( 43 f, 44 "password iterations is {count} but must be at least 100000" 45 ), 46 } 47 } 48 } 49 impl error::Error for ConfigErr {} 50 impl From<Error> for ConfigErr { 51 #[inline] 52 fn from(value: Error) -> Self { 53 Self::Io(value) 54 } 55 } 56 impl From<de::Error> for ConfigErr { 57 #[inline] 58 fn from(value: de::Error) -> Self { 59 Self::De(value) 60 } 61 } 62 impl From<ParseError> for ConfigErr { 63 #[inline] 64 fn from(value: ParseError) -> Self { 65 Self::Url(value) 66 } 67 } 68 #[derive(serde::Deserialize)] 69 #[serde(deny_unknown_fields)] 70 struct Tls { 71 ciphers: Option<Vec<CipherSuite>>, 72 cert: String, 73 key: String, 74 prefer_server_cipher_order: Option<bool>, 75 } 76 #[derive(serde::Deserialize)] 77 #[serde(deny_unknown_fields)] 78 struct ConfigFile { 79 database_max_conns: Option<NonZeroU8>, 80 database_timeout: Option<u16>, 81 db_connection_retries: Option<NonZeroU8>, 82 domain: String, 83 ip: IpAddr, 84 password_iterations: Option<u32>, 85 port: u16, 86 tls: Tls, 87 web_vault_enabled: Option<bool>, 88 workers: Option<NonZeroU8>, 89 } 90 #[derive(Debug)] 91 pub struct Config { 92 pub database_max_conns: NonZeroU8, 93 pub database_timeout: u16, 94 pub db_connection_retries: NonZeroU8, 95 domain: Url, 96 pub password_iterations: u32, 97 pub rocket: rocket::Config, 98 pub web_vault_enabled: bool, 99 } 100 impl Config { 101 #[inline] 102 pub fn load() -> Result<Self, ConfigErr> { 103 let config_file = 104 toml::from_str::<ConfigFile>(fs::read_to_string("config.toml")?.as_str())?; 105 let mut tls = TlsConfig::from_paths(config_file.tls.cert, config_file.tls.key); 106 tls = match config_file.tls.ciphers { 107 Some(ciphers) => { 108 if ciphers.is_empty() { 109 tls 110 } else { 111 tls.with_ciphers(ciphers) 112 } 113 } 114 None => tls, 115 }; 116 tls = match config_file.tls.prefer_server_cipher_order { 117 Some(prefer) => tls.with_preferred_server_cipher_order(prefer), 118 None => tls, 119 }; 120 let mut rocket = rocket::Config { 121 address: config_file.ip, 122 cli_colors: false, 123 limits: Limits::new() 124 .limit("json", 20i32.megabytes()) 125 .limit("data-form", 525i32.megabytes()) 126 .limit("file", 525i32.megabytes()), 127 log_level: LogLevel::Off, 128 port: config_file.port, 129 temp_dir: "data/tmp".into(), 130 tls: Some(tls), 131 ..Default::default() 132 }; 133 if let Some(count) = config_file.workers { 134 rocket.workers = usize::from(count.get()); 135 } 136 let url = format!( 137 "https://{}{}", 138 config_file.domain, 139 if config_file.port == 443 { 140 String::new() 141 } else { 142 format!(":{}", config_file.port) 143 } 144 ); 145 let domain = Url::parse(url.as_str())?; 146 if domain 147 .domain() 148 // We only allow domains in the config file. 149 // Note currently this check is overly conservative and 150 // disallows any domains that `Url` will encode in Punycode. 151 .map_or(true, |dom| !dom.eq_ignore_ascii_case(&config_file.domain)) 152 { 153 return Err(ConfigErr::BadDomain); 154 } 155 Ok(Self { 156 database_max_conns: config_file 157 .database_max_conns 158 .unwrap_or(NonZeroU8::new(10).unwrap()), 159 database_timeout: config_file.database_timeout.unwrap_or(30), 160 db_connection_retries: config_file 161 .db_connection_retries 162 .unwrap_or(NonZeroU8::new(15).unwrap()), 163 domain, 164 password_iterations: match config_file.password_iterations { 165 None => 600_000, 166 Some(count) => { 167 if count < 100_000u32 { 168 return Err(ConfigErr::InvalidPasswordIterations(count)); 169 } 170 count 171 } 172 }, 173 rocket, 174 web_vault_enabled: config_file.web_vault_enabled.unwrap_or(true), 175 }) 176 } 177 } 178 impl Config { 179 pub const DATA_FOLDER: &'static str = "data"; 180 pub const DATABASE_URL: &'static str = "data/db.sqlite3"; 181 pub const PRIVATE_ED25519_KEY: &'static str = "data/ed25519_key.pem"; 182 pub const WEB_VAULT_FOLDER: &'static str = "web-vault/"; 183 #[allow(clippy::arithmetic_side_effects, clippy::string_slice)] 184 #[inline] 185 pub fn domain_url(&self) -> &str { 186 let val = self.domain.as_str(); 187 // The last Unicode scalar value is '/' which is a 188 // single UTF-8 code unit, and we want to remove that. 189 // Note if this changes in the future such that the last 190 // Unicode scalar value is encoded using more than one 191 // UTF-8 code unit, then this will panic. 192 // Additionally if `len` is somehow 0, indexing will panic 193 // making this memory and logic safe. 194 &val[..val.len() - 1] 195 } 196 #[inline] 197 pub fn domain(&self) -> &str { 198 self.domain.domain().expect("impossible to error") 199 } 200 #[inline] 201 pub fn domain_origin(&self) -> String { 202 self.domain.origin().ascii_serialization() 203 } 204 }