vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

auth.rs (27747B)


      1 use crate::{
      2     config::{self, Config},
      3     error::Error,
      4 };
      5 use chrono::{TimeDelta, Utc};
      6 use jsonwebtoken::{self, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header};
      7 use openssl::pkey::{Id, PKey};
      8 use serde::de::DeserializeOwned;
      9 use serde::ser::Serialize;
     10 use std::fs::File;
     11 use std::io::{Read, Write};
     12 use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
     13 use webauthn_rp::request::{
     14     auth::{
     15         AuthenticationServerState, AuthenticationVerificationOptions,
     16         AuthenticatorAttachmentEnforcement, SignatureCounterEnforcement,
     17     },
     18     register::{RegistrationServerState, RegistrationVerificationOptions},
     19     BackupReq, BuildIdentityHasher, FixedCapHashSet,
     20 };
     21 static ALLOWED_ORIGINS: OnceLock<[&str; 1]> = OnceLock::new();
     22 #[inline]
     23 fn init_allowed_origins() {
     24     ALLOWED_ORIGINS
     25         .set([config::get_config().domain_url()])
     26         .expect("ALLOWED_ORIGINS must only be initialized once");
     27 }
     28 #[inline]
     29 fn get_allowed_origins() -> &'static [&'static str] {
     30     ALLOWED_ORIGINS
     31         .get()
     32         .expect("ALLOWED_ORIGINS must be initialized in main")
     33         .as_slice()
     34 }
     35 static REG_CEREMONIES: OnceLock<
     36     Arc<Mutex<FixedCapHashSet<RegistrationServerState, BuildIdentityHasher>>>,
     37 > = OnceLock::new();
     38 #[inline]
     39 fn init_reg_ceremonies() {
     40     REG_CEREMONIES
     41         .set(Arc::new(Mutex::new(FixedCapHashSet::new(10000))))
     42         .expect("REG_CEREMONIES must only be initialized once");
     43 }
     44 #[inline]
     45 pub fn get_reg_ceremonies(
     46 ) -> MutexGuard<'static, FixedCapHashSet<RegistrationServerState, BuildIdentityHasher>> {
     47     REG_CEREMONIES
     48         .get()
     49         .expect("REG_CEREMONIES must be initialized in main")
     50         .lock()
     51         .unwrap()
     52 }
     53 static REG_OPTIONS: OnceLock<RegistrationVerificationOptions<'static, 'static, &str, &str>> =
     54     OnceLock::new();
     55 #[inline]
     56 fn init_reg_options() {
     57     REG_OPTIONS
     58         .set(RegistrationVerificationOptions {
     59             allowed_origins: get_allowed_origins(),
     60             allowed_top_origins: None,
     61             backup_requirement: BackupReq::None,
     62             error_on_unsolicited_extensions: true,
     63             require_authenticator_attachment: false,
     64             client_data_json_relaxed: true,
     65         })
     66         .expect("REG_OPTIONS must only be initialized once");
     67 }
     68 #[inline]
     69 pub fn get_reg_options(
     70 ) -> &'static RegistrationVerificationOptions<'static, 'static, &'static str, &'static str> {
     71     REG_OPTIONS
     72         .get()
     73         .expect("REG_OPTIONS must be initialized in main")
     74 }
     75 static AUTH_CEREMONIES: OnceLock<
     76     Arc<Mutex<FixedCapHashSet<AuthenticationServerState, BuildIdentityHasher>>>,
     77 > = OnceLock::new();
     78 #[inline]
     79 fn init_auth_ceremonies() {
     80     AUTH_CEREMONIES
     81         .set(Arc::new(Mutex::new(FixedCapHashSet::new(10000))))
     82         .expect("AUTH_CEREMONIES must only be initialized once");
     83 }
     84 #[inline]
     85 pub fn get_auth_ceremonies(
     86 ) -> MutexGuard<'static, FixedCapHashSet<AuthenticationServerState, BuildIdentityHasher>> {
     87     AUTH_CEREMONIES
     88         .get()
     89         .expect("AUTH_CEREMONIES must be initialized in main")
     90         .lock()
     91         .unwrap()
     92 }
     93 static AUTH_OPTIONS: OnceLock<AuthenticationVerificationOptions<'static, 'static, &str, &str>> =
     94     OnceLock::new();
     95 #[inline]
     96 fn init_auth_options() {
     97     AUTH_OPTIONS
     98         .set(AuthenticationVerificationOptions {
     99             allowed_origins: get_allowed_origins(),
    100             allowed_top_origins: None,
    101             backup_requirement: None,
    102             error_on_unsolicited_extensions: true,
    103             auth_attachment_enforcement: AuthenticatorAttachmentEnforcement::Update(false),
    104             client_data_json_relaxed: true,
    105             update_uv: false,
    106             sig_counter_enforcement: SignatureCounterEnforcement::Fail,
    107         })
    108         .expect("AUTH_OPTIONS must only be initialized once");
    109 }
    110 #[inline]
    111 pub fn get_auth_options(
    112 ) -> &'static AuthenticationVerificationOptions<'static, 'static, &'static str, &'static str> {
    113     AUTH_OPTIONS
    114         .get()
    115         .expect("AUTH_OPTIONS must be initialized in main")
    116 }
    117 static DEFAULT_VALIDITY: OnceLock<TimeDelta> = OnceLock::new();
    118 #[inline]
    119 fn init_default_validity() {
    120     DEFAULT_VALIDITY
    121         .set(TimeDelta::try_hours(2).expect("TimeDelta::try_hours(2) should work"))
    122         .expect("DEFAULT_VALIDITY must only be initialized once");
    123 }
    124 #[inline]
    125 pub fn get_default_validity() -> &'static TimeDelta {
    126     DEFAULT_VALIDITY
    127         .get()
    128         .expect("DEFAULT_VALIDITY must be initialized in main")
    129 }
    130 static JWT_HEADER: OnceLock<Header> = OnceLock::new();
    131 #[inline]
    132 fn init_jwt_header() {
    133     JWT_HEADER
    134         .set(Header::new(JWT_ALGORITHM))
    135         .expect("JWT_HEADER must only be initialized once");
    136 }
    137 #[inline]
    138 fn get_jwt_header() -> &'static Header {
    139     JWT_HEADER
    140         .get()
    141         .expect("JWT_HEADER must be initialized in main")
    142 }
    143 static JWT_LOGIN_ISSUER: OnceLock<String> = OnceLock::new();
    144 #[inline]
    145 fn init_jwt_login_issuer() {
    146     JWT_LOGIN_ISSUER
    147         .set(format!("{}|login", config::get_config().domain_url()))
    148         .expect("JWT_LOGIN_ISSUER must only be initialized once");
    149 }
    150 #[inline]
    151 pub fn get_jwt_login_issuer() -> &'static str {
    152     JWT_LOGIN_ISSUER
    153         .get()
    154         .expect("JWT_LOGIN_ISSUER must be initialized in main")
    155         .as_str()
    156 }
    157 static JWT_INVITE_ISSUER: OnceLock<String> = OnceLock::new();
    158 #[inline]
    159 fn init_jwt_invite_issuer() {
    160     JWT_INVITE_ISSUER
    161         .set(format!("{}|invite", config::get_config().domain_url()))
    162         .expect("JWT_INVITE_ISSUER must only be initialized once");
    163 }
    164 #[inline]
    165 fn get_jwt_invite_issuer() -> &'static str {
    166     JWT_INVITE_ISSUER
    167         .get()
    168         .expect("JWT_INVITE_ISSUER must be initialized in main")
    169         .as_str()
    170 }
    171 static JWT_DELETE_ISSUER: OnceLock<String> = OnceLock::new();
    172 #[inline]
    173 fn init_jwt_delete_issuer() {
    174     JWT_DELETE_ISSUER
    175         .set(format!("{}|delete", config::get_config().domain_url()))
    176         .expect("JWT_DELETE_ISSUER must only be initialized once");
    177 }
    178 #[inline]
    179 fn get_jwt_delete_issuer() -> &'static str {
    180     JWT_DELETE_ISSUER
    181         .get()
    182         .expect("JWT_DELETE_ISSUER must be initialized in main")
    183         .as_str()
    184 }
    185 const JWT_ALGORITHM: Algorithm = Algorithm::EdDSA;
    186 static ED_KEYS: OnceLock<(EncodingKey, DecodingKey)> = OnceLock::new();
    187 #[allow(clippy::map_err_ignore, clippy::verbose_file_reads)]
    188 #[inline]
    189 fn init_ed_keys() -> Result<(), Error> {
    190     let mut file = File::options()
    191         .create(true)
    192         .read(true)
    193         .truncate(false)
    194         .write(true)
    195         .open(Config::PRIVATE_ED25519_KEY)?;
    196     let mut priv_pem = Vec::with_capacity(128);
    197     let ed_key = if file.read_to_end(&mut priv_pem)? == 0 {
    198         let ed_key = PKey::generate_ed25519()?;
    199         priv_pem = ed_key.private_key_to_pem_pkcs8()?;
    200         file.write_all(priv_pem.as_slice())?;
    201         ed_key
    202     } else {
    203         let ed_key = PKey::private_key_from_pem(priv_pem.as_slice())?;
    204         if ed_key.id() == Id::ED25519 {
    205             ed_key
    206         } else {
    207             let msg = format!(
    208                 "{} is not a private Ed25519 key",
    209                 Config::PRIVATE_ED25519_KEY
    210             );
    211             return Err(Error::new(msg.as_str(), msg.as_str()));
    212         }
    213     };
    214     ED_KEYS
    215         .set((
    216             EncodingKey::from_ed_pem(priv_pem.as_slice())?,
    217             DecodingKey::from_ed_pem(ed_key.public_key_to_pem()?.as_slice())?,
    218         ))
    219         .map_err(|_| {
    220             const MSG: &str = "ED_KEYS must only be initialized once";
    221             Error::new(MSG, MSG)
    222         })
    223 }
    224 #[inline]
    225 fn get_private_ed_key() -> &'static EncodingKey {
    226     &ED_KEYS
    227         .get()
    228         .expect("ED_KEYS must be initialized in main")
    229         .0
    230 }
    231 #[inline]
    232 fn get_public_ed_key() -> &'static DecodingKey {
    233     &ED_KEYS
    234         .get()
    235         .expect("ED_KEYS must be initialized in main")
    236         .1
    237 }
    238 #[inline]
    239 pub fn init_values() {
    240     init_allowed_origins();
    241     init_reg_ceremonies();
    242     init_reg_options();
    243     init_auth_ceremonies();
    244     init_auth_options();
    245     init_default_validity();
    246     init_jwt_header();
    247     init_jwt_login_issuer();
    248     init_jwt_invite_issuer();
    249     init_jwt_delete_issuer();
    250     init_ed_keys().expect("error creating Ed25519 keys");
    251 }
    252 pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
    253     match jsonwebtoken::encode(get_jwt_header(), claims, get_private_ed_key()) {
    254         Ok(token) => token,
    255         Err(e) => panic!("Error encoding jwt {e}"),
    256     }
    257 }
    258 
    259 #[allow(clippy::match_same_arms)]
    260 fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Error> {
    261     let mut validation = jsonwebtoken::Validation::new(JWT_ALGORITHM);
    262     validation.leeway = 30; // 30 seconds
    263     validation.validate_exp = true;
    264     validation.validate_nbf = true;
    265     validation.set_issuer(&[issuer]);
    266     let token = token.replace(char::is_whitespace, "");
    267     match jsonwebtoken::decode(&token, get_public_ed_key(), &validation) {
    268         Ok(d) => Ok(d.claims),
    269         Err(err) => match *err.kind() {
    270             ErrorKind::InvalidToken => err!("Token is invalid"),
    271             ErrorKind::InvalidIssuer => err!("Issuer is invalid"),
    272             ErrorKind::ExpiredSignature => err!("Token has expired"),
    273             ErrorKind::InvalidSignature
    274             | ErrorKind::InvalidEcdsaKey
    275             | ErrorKind::InvalidRsaKey(_)
    276             | ErrorKind::RsaFailedSigning
    277             | ErrorKind::InvalidAlgorithmName
    278             | ErrorKind::InvalidKeyFormat
    279             | ErrorKind::MissingRequiredClaim(_)
    280             | ErrorKind::InvalidAudience
    281             | ErrorKind::InvalidSubject
    282             | ErrorKind::ImmatureSignature
    283             | ErrorKind::InvalidAlgorithm
    284             | ErrorKind::MissingAlgorithm
    285             | ErrorKind::Base64(_)
    286             | ErrorKind::Json(_)
    287             | ErrorKind::Utf8(_)
    288             | ErrorKind::Crypto(_) => err!("Error decoding JWT"),
    289             _ => err!("Error decoding JWT"),
    290         },
    291     }
    292 }
    293 pub fn decode_login(token: &str) -> Result<LoginJwtClaims, Error> {
    294     decode_jwt(token, get_jwt_login_issuer().to_owned())
    295 }
    296 pub fn decode_invite(token: &str) -> Result<InviteJwtClaims, Error> {
    297     decode_jwt(token, get_jwt_invite_issuer().to_owned())
    298 }
    299 pub fn decode_delete(token: &str) -> Result<BasicJwtClaims, Error> {
    300     decode_jwt(token, get_jwt_delete_issuer().to_owned())
    301 }
    302 
    303 #[derive(Serialize, Deserialize)]
    304 pub struct LoginJwtClaims {
    305     // Not before
    306     pub nbf: i64,
    307     // Expiration time
    308     pub exp: i64,
    309     // Issuer
    310     pub iss: String,
    311     // Subject
    312     pub sub: String,
    313     pub premium: bool,
    314     pub name: String,
    315     pub email: String,
    316     pub email_verified: bool,
    317     // ---
    318     // Disabled these keys to be added to the JWT since they could cause the JWT to get too large
    319     // Also These key/value pairs are not used anywhere by either Vaultwarden or Bitwarden Clients
    320     // Because these might get used in the future, and they are added by the Bitwarden Server, lets keep it, but then commented out
    321     // See: https://github.com/dani-garcia/vaultwarden/issues/4156
    322     // ---
    323     // pub orgowner: Vec<String>,
    324     // pub orgadmin: Vec<String>,
    325     // pub orguser: Vec<String>,
    326     // pub orgmanager: Vec<String>,
    327     // user security_stamp
    328     pub sstamp: String,
    329     // device uuid
    330     pub device: String,
    331     // [ "api", "offline_access" ]
    332     pub scope: Vec<String>,
    333     // [ "Application" ]
    334     pub amr: Vec<String>,
    335 }
    336 
    337 #[derive(Serialize, Deserialize)]
    338 pub struct InviteJwtClaims {
    339     // Not before
    340     nbf: i64,
    341     // Expiration time
    342     exp: i64,
    343     // Issuer
    344     iss: String,
    345     // Subject
    346     sub: String,
    347     pub email: String,
    348     pub org_id: Option<String>,
    349     pub user_org_id: Option<String>,
    350     invited_by_email: Option<String>,
    351 }
    352 
    353 #[derive(Serialize, Deserialize)]
    354 pub struct FileDownloadClaims {
    355     // Not before
    356     nbf: i64,
    357     // Expiration time
    358     exp: i64,
    359     // Issuer
    360     iss: String,
    361     // Subject
    362     pub sub: String,
    363     pub file_id: String,
    364 }
    365 #[derive(Serialize, Deserialize)]
    366 pub struct BasicJwtClaims {
    367     // Not before
    368     nbf: i64,
    369     // Expiration time
    370     exp: i64,
    371     // Issuer
    372     iss: String,
    373     // Subject
    374     pub sub: String,
    375 }
    376 use crate::db::{
    377     models::{
    378         Collection, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException,
    379     },
    380     DbConn,
    381 };
    382 use rocket::{
    383     outcome::try_outcome,
    384     request::{FromRequest, Outcome, Request},
    385 };
    386 
    387 struct Host {
    388     host: String,
    389 }
    390 
    391 #[rocket::async_trait]
    392 impl<'r> FromRequest<'r> for Host {
    393     type Error = &'static str;
    394     async fn from_request(_: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    395         Outcome::Success(Self {
    396             host: config::get_config().domain_url().to_owned(),
    397         })
    398     }
    399 }
    400 pub struct ClientHeaders {
    401     #[allow(dead_code)]
    402     pub device_type: i32,
    403     pub ip: ClientIp,
    404 }
    405 
    406 #[rocket::async_trait]
    407 impl<'r> FromRequest<'r> for ClientHeaders {
    408     type Error = &'static str;
    409     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    410         let Outcome::Success(ip) = ClientIp::from_request(request).await else {
    411             err_handler!("Error getting Client IP")
    412         };
    413         // When unknown or unable to parse, return 14, which is 'Unknown Browser'
    414         let device_type: i32 = request
    415             .headers()
    416             .get_one("device-type")
    417             .map_or(14i32, |d| d.parse().unwrap_or(14i32));
    418 
    419         Outcome::Success(Self { device_type, ip })
    420     }
    421 }
    422 
    423 pub struct Headers {
    424     pub host: String,
    425     pub device: Device,
    426     pub user: User,
    427     pub ip: ClientIp,
    428 }
    429 #[allow(clippy::else_if_without_else)]
    430 #[rocket::async_trait]
    431 impl<'r> FromRequest<'r> for Headers {
    432     type Error = &'static str;
    433     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    434         let headers = request.headers();
    435         let host = try_outcome!(Host::from_request(request).await).host;
    436         let Outcome::Success(ip) = ClientIp::from_request(request).await else {
    437             err_handler!("Error getting Client IP")
    438         };
    439         // Get access_token
    440         let access_token: &str = match headers.get_one("Authorization") {
    441             Some(a) => match a.rsplit("Bearer ").next() {
    442                 Some(split) => split,
    443                 None => err_handler!("No access token provided"),
    444             },
    445             None => err_handler!("No access token provided"),
    446         };
    447         // Check JWT token is valid and get device and user from it
    448         let Ok(claims) = decode_login(access_token) else {
    449             err_handler!("Invalid claim")
    450         };
    451         let device_uuid = claims.device;
    452         let user_uuid = claims.sub;
    453         let Outcome::Success(conn) = DbConn::from_request(request).await else {
    454             err_handler!("Error getting DB")
    455         };
    456         let Some(device) = Device::find_by_uuid_and_user(&device_uuid, &user_uuid, &conn).await
    457         else {
    458             err_handler!("Invalid device id")
    459         };
    460         let Some(user) = User::find_by_uuid(&user_uuid, &conn).await else {
    461             err_handler!("Device has no user associated")
    462         };
    463         if user.security_stamp != claims.sstamp {
    464             if let Some(stamp_exception) = user
    465                 .stamp_exception
    466                 .as_deref()
    467                 .and_then(|s| serde_json::from_str::<UserStampException>(s).ok())
    468             {
    469                 let Some(current_route) = request.route().and_then(|r| r.name.as_deref()) else {
    470                     err_handler!("Error getting current route for stamp exception")
    471                 };
    472                 // Check if the stamp exception has expired first.
    473                 // Then, check if the current route matches any of the allowed routes.
    474                 // After that check the stamp in exception matches the one in the claims.
    475                 if u64::try_from(Utc::now().naive_utc().and_utc().timestamp()).expect("underflow")
    476                     > stamp_exception.expire
    477                 {
    478                     // If the stamp exception has been expired remove it from the database.
    479                     // This prevents checking this stamp exception for new requests.
    480                     let mut user = user;
    481                     user.reset_stamp_exception();
    482                     if let Err(e) = user.save(&conn).await {
    483                         error!("Error updating user: {:#?}", e);
    484                     }
    485                     err_handler!("Stamp exception is expired")
    486                 } else if !stamp_exception.routes.contains(&current_route.to_owned()) {
    487                     err_handler!(
    488                         "Invalid security stamp: Current route and exception route do not match"
    489                     )
    490                 } else if stamp_exception.security_stamp != claims.sstamp {
    491                     err_handler!("Invalid security stamp for matched stamp exception")
    492                 }
    493             } else {
    494                 err_handler!("Invalid security stamp")
    495             }
    496         }
    497         Outcome::Success(Self {
    498             host,
    499             device,
    500             user,
    501             ip,
    502         })
    503     }
    504 }
    505 
    506 pub struct OrgHeaders {
    507     host: String,
    508     device: Device,
    509     user: User,
    510     org_user_type: UserOrgType,
    511     org_user: UserOrganization,
    512     ip: ClientIp,
    513 }
    514 
    515 #[rocket::async_trait]
    516 impl<'r> FromRequest<'r> for OrgHeaders {
    517     type Error = &'static str;
    518     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    519         let headers = try_outcome!(Headers::from_request(request).await);
    520         // org_id is usually the second path param ("/organizations/<org_id>"),
    521         // but there are cases where it is a query value.
    522         // First check the path, if this is not a valid uuid, try the query values.
    523         let url_org_id: Option<&str> = {
    524             let mut url_org_id = None;
    525             if let Some(Ok(org_id)) = request.param::<&str>(1) {
    526                 if uuid::Uuid::parse_str(org_id).is_ok() {
    527                     url_org_id = Some(org_id);
    528                 }
    529             }
    530             if let Some(Ok(org_id)) = request.query_value::<&str>("organizationId") {
    531                 if uuid::Uuid::parse_str(org_id).is_ok() {
    532                     url_org_id = Some(org_id);
    533                 }
    534             }
    535             url_org_id
    536         };
    537         match url_org_id {
    538             Some(org_id) => {
    539                 let user = headers.user;
    540                 let org_user = match DbConn::from_request(request).await {
    541                     Outcome::Success(conn) => {
    542                         match UserOrganization::find_by_user_and_org(&user.uuid, org_id, &conn)
    543                             .await
    544                         {
    545                             Some(user) => {
    546                                 if user.status == i32::from(UserOrgStatus::Confirmed) {
    547                                     user
    548                                 } else {
    549                                     err_handler!(
    550                                     "The current user isn't confirmed member of the organization"
    551                                 )
    552                                 }
    553                             }
    554                             None => {
    555                                 err_handler!("The current user isn't member of the organization")
    556                             }
    557                         }
    558                     }
    559                     Outcome::Error(_) | Outcome::Forward(_) => err_handler!("Error getting DB"),
    560                 };
    561                 Outcome::Success(Self {
    562                     host: headers.host,
    563                     device: headers.device,
    564                     user,
    565                     org_user_type: {
    566                         if let Ok(org_usr_type) = UserOrgType::try_from(org_user.atype) {
    567                             org_usr_type
    568                         } else {
    569                             // This should only happen if the DB is corrupted
    570                             err_handler!("Unknown user type in the database")
    571                         }
    572                     },
    573                     org_user,
    574                     ip: headers.ip,
    575                 })
    576             }
    577             _ => err_handler!("Error getting the organization id"),
    578         }
    579     }
    580 }
    581 
    582 pub struct AdminHeaders {
    583     pub host: String,
    584     pub device: Device,
    585     pub user: User,
    586     pub org_user_type: UserOrgType,
    587     pub client_version: Option<String>,
    588     pub ip: ClientIp,
    589 }
    590 
    591 #[rocket::async_trait]
    592 impl<'r> FromRequest<'r> for AdminHeaders {
    593     type Error = &'static str;
    594     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    595         let headers = try_outcome!(OrgHeaders::from_request(request).await);
    596         let client_version = request
    597             .headers()
    598             .get_one("Bitwarden-Client-Version")
    599             .map(String::from);
    600         if headers.org_user_type >= UserOrgType::Admin {
    601             Outcome::Success(Self {
    602                 host: headers.host,
    603                 device: headers.device,
    604                 user: headers.user,
    605                 org_user_type: headers.org_user_type,
    606                 client_version,
    607                 ip: headers.ip,
    608             })
    609         } else {
    610             err_handler!("You need to be Admin or Owner to call this endpoint")
    611         }
    612     }
    613 }
    614 
    615 impl From<AdminHeaders> for Headers {
    616     fn from(h: AdminHeaders) -> Self {
    617         Self {
    618             host: h.host,
    619             device: h.device,
    620             user: h.user,
    621             ip: h.ip,
    622         }
    623     }
    624 }
    625 
    626 // col_id is usually the fourth path param ("/organizations/<org_id>/collections/<col_id>"),
    627 // but there could be cases where it is a query value.
    628 // First check the path, if this is not a valid uuid, try the query values.
    629 fn get_col_id(request: &Request<'_>) -> Option<String> {
    630     if let Some(Ok(col_id)) = request.param::<String>(3) {
    631         if uuid::Uuid::parse_str(&col_id).is_ok() {
    632             return Some(col_id);
    633         }
    634     }
    635     if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") {
    636         if uuid::Uuid::parse_str(&col_id).is_ok() {
    637             return Some(col_id);
    638         }
    639     }
    640     None
    641 }
    642 
    643 /// The ManagerHeaders are used to check if you are at least a Manager
    644 /// and have access to the specific collection provided via the <col_id>/collections/collectionId.
    645 /// This does strict checking on the collection_id, ManagerHeadersLoose does not.
    646 pub struct ManagerHeaders {
    647     host: String,
    648     device: Device,
    649     pub user: User,
    650     ip: ClientIp,
    651 }
    652 
    653 #[rocket::async_trait]
    654 impl<'r> FromRequest<'r> for ManagerHeaders {
    655     type Error = &'static str;
    656     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    657         let headers = try_outcome!(OrgHeaders::from_request(request).await);
    658         if headers.org_user_type >= UserOrgType::Manager {
    659             match get_col_id(request) {
    660                 Some(col_id) => {
    661                     let Outcome::Success(conn) = DbConn::from_request(request).await else {
    662                         err_handler!("Error getting DB")
    663                     };
    664 
    665                     if !Collection::can_access_collection(&headers.org_user, &col_id, &conn).await {
    666                         err_handler!("The current user isn't a manager for this collection")
    667                     }
    668                 }
    669                 _ => err_handler!("Error getting the collection id"),
    670             }
    671 
    672             Outcome::Success(Self {
    673                 host: headers.host,
    674                 device: headers.device,
    675                 user: headers.user,
    676                 ip: headers.ip,
    677             })
    678         } else {
    679             err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
    680         }
    681     }
    682 }
    683 
    684 impl From<ManagerHeaders> for Headers {
    685     fn from(h: ManagerHeaders) -> Self {
    686         Self {
    687             host: h.host,
    688             device: h.device,
    689             user: h.user,
    690             ip: h.ip,
    691         }
    692     }
    693 }
    694 
    695 /// The ManagerHeadersLoose is used when you at least need to be a Manager,
    696 /// but there is no collection_id sent with the request (either in the path or as form data).
    697 pub struct ManagerHeadersLoose {
    698     host: String,
    699     device: Device,
    700     pub user: User,
    701     pub org_user: UserOrganization,
    702     ip: ClientIp,
    703 }
    704 
    705 #[rocket::async_trait]
    706 impl<'r> FromRequest<'r> for ManagerHeadersLoose {
    707     type Error = &'static str;
    708     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    709         let headers = try_outcome!(OrgHeaders::from_request(request).await);
    710         if headers.org_user_type >= UserOrgType::Manager {
    711             Outcome::Success(Self {
    712                 host: headers.host,
    713                 device: headers.device,
    714                 user: headers.user,
    715                 org_user: headers.org_user,
    716                 ip: headers.ip,
    717             })
    718         } else {
    719             err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
    720         }
    721     }
    722 }
    723 
    724 impl From<ManagerHeadersLoose> for Headers {
    725     fn from(h: ManagerHeadersLoose) -> Self {
    726         Self {
    727             host: h.host,
    728             device: h.device,
    729             user: h.user,
    730             ip: h.ip,
    731         }
    732     }
    733 }
    734 
    735 impl ManagerHeaders {
    736     pub async fn from_loose(
    737         h: ManagerHeadersLoose,
    738         collections: &Vec<String>,
    739         conn: &DbConn,
    740     ) -> Result<Self, Error> {
    741         for col_id in collections {
    742             if uuid::Uuid::parse_str(col_id).is_err() {
    743                 err!("Collection Id is malformed!");
    744             }
    745             if !Collection::can_access_collection(&h.org_user, col_id, conn).await {
    746                 err!("You don't have access to all collections!");
    747             }
    748         }
    749         Ok(Self {
    750             host: h.host,
    751             device: h.device,
    752             user: h.user,
    753             ip: h.ip,
    754         })
    755     }
    756 }
    757 
    758 pub struct OwnerHeaders {
    759     #[allow(dead_code)]
    760     pub device: Device,
    761     pub user: User,
    762     #[allow(dead_code)]
    763     pub ip: ClientIp,
    764 }
    765 
    766 #[rocket::async_trait]
    767 impl<'r> FromRequest<'r> for OwnerHeaders {
    768     type Error = &'static str;
    769     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    770         let headers = try_outcome!(OrgHeaders::from_request(request).await);
    771         if headers.org_user_type == UserOrgType::Owner {
    772             Outcome::Success(Self {
    773                 device: headers.device,
    774                 user: headers.user,
    775                 ip: headers.ip,
    776             })
    777         } else {
    778             err_handler!("You need to be Owner to call this endpoint")
    779         }
    780     }
    781 }
    782 
    783 //
    784 // Client IP address detection
    785 //
    786 use std::net::IpAddr;
    787 #[derive(Clone, Copy)]
    788 pub struct ClientIp {
    789     pub ip: IpAddr,
    790 }
    791 
    792 #[rocket::async_trait]
    793 impl<'r> FromRequest<'r> for ClientIp {
    794     type Error = ();
    795     #[allow(clippy::map_err_ignore, clippy::string_slice)]
    796     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    797         let ip = request.headers().get_one("X-Real-IP").and_then(|ip| {
    798             ip.find(',')
    799                 .map_or(ip, |idx| &ip[..idx])
    800                 .parse()
    801                 .map_err(|_| warn!("'X-Real-IP' header is malformed: {ip}"))
    802                 .ok()
    803         });
    804         let ip = ip
    805             .or_else(|| request.remote().map(|r| r.ip()))
    806             .unwrap_or_else(|| "0.0.0.0".parse().unwrap());
    807         Outcome::Success(Self { ip })
    808     }
    809 }
    810 
    811 pub struct WsAccessTokenHeader {
    812     #[allow(dead_code)]
    813     pub access_token: Option<String>,
    814 }
    815 
    816 #[rocket::async_trait]
    817 impl<'r> FromRequest<'r> for WsAccessTokenHeader {
    818     type Error = ();
    819     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    820         let headers = request.headers();
    821         let access_token = headers
    822             .get_one("Authorization")
    823             .and_then(|a| a.rsplit("Bearer ").next().map(String::from));
    824         Outcome::Success(Self { access_token })
    825     }
    826 }