util.rs (16655B)
1 use crate::config; 2 use core::fmt::{self, Display, Formatter}; 3 use rocket::{ 4 fairing::{Fairing, Info, Kind}, 5 http::{ContentType, Header, HeaderMap, Method, Status}, 6 request::FromParam, 7 response::{self, Responder}, 8 Request, Response, 9 }; 10 use serde::de::{self, DeserializeOwned, Deserializer, MapAccess, SeqAccess, Visitor}; 11 use std::{error, io::Cursor, ops::Deref, string::ToString}; 12 use tokio::{ 13 runtime::Handle, 14 time::{sleep, Duration}, 15 }; 16 17 pub struct AppHeaders; 18 19 #[rocket::async_trait] 20 impl Fairing for AppHeaders { 21 fn info(&self) -> Info { 22 Info { 23 name: "Application Headers", 24 kind: Kind::Response, 25 } 26 } 27 #[allow(clippy::similar_names)] 28 async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { 29 let req_uri_path = req.uri().path(); 30 let req_headers = req.headers(); 31 32 // Check if this connection is an Upgrade/WebSocket connection and return early 33 // We do not want add any extra headers, this could cause issues with reverse proxies or CloudFlare 34 if req_uri_path.ends_with("notifications/hub") 35 || req_uri_path.ends_with("notifications/anonymous-hub") 36 { 37 match ( 38 req_headers.get_one("connection"), 39 req_headers.get_one("upgrade"), 40 ) { 41 (Some(c), Some(u)) 42 if c.to_lowercase().contains("upgrade") 43 && u.to_lowercase().contains("websocket") => 44 { 45 // Remove headers which could cause websocket connection issues 46 res.remove_header("X-Frame-Options"); 47 res.remove_header("X-Content-Type-Options"); 48 res.remove_header("Permissions-Policy"); 49 return; 50 } 51 (_, _) => (), 52 } 53 } 54 res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()"); 55 res.set_raw_header("Referrer-Policy", "same-origin"); 56 res.set_raw_header("X-Content-Type-Options", "nosniff"); 57 // Obsolete in modern browsers, unsafe (XS-Leak), and largely replaced by CSP 58 res.set_raw_header("X-XSS-Protection", "0"); 59 // Do not send the Content-Security-Policy (CSP) Header and X-Frame-Options for the *-connector.html files. 60 // This can cause issues when some MFA requests needs to open a popup or page within the clients like WebAuthn. 61 // This is the same behavior as upstream Bitwarden. 62 if req_uri_path.ends_with("connector.html") { 63 // It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues. 64 res.remove_header("X-Frame-Options"); 65 } else { 66 // # Frame Ancestors: 67 // Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb 68 // Edge Add-ons: https://microsoftedge.microsoft.com/addons/detail/bitwarden-free-password/jbkfoedolllekgbhcbcoahefnbanhhlh?hl=en-US 69 // Firefox Browser Add-ons: https://addons.mozilla.org/en-US/firefox/addon/bitwarden-password-manager/ 70 // # img/child/frame src: 71 // Have I Been Pwned to allow those calls to work. 72 // # Connect src: 73 // Leaked Passwords check: api.pwnedpasswords.com 74 // 2FA/MFA Site check: api.2fa.directory 75 // # Mail Relay: https://bitwarden.com/blog/add-privacy-and-security-using-email-aliases-with-bitwarden/ 76 // app.simplelogin.io, app.addy.io, api.fastmail.com, quack.duckduckgo.com 77 let csp = format!( 78 "default-src 'self'; \ 79 base-uri 'self'; \ 80 form-action 'self'; \ 81 object-src 'self' blob:; \ 82 script-src 'self' 'wasm-unsafe-eval'; \ 83 style-src 'self' 'unsafe-inline'; \ 84 child-src 'self' https://*.duosecurity.com https://*.duofederal.com; \ 85 frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; \ 86 frame-ancestors 'self' \ 87 chrome-extension://nngceckbapebfimnlniiiahkandclblb \ 88 chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh \ 89 moz-extension://* \ 90 {allowed_iframe_ancestors}; \ 91 img-src 'self' data: \ 92 https://haveibeenpwned.com \ 93 {icon_service_csp}; \ 94 connect-src 'self' \ 95 https://api.pwnedpasswords.com \ 96 https://api.2fa.directory \ 97 https://app.simplelogin.io/api/ \ 98 https://app.addy.io/api/ \ 99 https://api.fastmail.com/ \ 100 https://api.forwardemail.net \ 101 ;\ 102 ", 103 icon_service_csp = "", 104 allowed_iframe_ancestors = "" 105 ); 106 res.set_raw_header("Content-Security-Policy", csp); 107 res.set_raw_header("X-Frame-Options", "SAMEORIGIN"); 108 } 109 // Disable cache unless otherwise specified 110 if !res.headers().contains("cache-control") { 111 res.set_raw_header("Cache-Control", "no-cache, no-store, max-age=0"); 112 } 113 } 114 } 115 116 pub struct Cors; 117 118 impl Cors { 119 fn get_header(headers: &HeaderMap<'_>, name: &str) -> String { 120 headers 121 .get_one(name) 122 .map_or_else(String::new, ToString::to_string) 123 } 124 // Check a request's `Origin` header against the list of allowed origins. 125 // If a match exists, return it. Otherwise, return None. 126 fn get_allowed_origin(headers: &HeaderMap<'_>) -> Option<String> { 127 let origin = Self::get_header(headers, "Origin"); 128 let domain_origin = config::get_config().domain_url(); 129 let safari_extension_origin = "file://"; 130 if origin == domain_origin || origin == safari_extension_origin { 131 Some(origin) 132 } else { 133 None 134 } 135 } 136 } 137 138 #[rocket::async_trait] 139 impl Fairing for Cors { 140 fn info(&self) -> Info { 141 Info { 142 name: "Cors", 143 kind: Kind::Response, 144 } 145 } 146 147 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) { 148 let req_headers = request.headers(); 149 if let Some(origin) = Self::get_allowed_origin(req_headers) { 150 response.set_header(Header::new("Access-Control-Allow-Origin", origin)); 151 } 152 // Preflight request 153 if request.method() == Method::Options { 154 let req_allow_headers = Self::get_header(req_headers, "Access-Control-Request-Headers"); 155 let req_allow_method = Self::get_header(req_headers, "Access-Control-Request-Method"); 156 response.set_header(Header::new( 157 "Access-Control-Allow-Methods", 158 req_allow_method, 159 )); 160 response.set_header(Header::new( 161 "Access-Control-Allow-Headers", 162 req_allow_headers, 163 )); 164 response.set_header(Header::new("Access-Control-Allow-Credentials", "true")); 165 response.set_status(Status::Ok); 166 response.set_header(ContentType::Plain); 167 response.set_sized_body(Some(0), Cursor::new("")); 168 } 169 } 170 } 171 172 pub struct Cached<R> { 173 response: R, 174 is_immutable: bool, 175 ttl: u64, 176 } 177 178 impl<R> Cached<R> { 179 pub const fn long(response: R, is_immutable: bool) -> Self { 180 Self { 181 response, 182 is_immutable, 183 ttl: 604_800, // 7 days 184 } 185 } 186 pub const fn short(response: R, is_immutable: bool) -> Self { 187 Self { 188 response, 189 is_immutable, 190 ttl: 600, // 10 minutes 191 } 192 } 193 pub const fn ttl(response: R, ttl: u64, is_immutable: bool) -> Self { 194 Self { 195 response, 196 is_immutable, 197 ttl, 198 } 199 } 200 } 201 impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> { 202 fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> { 203 let mut res = self.response.respond_to(request)?; 204 let cache_control_header = if self.is_immutable { 205 format!("public, immutable, max-age={}", self.ttl) 206 } else { 207 format!("public, max-age={}", self.ttl) 208 }; 209 res.set_raw_header("Cache-Control", cache_control_header); 210 let time_now = chrono::Local::now(); 211 let expiry_time = time_now 212 .checked_add_signed( 213 chrono::TimeDelta::try_seconds(self.ttl.try_into().unwrap()).unwrap(), 214 ) 215 .expect("Duration add overflowed"); 216 res.set_raw_header("Expires", format_datetime_http(&expiry_time)); 217 Ok(res) 218 } 219 } 220 pub struct SafeString(String); 221 222 impl Display for SafeString { 223 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 224 self.0.fmt(f) 225 } 226 } 227 228 impl Deref for SafeString { 229 type Target = String; 230 fn deref(&self) -> &Self::Target { 231 &self.0 232 } 233 } 234 235 impl AsRef<Path> for SafeString { 236 #[inline] 237 fn as_ref(&self) -> &Path { 238 Path::new(&self.0) 239 } 240 } 241 242 impl<'r> FromParam<'r> for SafeString { 243 type Error = (); 244 #[inline] 245 fn from_param(param: &'r str) -> Result<Self, Self::Error> { 246 if param 247 .chars() 248 .all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) 249 { 250 Ok(Self(param.to_owned())) 251 } else { 252 Err(()) 253 } 254 } 255 } 256 use std::path::Path; 257 258 pub fn get_uuid() -> String { 259 uuid::Uuid::new_v4().to_string() 260 } 261 use std::str::FromStr; 262 fn lcase_first(s: &str) -> String { 263 let mut c = s.chars(); 264 c.next().map_or_else(String::new, |f| { 265 let mut val = f.to_lowercase().collect::<String>(); 266 val.push_str(c.as_str()); 267 val 268 }) 269 } 270 pub fn try_parse_string<S, T>(string: Option<S>) -> Option<T> 271 where 272 S: AsRef<str>, 273 T: FromStr, 274 { 275 if let Some(Ok(value)) = string.map(|s| s.as_ref().parse::<T>()) { 276 Some(value) 277 } else { 278 None 279 } 280 } 281 use chrono::{DateTime, Local, NaiveDateTime}; 282 // Format used by Bitwarden API 283 const DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.6fZ"; 284 /// Formats a UTC-offset `NaiveDateTime` in the format used by Bitwarden API 285 /// responses with "date" fields (`CreationDate`, `RevisionDate`, etc.). 286 pub fn format_date(dt: &NaiveDateTime) -> String { 287 dt.format(DATETIME_FORMAT).to_string() 288 } 289 /// Formats a `DateTime<Local>` as required for HTTP 290 /// 291 /// [http](https://httpwg.org/specs/rfc7231.html#http.date) 292 fn format_datetime_http(dt: &DateTime<Local>) -> String { 293 let expiry_time = 294 DateTime::<chrono::Utc>::from_naive_utc_and_offset(dt.naive_utc(), chrono::Utc); 295 296 // HACK: HTTP expects the date to always be GMT (UTC) rather than giving an 297 // offset (which would always be 0 in UTC anyway) 298 expiry_time.to_rfc2822().replace("+0000", "GMT") 299 } 300 use serde_json::{self, Value}; 301 type JsonMap = serde_json::Map<String, Value>; 302 303 #[derive(Serialize, Deserialize)] 304 pub struct LowerCase<T: DeserializeOwned> { 305 #[serde(deserialize_with = "lowercase_deserialize")] 306 #[serde(flatten)] 307 pub data: T, 308 } 309 310 impl Default for LowerCase<Value> { 311 fn default() -> Self { 312 Self { data: Value::Null } 313 } 314 } 315 316 pub fn lowercase_deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error> 317 where 318 T: DeserializeOwned, 319 D: Deserializer<'de>, 320 { 321 let d = deserializer.deserialize_any(LowerCaseVisitor)?; 322 T::deserialize(d).map_err(de::Error::custom) 323 } 324 325 struct LowerCaseVisitor; 326 327 impl<'de> Visitor<'de> for LowerCaseVisitor { 328 type Value = Value; 329 330 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { 331 formatter.write_str("an object or an array") 332 } 333 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error> 334 where 335 A: MapAccess<'de>, 336 { 337 let mut result_map = JsonMap::new(); 338 339 while let Some((key, value)) = map.next_entry()? { 340 result_map.insert(_process_key(key), convert_json_key_lcase_first(value)); 341 } 342 343 Ok(Value::Object(result_map)) 344 } 345 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error> 346 where 347 A: SeqAccess<'de>, 348 { 349 let mut result_seq = Vec::<Value>::new(); 350 351 while let Some(value) = seq.next_element()? { 352 result_seq.push(convert_json_key_lcase_first(value)); 353 } 354 355 Ok(Value::Array(result_seq)) 356 } 357 } 358 359 // Inner function to handle a special case for the 'ssn' key. 360 // This key is part of the Identity Cipher (Social Security Number) 361 fn _process_key(key: &str) -> String { 362 match key.to_lowercase().as_ref() { 363 "ssn" => "ssn".into(), 364 _ => self::lcase_first(key), 365 } 366 } 367 368 #[derive(Clone, Debug, Deserialize)] 369 #[serde(untagged)] 370 pub enum NumberOrString { 371 Number(i64), 372 String(String), 373 } 374 375 impl NumberOrString { 376 pub fn into_string(self) -> String { 377 match self { 378 Self::Number(n) => n.to_string(), 379 Self::String(s) => s, 380 } 381 } 382 pub fn into_i32(self) -> Result<i32, crate::Error> { 383 use std::num::ParseIntError as PIE; 384 match self { 385 Self::Number(n) => match i32::try_from(n) { 386 Ok(n) => Ok(n), 387 Err(_) => err!("Number does not fit in i32"), 388 }, 389 Self::String(s) => s 390 .parse() 391 .map_err(|e: PIE| crate::Error::new("Can't convert to number", e.to_string())), 392 } 393 } 394 } 395 396 pub fn retry<F, T, E>(mut func: F, max_tries: u32) -> Result<T, E> 397 where 398 F: FnMut() -> Result<T, E>, 399 { 400 let mut tries = 0u32; 401 loop { 402 match func() { 403 ok @ Ok(_) => return ok, 404 err @ Err(_) => { 405 tries = tries.checked_add(1).expect("u32 add overflowed"); 406 if tries >= max_tries { 407 return err; 408 } 409 Handle::current().block_on(sleep(Duration::from_millis(500))); 410 } 411 } 412 } 413 } 414 415 pub async fn retry_db<F, T: Send, E>(mut func: F, max_tries: u32) -> Result<T, E> 416 where 417 F: FnMut() -> Result<T, E> + Send, 418 E: error::Error + Send, 419 { 420 let mut tries = 0u32; 421 loop { 422 match func() { 423 ok @ Ok(_) => return ok, 424 Err(e) => { 425 tries = tries.checked_add(1).expect("u32 add overflowed"); 426 if tries >= max_tries && max_tries > 0 { 427 return Err(e); 428 } 429 warn!("Can't connect to database, retrying: {:?}", e); 430 sleep(Duration::from_millis(1_000)).await; 431 } 432 } 433 } 434 } 435 pub fn convert_json_key_lcase_first(src_json: Value) -> Value { 436 match src_json { 437 Value::Array(elm) => { 438 let mut new_array: Vec<Value> = Vec::with_capacity(elm.len()); 439 440 for obj in elm { 441 new_array.push(convert_json_key_lcase_first(obj)); 442 } 443 Value::Array(new_array) 444 } 445 446 Value::Object(obj) => { 447 let mut json_map = JsonMap::new(); 448 for (key, value) in obj { 449 match (key, value) { 450 (key, Value::Object(elm)) => { 451 let inner_value = convert_json_key_lcase_first(Value::Object(elm)); 452 json_map.insert(_process_key(&key), inner_value); 453 } 454 455 (key, Value::Array(elm)) => { 456 let mut inner_array: Vec<Value> = Vec::with_capacity(elm.len()); 457 458 for inner_obj in elm { 459 inner_array.push(convert_json_key_lcase_first(inner_obj)); 460 } 461 462 json_map.insert(_process_key(&key), Value::Array(inner_array)); 463 } 464 465 (key, value) => { 466 json_map.insert(_process_key(&key), value); 467 } 468 } 469 } 470 471 Value::Object(json_map) 472 } 473 value @ (Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_)) => value, 474 } 475 }