vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

util.rs (16658B)


      1 use crate::config;
      2 use core::fmt::{self, Display, Formatter};
      3 use rocket::{
      4     fairing::{Fairing, Info, Kind},
      5     http::{ContentType, Header, HeaderMap, Method, Status},
      6     request::FromParam,
      7     response::{self, Responder},
      8     Request, Response,
      9 };
     10 use serde::de::{self, DeserializeOwned, Deserializer, MapAccess, SeqAccess, Visitor};
     11 use std::{error, io::Cursor, ops::Deref, string::ToString};
     12 use tokio::{
     13     runtime::Handle,
     14     time::{sleep, Duration},
     15 };
     16 
     17 pub struct AppHeaders;
     18 
     19 #[rocket::async_trait]
     20 impl Fairing for AppHeaders {
     21     fn info(&self) -> Info {
     22         Info {
     23             name: "Application Headers",
     24             kind: Kind::Response,
     25         }
     26     }
     27     #[allow(clippy::similar_names)]
     28     async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
     29         let req_uri_path = req.uri().path();
     30         let req_headers = req.headers();
     31 
     32         // Check if this connection is an Upgrade/WebSocket connection and return early
     33         // We do not want add any extra headers, this could cause issues with reverse proxies or CloudFlare
     34         if req_uri_path.ends_with("notifications/hub")
     35             || req_uri_path.ends_with("notifications/anonymous-hub")
     36         {
     37             match (
     38                 req_headers.get_one("connection"),
     39                 req_headers.get_one("upgrade"),
     40             ) {
     41                 (Some(c), Some(u))
     42                     if c.to_lowercase().contains("upgrade")
     43                         && u.to_lowercase().contains("websocket") =>
     44                 {
     45                     // Remove headers which could cause websocket connection issues
     46                     res.remove_header("X-Frame-Options");
     47                     res.remove_header("X-Content-Type-Options");
     48                     res.remove_header("Permissions-Policy");
     49                     return;
     50                 }
     51                 (_, _) => (),
     52             }
     53         }
     54         res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()");
     55         res.set_raw_header("Referrer-Policy", "same-origin");
     56         res.set_raw_header("X-Content-Type-Options", "nosniff");
     57         // Obsolete in modern browsers, unsafe (XS-Leak), and largely replaced by CSP
     58         res.set_raw_header("X-XSS-Protection", "0");
     59         // Do not send the Content-Security-Policy (CSP) Header and X-Frame-Options for the *-connector.html files.
     60         // This can cause issues when some MFA requests needs to open a popup or page within the clients like WebAuthn.
     61         // This is the same behavior as upstream Bitwarden.
     62         if req_uri_path.ends_with("connector.html") {
     63             // It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues.
     64             res.remove_header("X-Frame-Options");
     65         } else {
     66             // # Frame Ancestors:
     67             // Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb
     68             // Edge Add-ons: https://microsoftedge.microsoft.com/addons/detail/bitwarden-free-password/jbkfoedolllekgbhcbcoahefnbanhhlh?hl=en-US
     69             // Firefox Browser Add-ons: https://addons.mozilla.org/en-US/firefox/addon/bitwarden-password-manager/
     70             // # img/child/frame src:
     71             // Have I Been Pwned to allow those calls to work.
     72             // # Connect src:
     73             // Leaked Passwords check: api.pwnedpasswords.com
     74             // 2FA/MFA Site check: api.2fa.directory
     75             // # Mail Relay: https://bitwarden.com/blog/add-privacy-and-security-using-email-aliases-with-bitwarden/
     76             // app.simplelogin.io, app.addy.io, api.fastmail.com, quack.duckduckgo.com
     77             let csp = format!(
     78                 "default-src 'self'; \
     79                 base-uri 'self'; \
     80                 form-action 'self'; \
     81                 object-src 'self' blob:; \
     82                 script-src 'self' 'wasm-unsafe-eval'; \
     83                 style-src 'self' 'unsafe-inline'; \
     84                 child-src 'self' https://*.duosecurity.com https://*.duofederal.com; \
     85                 frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; \
     86                 frame-ancestors 'self' \
     87                   chrome-extension://nngceckbapebfimnlniiiahkandclblb \
     88                   chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh \
     89                   moz-extension://* \
     90                   {allowed_iframe_ancestors}; \
     91                 img-src 'self' data: \
     92                   https://haveibeenpwned.com \
     93                   {icon_service_csp}; \
     94                 connect-src 'self' \
     95                   https://api.pwnedpasswords.com \
     96                   https://api.2fa.directory \
     97                   https://app.simplelogin.io/api/ \
     98                   https://app.addy.io/api/ \
     99                   https://api.fastmail.com/ \
    100                   https://api.forwardemail.net \
    101                   ;\
    102                 ",
    103                 icon_service_csp = "",
    104                 allowed_iframe_ancestors = ""
    105             );
    106             res.set_raw_header("Content-Security-Policy", csp);
    107             res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
    108         }
    109         // Disable cache unless otherwise specified
    110         if !res.headers().contains("cache-control") {
    111             res.set_raw_header("Cache-Control", "no-cache, no-store, max-age=0");
    112         }
    113     }
    114 }
    115 
    116 pub struct Cors;
    117 
    118 impl Cors {
    119     fn get_header(headers: &HeaderMap<'_>, name: &str) -> String {
    120         headers
    121             .get_one(name)
    122             .map_or_else(String::new, ToString::to_string)
    123     }
    124     // Check a request's `Origin` header against the list of allowed origins.
    125     // If a match exists, return it. Otherwise, return None.
    126     fn get_allowed_origin(headers: &HeaderMap<'_>) -> Option<String> {
    127         let origin = Self::get_header(headers, "Origin");
    128         let domain_origin = config::get_config().domain_origin();
    129         let safari_extension_origin = "file://";
    130         if origin == domain_origin || origin == safari_extension_origin {
    131             Some(origin)
    132         } else {
    133             None
    134         }
    135     }
    136 }
    137 
    138 #[rocket::async_trait]
    139 impl Fairing for Cors {
    140     fn info(&self) -> Info {
    141         Info {
    142             name: "Cors",
    143             kind: Kind::Response,
    144         }
    145     }
    146 
    147     async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
    148         let req_headers = request.headers();
    149         if let Some(origin) = Self::get_allowed_origin(req_headers) {
    150             response.set_header(Header::new("Access-Control-Allow-Origin", origin));
    151         }
    152         // Preflight request
    153         if request.method() == Method::Options {
    154             let req_allow_headers = Self::get_header(req_headers, "Access-Control-Request-Headers");
    155             let req_allow_method = Self::get_header(req_headers, "Access-Control-Request-Method");
    156             response.set_header(Header::new(
    157                 "Access-Control-Allow-Methods",
    158                 req_allow_method,
    159             ));
    160             response.set_header(Header::new(
    161                 "Access-Control-Allow-Headers",
    162                 req_allow_headers,
    163             ));
    164             response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
    165             response.set_status(Status::Ok);
    166             response.set_header(ContentType::Plain);
    167             response.set_sized_body(Some(0), Cursor::new(""));
    168         }
    169     }
    170 }
    171 
    172 pub struct Cached<R> {
    173     response: R,
    174     is_immutable: bool,
    175     ttl: u64,
    176 }
    177 
    178 impl<R> Cached<R> {
    179     pub const fn long(response: R, is_immutable: bool) -> Self {
    180         Self {
    181             response,
    182             is_immutable,
    183             ttl: 604_800, // 7 days
    184         }
    185     }
    186     pub const fn short(response: R, is_immutable: bool) -> Self {
    187         Self {
    188             response,
    189             is_immutable,
    190             ttl: 600, // 10 minutes
    191         }
    192     }
    193     pub const fn ttl(response: R, ttl: u64, is_immutable: bool) -> Self {
    194         Self {
    195             response,
    196             is_immutable,
    197             ttl,
    198         }
    199     }
    200 }
    201 impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> {
    202     fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
    203         let mut res = self.response.respond_to(request)?;
    204         let cache_control_header = if self.is_immutable {
    205             format!("public, immutable, max-age={}", self.ttl)
    206         } else {
    207             format!("public, max-age={}", self.ttl)
    208         };
    209         res.set_raw_header("Cache-Control", cache_control_header);
    210         let time_now = chrono::Local::now();
    211         let expiry_time = time_now
    212             .checked_add_signed(
    213                 chrono::TimeDelta::try_seconds(self.ttl.try_into().unwrap()).unwrap(),
    214             )
    215             .expect("Duration add overflowed");
    216         res.set_raw_header("Expires", format_datetime_http(&expiry_time));
    217         Ok(res)
    218     }
    219 }
    220 pub struct SafeString(String);
    221 
    222 impl Display for SafeString {
    223     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
    224         self.0.fmt(f)
    225     }
    226 }
    227 
    228 impl Deref for SafeString {
    229     type Target = String;
    230     fn deref(&self) -> &Self::Target {
    231         &self.0
    232     }
    233 }
    234 
    235 impl AsRef<Path> for SafeString {
    236     #[inline]
    237     fn as_ref(&self) -> &Path {
    238         Path::new(&self.0)
    239     }
    240 }
    241 
    242 impl<'r> FromParam<'r> for SafeString {
    243     type Error = ();
    244     #[inline]
    245     fn from_param(param: &'r str) -> Result<Self, Self::Error> {
    246         if param
    247             .chars()
    248             .all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-'))
    249         {
    250             Ok(Self(param.to_owned()))
    251         } else {
    252             Err(())
    253         }
    254     }
    255 }
    256 use std::path::Path;
    257 
    258 pub fn get_uuid() -> String {
    259     uuid::Uuid::new_v4().to_string()
    260 }
    261 use std::str::FromStr;
    262 fn lcase_first(s: &str) -> String {
    263     let mut c = s.chars();
    264     c.next().map_or_else(String::new, |f| {
    265         let mut val = f.to_lowercase().collect::<String>();
    266         val.push_str(c.as_str());
    267         val
    268     })
    269 }
    270 pub fn try_parse_string<S, T>(string: Option<S>) -> Option<T>
    271 where
    272     S: AsRef<str>,
    273     T: FromStr,
    274 {
    275     if let Some(Ok(value)) = string.map(|s| s.as_ref().parse::<T>()) {
    276         Some(value)
    277     } else {
    278         None
    279     }
    280 }
    281 use chrono::{DateTime, Local, NaiveDateTime};
    282 // Format used by Bitwarden API
    283 const DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.6fZ";
    284 /// Formats a UTC-offset `NaiveDateTime` in the format used by Bitwarden API
    285 /// responses with "date" fields (`CreationDate`, `RevisionDate`, etc.).
    286 pub fn format_date(dt: &NaiveDateTime) -> String {
    287     dt.format(DATETIME_FORMAT).to_string()
    288 }
    289 /// Formats a `DateTime<Local>` as required for HTTP
    290 ///
    291 /// [http](https://httpwg.org/specs/rfc7231.html#http.date)
    292 fn format_datetime_http(dt: &DateTime<Local>) -> String {
    293     let expiry_time =
    294         DateTime::<chrono::Utc>::from_naive_utc_and_offset(dt.naive_utc(), chrono::Utc);
    295 
    296     // HACK: HTTP expects the date to always be GMT (UTC) rather than giving an
    297     // offset (which would always be 0 in UTC anyway)
    298     expiry_time.to_rfc2822().replace("+0000", "GMT")
    299 }
    300 use serde_json::{self, Value};
    301 type JsonMap = serde_json::Map<String, Value>;
    302 
    303 #[derive(Serialize, Deserialize)]
    304 pub struct LowerCase<T: DeserializeOwned> {
    305     #[serde(deserialize_with = "lowercase_deserialize")]
    306     #[serde(flatten)]
    307     pub data: T,
    308 }
    309 
    310 impl Default for LowerCase<Value> {
    311     fn default() -> Self {
    312         Self { data: Value::Null }
    313     }
    314 }
    315 
    316 pub fn lowercase_deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
    317 where
    318     T: DeserializeOwned,
    319     D: Deserializer<'de>,
    320 {
    321     let d = deserializer.deserialize_any(LowerCaseVisitor)?;
    322     T::deserialize(d).map_err(de::Error::custom)
    323 }
    324 
    325 struct LowerCaseVisitor;
    326 
    327 impl<'de> Visitor<'de> for LowerCaseVisitor {
    328     type Value = Value;
    329 
    330     fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
    331         formatter.write_str("an object or an array")
    332     }
    333     fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
    334     where
    335         A: MapAccess<'de>,
    336     {
    337         let mut result_map = JsonMap::new();
    338 
    339         while let Some((key, value)) = map.next_entry()? {
    340             result_map.insert(_process_key(key), convert_json_key_lcase_first(value));
    341         }
    342 
    343         Ok(Value::Object(result_map))
    344     }
    345     fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
    346     where
    347         A: SeqAccess<'de>,
    348     {
    349         let mut result_seq = Vec::<Value>::new();
    350 
    351         while let Some(value) = seq.next_element()? {
    352             result_seq.push(convert_json_key_lcase_first(value));
    353         }
    354 
    355         Ok(Value::Array(result_seq))
    356     }
    357 }
    358 
    359 // Inner function to handle a special case for the 'ssn' key.
    360 // This key is part of the Identity Cipher (Social Security Number)
    361 fn _process_key(key: &str) -> String {
    362     match key.to_lowercase().as_ref() {
    363         "ssn" => "ssn".into(),
    364         _ => self::lcase_first(key),
    365     }
    366 }
    367 
    368 #[derive(Clone, Debug, Deserialize)]
    369 #[serde(untagged)]
    370 pub enum NumberOrString {
    371     Number(i64),
    372     String(String),
    373 }
    374 
    375 impl NumberOrString {
    376     pub fn into_string(self) -> String {
    377         match self {
    378             Self::Number(n) => n.to_string(),
    379             Self::String(s) => s,
    380         }
    381     }
    382     pub fn into_i32(self) -> Result<i32, crate::Error> {
    383         use std::num::ParseIntError as PIE;
    384         match self {
    385             Self::Number(n) => match i32::try_from(n) {
    386                 Ok(n) => Ok(n),
    387                 Err(_) => err!("Number does not fit in i32"),
    388             },
    389             Self::String(s) => s
    390                 .parse()
    391                 .map_err(|e: PIE| crate::Error::new("Can't convert to number", e.to_string())),
    392         }
    393     }
    394 }
    395 
    396 pub fn retry<F, T, E>(mut func: F, max_tries: u32) -> Result<T, E>
    397 where
    398     F: FnMut() -> Result<T, E>,
    399 {
    400     let mut tries = 0u32;
    401     loop {
    402         match func() {
    403             ok @ Ok(_) => return ok,
    404             err @ Err(_) => {
    405                 tries = tries.checked_add(1).expect("u32 add overflowed");
    406                 if tries >= max_tries {
    407                     return err;
    408                 }
    409                 Handle::current().block_on(sleep(Duration::from_millis(500)));
    410             }
    411         }
    412     }
    413 }
    414 
    415 pub async fn retry_db<F, T: Send, E>(mut func: F, max_tries: u32) -> Result<T, E>
    416 where
    417     F: FnMut() -> Result<T, E> + Send,
    418     E: error::Error + Send,
    419 {
    420     let mut tries = 0u32;
    421     loop {
    422         match func() {
    423             ok @ Ok(_) => return ok,
    424             Err(e) => {
    425                 tries = tries.checked_add(1).expect("u32 add overflowed");
    426                 if tries >= max_tries && max_tries > 0 {
    427                     return Err(e);
    428                 }
    429                 warn!("Can't connect to database, retrying: {:?}", e);
    430                 sleep(Duration::from_millis(1_000)).await;
    431             }
    432         }
    433     }
    434 }
    435 pub fn convert_json_key_lcase_first(src_json: Value) -> Value {
    436     match src_json {
    437         Value::Array(elm) => {
    438             let mut new_array: Vec<Value> = Vec::with_capacity(elm.len());
    439 
    440             for obj in elm {
    441                 new_array.push(convert_json_key_lcase_first(obj));
    442             }
    443             Value::Array(new_array)
    444         }
    445 
    446         Value::Object(obj) => {
    447             let mut json_map = JsonMap::new();
    448             for (key, value) in obj {
    449                 match (key, value) {
    450                     (key, Value::Object(elm)) => {
    451                         let inner_value = convert_json_key_lcase_first(Value::Object(elm));
    452                         json_map.insert(_process_key(&key), inner_value);
    453                     }
    454 
    455                     (key, Value::Array(elm)) => {
    456                         let mut inner_array: Vec<Value> = Vec::with_capacity(elm.len());
    457 
    458                         for inner_obj in elm {
    459                             inner_array.push(convert_json_key_lcase_first(inner_obj));
    460                         }
    461 
    462                         json_map.insert(_process_key(&key), Value::Array(inner_array));
    463                     }
    464 
    465                     (key, value) => {
    466                         json_map.insert(_process_key(&key), value);
    467                     }
    468                 }
    469             }
    470 
    471             Value::Object(json_map)
    472         }
    473         value @ (Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_)) => value,
    474     }
    475 }