vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

auth.rs (27825B)


      1 use crate::{
      2     config::{self, Config},
      3     error::Error,
      4 };
      5 use chrono::{TimeDelta, Utc};
      6 use jsonwebtoken::{self, Algorithm, DecodingKey, EncodingKey, Header, errors::ErrorKind};
      7 use openssl::pkey::{Id, PKey};
      8 use serde::de::DeserializeOwned;
      9 use serde::ser::Serialize;
     10 use std::fs::File;
     11 use std::io::{Read as _, Write as _};
     12 use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
     13 use webauthn_rp::request::{
     14     BackupReq, BuildIdentityHasher, FixedCapHashSet,
     15     auth::{
     16         AuthenticationVerificationOptions, AuthenticatorAttachmentEnforcement,
     17         NonDiscoverableAuthenticationServerState, SignatureCounterEnforcement,
     18     },
     19     register::{RegistrationServerState, RegistrationVerificationOptions},
     20 };
     21 static ALLOWED_ORIGINS: OnceLock<[&str; 1]> = OnceLock::new();
     22 #[inline]
     23 fn init_allowed_origins() {
     24     ALLOWED_ORIGINS
     25         .set([config::get_config().domain_url()])
     26         .expect("ALLOWED_ORIGINS must only be initialized once");
     27 }
     28 #[inline]
     29 fn get_allowed_origins() -> &'static [&'static str] {
     30     ALLOWED_ORIGINS
     31         .get()
     32         .expect("ALLOWED_ORIGINS must be initialized in main")
     33         .as_slice()
     34 }
     35 static REG_CEREMONIES: OnceLock<
     36     Arc<Mutex<FixedCapHashSet<RegistrationServerState<16>, BuildIdentityHasher>>>,
     37 > = OnceLock::new();
     38 #[inline]
     39 fn init_reg_ceremonies() {
     40     REG_CEREMONIES
     41         .set(Arc::new(Mutex::new(FixedCapHashSet::new(10000))))
     42         .expect("REG_CEREMONIES must only be initialized once");
     43 }
     44 #[inline]
     45 pub fn get_reg_ceremonies()
     46 -> MutexGuard<'static, FixedCapHashSet<RegistrationServerState<16>, BuildIdentityHasher>> {
     47     REG_CEREMONIES
     48         .get()
     49         .expect("REG_CEREMONIES must be initialized in main")
     50         .lock()
     51         .unwrap()
     52 }
     53 static REG_OPTIONS: OnceLock<RegistrationVerificationOptions<'static, 'static, &str, &str>> =
     54     OnceLock::new();
     55 #[inline]
     56 fn init_reg_options() {
     57     REG_OPTIONS
     58         .set(RegistrationVerificationOptions {
     59             allowed_origins: get_allowed_origins(),
     60             allowed_top_origins: None,
     61             backup_requirement: BackupReq::None,
     62             error_on_unsolicited_extensions: true,
     63             require_authenticator_attachment: false,
     64             client_data_json_relaxed: true,
     65         })
     66         .expect("REG_OPTIONS must only be initialized once");
     67 }
     68 #[inline]
     69 pub fn get_reg_options()
     70 -> &'static RegistrationVerificationOptions<'static, 'static, &'static str, &'static str> {
     71     REG_OPTIONS
     72         .get()
     73         .expect("REG_OPTIONS must be initialized in main")
     74 }
     75 static AUTH_CEREMONIES: OnceLock<
     76     Arc<Mutex<FixedCapHashSet<NonDiscoverableAuthenticationServerState, BuildIdentityHasher>>>,
     77 > = OnceLock::new();
     78 #[inline]
     79 fn init_auth_ceremonies() {
     80     AUTH_CEREMONIES
     81         .set(Arc::new(Mutex::new(FixedCapHashSet::new(10000))))
     82         .expect("AUTH_CEREMONIES must only be initialized once");
     83 }
     84 #[inline]
     85 pub fn get_auth_ceremonies() -> MutexGuard<
     86     'static,
     87     FixedCapHashSet<NonDiscoverableAuthenticationServerState, BuildIdentityHasher>,
     88 > {
     89     AUTH_CEREMONIES
     90         .get()
     91         .expect("AUTH_CEREMONIES must be initialized in main")
     92         .lock()
     93         .unwrap()
     94 }
     95 static AUTH_OPTIONS: OnceLock<AuthenticationVerificationOptions<'static, 'static, &str, &str>> =
     96     OnceLock::new();
     97 #[inline]
     98 fn init_auth_options() {
     99     AUTH_OPTIONS
    100         .set(AuthenticationVerificationOptions {
    101             allowed_origins: get_allowed_origins(),
    102             allowed_top_origins: None,
    103             backup_requirement: None,
    104             error_on_unsolicited_extensions: true,
    105             auth_attachment_enforcement: AuthenticatorAttachmentEnforcement::Update(false),
    106             client_data_json_relaxed: true,
    107             update_uv: false,
    108             sig_counter_enforcement: SignatureCounterEnforcement::Fail,
    109         })
    110         .expect("AUTH_OPTIONS must only be initialized once");
    111 }
    112 #[inline]
    113 pub fn get_auth_options()
    114 -> &'static AuthenticationVerificationOptions<'static, 'static, &'static str, &'static str> {
    115     AUTH_OPTIONS
    116         .get()
    117         .expect("AUTH_OPTIONS must be initialized in main")
    118 }
    119 static DEFAULT_VALIDITY: OnceLock<TimeDelta> = OnceLock::new();
    120 #[inline]
    121 fn init_default_validity() {
    122     DEFAULT_VALIDITY
    123         .set(TimeDelta::try_hours(2).expect("TimeDelta::try_hours(2) should work"))
    124         .expect("DEFAULT_VALIDITY must only be initialized once");
    125 }
    126 #[inline]
    127 pub fn get_default_validity() -> &'static TimeDelta {
    128     DEFAULT_VALIDITY
    129         .get()
    130         .expect("DEFAULT_VALIDITY must be initialized in main")
    131 }
    132 static JWT_HEADER: OnceLock<Header> = OnceLock::new();
    133 #[inline]
    134 fn init_jwt_header() {
    135     JWT_HEADER
    136         .set(Header::new(JWT_ALGORITHM))
    137         .expect("JWT_HEADER must only be initialized once");
    138 }
    139 #[inline]
    140 fn get_jwt_header() -> &'static Header {
    141     JWT_HEADER
    142         .get()
    143         .expect("JWT_HEADER must be initialized in main")
    144 }
    145 static JWT_LOGIN_ISSUER: OnceLock<String> = OnceLock::new();
    146 #[inline]
    147 fn init_jwt_login_issuer() {
    148     JWT_LOGIN_ISSUER
    149         .set(format!("{}|login", config::get_config().domain_url()))
    150         .expect("JWT_LOGIN_ISSUER must only be initialized once");
    151 }
    152 #[inline]
    153 pub fn get_jwt_login_issuer() -> &'static str {
    154     JWT_LOGIN_ISSUER
    155         .get()
    156         .expect("JWT_LOGIN_ISSUER must be initialized in main")
    157         .as_str()
    158 }
    159 static JWT_INVITE_ISSUER: OnceLock<String> = OnceLock::new();
    160 #[inline]
    161 fn init_jwt_invite_issuer() {
    162     JWT_INVITE_ISSUER
    163         .set(format!("{}|invite", config::get_config().domain_url()))
    164         .expect("JWT_INVITE_ISSUER must only be initialized once");
    165 }
    166 #[inline]
    167 fn get_jwt_invite_issuer() -> &'static str {
    168     JWT_INVITE_ISSUER
    169         .get()
    170         .expect("JWT_INVITE_ISSUER must be initialized in main")
    171         .as_str()
    172 }
    173 static JWT_DELETE_ISSUER: OnceLock<String> = OnceLock::new();
    174 #[inline]
    175 fn init_jwt_delete_issuer() {
    176     JWT_DELETE_ISSUER
    177         .set(format!("{}|delete", config::get_config().domain_url()))
    178         .expect("JWT_DELETE_ISSUER must only be initialized once");
    179 }
    180 #[inline]
    181 fn get_jwt_delete_issuer() -> &'static str {
    182     JWT_DELETE_ISSUER
    183         .get()
    184         .expect("JWT_DELETE_ISSUER must be initialized in main")
    185         .as_str()
    186 }
    187 const JWT_ALGORITHM: Algorithm = Algorithm::EdDSA;
    188 static ED_KEYS: OnceLock<(EncodingKey, DecodingKey)> = OnceLock::new();
    189 #[allow(clippy::map_err_ignore, clippy::verbose_file_reads)]
    190 #[inline]
    191 fn init_ed_keys() -> Result<(), Error> {
    192     let mut file = File::options()
    193         .create(true)
    194         .read(true)
    195         .truncate(false)
    196         .write(true)
    197         .open(Config::PRIVATE_ED25519_KEY)?;
    198     let mut priv_pem = Vec::with_capacity(128);
    199     let ed_key = if file.read_to_end(&mut priv_pem)? == 0 {
    200         let ed_key = PKey::generate_ed25519()?;
    201         priv_pem = ed_key.private_key_to_pem_pkcs8()?;
    202         file.write_all(priv_pem.as_slice())?;
    203         ed_key
    204     } else {
    205         let ed_key = PKey::private_key_from_pem(priv_pem.as_slice())?;
    206         if ed_key.id() == Id::ED25519 {
    207             ed_key
    208         } else {
    209             let msg = format!(
    210                 "{} is not a private Ed25519 key",
    211                 Config::PRIVATE_ED25519_KEY
    212             );
    213             return Err(Error::new(msg.as_str(), msg.as_str()));
    214         }
    215     };
    216     ED_KEYS
    217         .set((
    218             EncodingKey::from_ed_pem(priv_pem.as_slice())?,
    219             DecodingKey::from_ed_pem(ed_key.public_key_to_pem()?.as_slice())?,
    220         ))
    221         .map_err(|_| {
    222             const MSG: &str = "ED_KEYS must only be initialized once";
    223             Error::new(MSG, MSG)
    224         })
    225 }
    226 #[inline]
    227 fn get_private_ed_key() -> &'static EncodingKey {
    228     &ED_KEYS
    229         .get()
    230         .expect("ED_KEYS must be initialized in main")
    231         .0
    232 }
    233 #[inline]
    234 fn get_public_ed_key() -> &'static DecodingKey {
    235     &ED_KEYS
    236         .get()
    237         .expect("ED_KEYS must be initialized in main")
    238         .1
    239 }
    240 #[inline]
    241 pub fn init_values() {
    242     init_allowed_origins();
    243     init_reg_ceremonies();
    244     init_reg_options();
    245     init_auth_ceremonies();
    246     init_auth_options();
    247     init_default_validity();
    248     init_jwt_header();
    249     init_jwt_login_issuer();
    250     init_jwt_invite_issuer();
    251     init_jwt_delete_issuer();
    252     init_ed_keys().expect("error creating Ed25519 keys");
    253 }
    254 pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
    255     match jsonwebtoken::encode(get_jwt_header(), claims, get_private_ed_key()) {
    256         Ok(token) => token,
    257         Err(e) => panic!("Error encoding jwt {e}"),
    258     }
    259 }
    260 
    261 #[allow(clippy::match_same_arms)]
    262 fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Error> {
    263     let mut validation = jsonwebtoken::Validation::new(JWT_ALGORITHM);
    264     validation.leeway = 30; // 30 seconds
    265     validation.validate_exp = true;
    266     validation.validate_nbf = true;
    267     validation.set_issuer(&[issuer]);
    268     let token = token.replace(char::is_whitespace, "");
    269     match jsonwebtoken::decode(&token, get_public_ed_key(), &validation) {
    270         Ok(d) => Ok(d.claims),
    271         Err(err) => match *err.kind() {
    272             ErrorKind::InvalidToken => err!("Token is invalid"),
    273             ErrorKind::InvalidIssuer => err!("Issuer is invalid"),
    274             ErrorKind::ExpiredSignature => err!("Token has expired"),
    275             ErrorKind::InvalidSignature
    276             | ErrorKind::InvalidEcdsaKey
    277             | ErrorKind::InvalidRsaKey(_)
    278             | ErrorKind::RsaFailedSigning
    279             | ErrorKind::InvalidAlgorithmName
    280             | ErrorKind::InvalidKeyFormat
    281             | ErrorKind::MissingRequiredClaim(_)
    282             | ErrorKind::InvalidAudience
    283             | ErrorKind::InvalidSubject
    284             | ErrorKind::ImmatureSignature
    285             | ErrorKind::InvalidAlgorithm
    286             | ErrorKind::MissingAlgorithm
    287             | ErrorKind::Base64(_)
    288             | ErrorKind::Json(_)
    289             | ErrorKind::Utf8(_)
    290             | ErrorKind::Crypto(_) => err!("Error decoding JWT"),
    291             _ => err!("Error decoding JWT"),
    292         },
    293     }
    294 }
    295 pub fn decode_login(token: &str) -> Result<LoginJwtClaims, Error> {
    296     decode_jwt(token, get_jwt_login_issuer().to_owned())
    297 }
    298 pub fn decode_invite(token: &str) -> Result<InviteJwtClaims, Error> {
    299     decode_jwt(token, get_jwt_invite_issuer().to_owned())
    300 }
    301 pub fn decode_delete(token: &str) -> Result<BasicJwtClaims, Error> {
    302     decode_jwt(token, get_jwt_delete_issuer().to_owned())
    303 }
    304 
    305 #[derive(Serialize, Deserialize)]
    306 pub struct LoginJwtClaims {
    307     // Not before
    308     pub nbf: i64,
    309     // Expiration time
    310     pub exp: i64,
    311     // Issuer
    312     pub iss: String,
    313     // Subject
    314     pub sub: String,
    315     pub premium: bool,
    316     pub name: String,
    317     pub email: String,
    318     pub email_verified: bool,
    319     // ---
    320     // Disabled these keys to be added to the JWT since they could cause the JWT to get too large
    321     // Also These key/value pairs are not used anywhere by either Vaultwarden or Bitwarden Clients
    322     // Because these might get used in the future, and they are added by the Bitwarden Server, lets keep it, but then commented out
    323     // See: https://github.com/dani-garcia/vaultwarden/issues/4156
    324     // ---
    325     // pub orgowner: Vec<String>,
    326     // pub orgadmin: Vec<String>,
    327     // pub orguser: Vec<String>,
    328     // pub orgmanager: Vec<String>,
    329     // user security_stamp
    330     pub sstamp: String,
    331     // device uuid
    332     pub device: String,
    333     // [ "api", "offline_access" ]
    334     pub scope: Vec<String>,
    335     // [ "Application" ]
    336     pub amr: Vec<String>,
    337 }
    338 
    339 #[derive(Serialize, Deserialize)]
    340 pub struct InviteJwtClaims {
    341     // Not before
    342     nbf: i64,
    343     // Expiration time
    344     exp: i64,
    345     // Issuer
    346     iss: String,
    347     // Subject
    348     sub: String,
    349     pub email: String,
    350     pub org_id: Option<String>,
    351     pub user_org_id: Option<String>,
    352     invited_by_email: Option<String>,
    353 }
    354 
    355 #[derive(Serialize, Deserialize)]
    356 pub struct FileDownloadClaims {
    357     // Not before
    358     nbf: i64,
    359     // Expiration time
    360     exp: i64,
    361     // Issuer
    362     iss: String,
    363     // Subject
    364     pub sub: String,
    365     pub file_id: String,
    366 }
    367 #[derive(Serialize, Deserialize)]
    368 pub struct BasicJwtClaims {
    369     // Not before
    370     nbf: i64,
    371     // Expiration time
    372     exp: i64,
    373     // Issuer
    374     iss: String,
    375     // Subject
    376     pub sub: String,
    377 }
    378 use crate::db::{
    379     DbConn,
    380     models::{
    381         Collection, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException,
    382     },
    383 };
    384 use rocket::{
    385     outcome::try_outcome,
    386     request::{FromRequest, Outcome, Request},
    387 };
    388 
    389 struct Host {
    390     host: String,
    391 }
    392 
    393 #[rocket::async_trait]
    394 impl<'r> FromRequest<'r> for Host {
    395     type Error = &'static str;
    396     async fn from_request(_: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    397         Outcome::Success(Self {
    398             host: config::get_config().domain_url().to_owned(),
    399         })
    400     }
    401 }
    402 pub struct ClientHeaders {
    403     #[allow(dead_code)]
    404     pub device_type: i32,
    405     pub ip: ClientIp,
    406 }
    407 
    408 #[rocket::async_trait]
    409 impl<'r> FromRequest<'r> for ClientHeaders {
    410     type Error = &'static str;
    411     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    412         let Outcome::Success(ip) = ClientIp::from_request(request).await else {
    413             err_handler!("Error getting Client IP")
    414         };
    415         // When unknown or unable to parse, return 14, which is 'Unknown Browser'
    416         let device_type: i32 = request
    417             .headers()
    418             .get_one("device-type")
    419             .map_or(14i32, |d| d.parse().unwrap_or(14i32));
    420 
    421         Outcome::Success(Self { device_type, ip })
    422     }
    423 }
    424 
    425 pub struct Headers {
    426     pub host: String,
    427     pub device: Device,
    428     pub user: User,
    429     pub ip: ClientIp,
    430 }
    431 #[allow(clippy::else_if_without_else)]
    432 #[rocket::async_trait]
    433 impl<'r> FromRequest<'r> for Headers {
    434     type Error = &'static str;
    435     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    436         let headers = request.headers();
    437         let host = try_outcome!(Host::from_request(request).await).host;
    438         let Outcome::Success(ip) = ClientIp::from_request(request).await else {
    439             err_handler!("Error getting Client IP")
    440         };
    441         // Get access_token
    442         let access_token: &str = match headers.get_one("Authorization") {
    443             Some(a) => match a.rsplit("Bearer ").next() {
    444                 Some(split) => split,
    445                 None => err_handler!("No access token provided"),
    446             },
    447             None => err_handler!("No access token provided"),
    448         };
    449         // Check JWT token is valid and get device and user from it
    450         let Ok(claims) = decode_login(access_token) else {
    451             err_handler!("Invalid claim")
    452         };
    453         let device_uuid = claims.device;
    454         let user_uuid = claims.sub;
    455         let Outcome::Success(conn) = DbConn::from_request(request).await else {
    456             err_handler!("Error getting DB")
    457         };
    458         let Some(device) = Device::find_by_uuid_and_user(&device_uuid, &user_uuid, &conn).await
    459         else {
    460             err_handler!("Invalid device id")
    461         };
    462         let Some(user) = User::find_by_uuid(&user_uuid, &conn).await else {
    463             err_handler!("Device has no user associated")
    464         };
    465         if user.security_stamp != claims.sstamp {
    466             if let Some(stamp_exception) = user
    467                 .stamp_exception
    468                 .as_deref()
    469                 .and_then(|s| serde_json::from_str::<UserStampException>(s).ok())
    470             {
    471                 let Some(current_route) = request.route().and_then(|r| r.name.as_deref()) else {
    472                     err_handler!("Error getting current route for stamp exception")
    473                 };
    474                 // Check if the stamp exception has expired first.
    475                 // Then, check if the current route matches any of the allowed routes.
    476                 // After that check the stamp in exception matches the one in the claims.
    477                 if u64::try_from(Utc::now().naive_utc().and_utc().timestamp()).expect("underflow")
    478                     > stamp_exception.expire
    479                 {
    480                     // If the stamp exception has been expired remove it from the database.
    481                     // This prevents checking this stamp exception for new requests.
    482                     let mut user = user;
    483                     user.reset_stamp_exception();
    484                     if let Err(e) = user.save(&conn).await {
    485                         error!("Error updating user: {:#?}", e);
    486                     }
    487                     err_handler!("Stamp exception is expired")
    488                 } else if !stamp_exception.routes.contains(&current_route.to_owned()) {
    489                     err_handler!(
    490                         "Invalid security stamp: Current route and exception route do not match"
    491                     )
    492                 } else if stamp_exception.security_stamp != claims.sstamp {
    493                     err_handler!("Invalid security stamp for matched stamp exception")
    494                 }
    495             } else {
    496                 err_handler!("Invalid security stamp")
    497             }
    498         }
    499         Outcome::Success(Self {
    500             host,
    501             device,
    502             user,
    503             ip,
    504         })
    505     }
    506 }
    507 
    508 pub struct OrgHeaders {
    509     host: String,
    510     device: Device,
    511     user: User,
    512     org_user_type: UserOrgType,
    513     org_user: UserOrganization,
    514     ip: ClientIp,
    515 }
    516 
    517 #[rocket::async_trait]
    518 impl<'r> FromRequest<'r> for OrgHeaders {
    519     type Error = &'static str;
    520     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    521         let headers = try_outcome!(Headers::from_request(request).await);
    522         // org_id is usually the second path param ("/organizations/<org_id>"),
    523         // but there are cases where it is a query value.
    524         // First check the path, if this is not a valid uuid, try the query values.
    525         let url_org_id: Option<&str> = {
    526             let mut url_org_id = None;
    527             if let Some(Ok(org_id)) = request.param::<&str>(1) {
    528                 if uuid::Uuid::parse_str(org_id).is_ok() {
    529                     url_org_id = Some(org_id);
    530                 }
    531             }
    532             if let Some(Ok(org_id)) = request.query_value::<&str>("organizationId") {
    533                 if uuid::Uuid::parse_str(org_id).is_ok() {
    534                     url_org_id = Some(org_id);
    535                 }
    536             }
    537             url_org_id
    538         };
    539         match url_org_id {
    540             Some(org_id) => {
    541                 let user = headers.user;
    542                 let org_user = match DbConn::from_request(request).await {
    543                     Outcome::Success(conn) => {
    544                         match UserOrganization::find_by_user_and_org(&user.uuid, org_id, &conn)
    545                             .await
    546                         {
    547                             Some(user) => {
    548                                 if user.status == i32::from(UserOrgStatus::Confirmed) {
    549                                     user
    550                                 } else {
    551                                     err_handler!(
    552                                         "The current user isn't confirmed member of the organization"
    553                                     )
    554                                 }
    555                             }
    556                             None => {
    557                                 err_handler!("The current user isn't member of the organization")
    558                             }
    559                         }
    560                     }
    561                     Outcome::Error(_) | Outcome::Forward(_) => err_handler!("Error getting DB"),
    562                 };
    563                 Outcome::Success(Self {
    564                     host: headers.host,
    565                     device: headers.device,
    566                     user,
    567                     org_user_type: {
    568                         if let Ok(org_usr_type) = UserOrgType::try_from(org_user.atype) {
    569                             org_usr_type
    570                         } else {
    571                             // This should only happen if the DB is corrupted
    572                             err_handler!("Unknown user type in the database")
    573                         }
    574                     },
    575                     org_user,
    576                     ip: headers.ip,
    577                 })
    578             }
    579             _ => err_handler!("Error getting the organization id"),
    580         }
    581     }
    582 }
    583 
    584 pub struct AdminHeaders {
    585     pub host: String,
    586     pub device: Device,
    587     pub user: User,
    588     pub org_user_type: UserOrgType,
    589     pub client_version: Option<String>,
    590     pub ip: ClientIp,
    591 }
    592 
    593 #[rocket::async_trait]
    594 impl<'r> FromRequest<'r> for AdminHeaders {
    595     type Error = &'static str;
    596     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    597         let headers = try_outcome!(OrgHeaders::from_request(request).await);
    598         let client_version = request
    599             .headers()
    600             .get_one("Bitwarden-Client-Version")
    601             .map(String::from);
    602         if headers.org_user_type >= UserOrgType::Admin {
    603             Outcome::Success(Self {
    604                 host: headers.host,
    605                 device: headers.device,
    606                 user: headers.user,
    607                 org_user_type: headers.org_user_type,
    608                 client_version,
    609                 ip: headers.ip,
    610             })
    611         } else {
    612             err_handler!("You need to be Admin or Owner to call this endpoint")
    613         }
    614     }
    615 }
    616 
    617 impl From<AdminHeaders> for Headers {
    618     fn from(h: AdminHeaders) -> Self {
    619         Self {
    620             host: h.host,
    621             device: h.device,
    622             user: h.user,
    623             ip: h.ip,
    624         }
    625     }
    626 }
    627 
    628 // col_id is usually the fourth path param ("/organizations/<org_id>/collections/<col_id>"),
    629 // but there could be cases where it is a query value.
    630 // First check the path, if this is not a valid uuid, try the query values.
    631 fn get_col_id(request: &Request<'_>) -> Option<String> {
    632     if let Some(Ok(col_id)) = request.param::<String>(3) {
    633         if uuid::Uuid::parse_str(&col_id).is_ok() {
    634             return Some(col_id);
    635         }
    636     }
    637     if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") {
    638         if uuid::Uuid::parse_str(&col_id).is_ok() {
    639             return Some(col_id);
    640         }
    641     }
    642     None
    643 }
    644 
    645 /// The ManagerHeaders are used to check if you are at least a Manager
    646 /// and have access to the specific collection provided via the <col_id>/collections/collectionId.
    647 /// This does strict checking on the collection_id, ManagerHeadersLoose does not.
    648 pub struct ManagerHeaders {
    649     host: String,
    650     device: Device,
    651     pub user: User,
    652     ip: ClientIp,
    653 }
    654 
    655 #[rocket::async_trait]
    656 impl<'r> FromRequest<'r> for ManagerHeaders {
    657     type Error = &'static str;
    658     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    659         let headers = try_outcome!(OrgHeaders::from_request(request).await);
    660         if headers.org_user_type >= UserOrgType::Manager {
    661             match get_col_id(request) {
    662                 Some(col_id) => {
    663                     let Outcome::Success(conn) = DbConn::from_request(request).await else {
    664                         err_handler!("Error getting DB")
    665                     };
    666 
    667                     if !Collection::can_access_collection(&headers.org_user, &col_id, &conn).await {
    668                         err_handler!("The current user isn't a manager for this collection")
    669                     }
    670                 }
    671                 _ => err_handler!("Error getting the collection id"),
    672             }
    673 
    674             Outcome::Success(Self {
    675                 host: headers.host,
    676                 device: headers.device,
    677                 user: headers.user,
    678                 ip: headers.ip,
    679             })
    680         } else {
    681             err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
    682         }
    683     }
    684 }
    685 
    686 impl From<ManagerHeaders> for Headers {
    687     fn from(h: ManagerHeaders) -> Self {
    688         Self {
    689             host: h.host,
    690             device: h.device,
    691             user: h.user,
    692             ip: h.ip,
    693         }
    694     }
    695 }
    696 
    697 /// The ManagerHeadersLoose is used when you at least need to be a Manager,
    698 /// but there is no collection_id sent with the request (either in the path or as form data).
    699 pub struct ManagerHeadersLoose {
    700     host: String,
    701     device: Device,
    702     pub user: User,
    703     pub org_user: UserOrganization,
    704     ip: ClientIp,
    705 }
    706 
    707 #[rocket::async_trait]
    708 impl<'r> FromRequest<'r> for ManagerHeadersLoose {
    709     type Error = &'static str;
    710     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    711         let headers = try_outcome!(OrgHeaders::from_request(request).await);
    712         if headers.org_user_type >= UserOrgType::Manager {
    713             Outcome::Success(Self {
    714                 host: headers.host,
    715                 device: headers.device,
    716                 user: headers.user,
    717                 org_user: headers.org_user,
    718                 ip: headers.ip,
    719             })
    720         } else {
    721             err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
    722         }
    723     }
    724 }
    725 
    726 impl From<ManagerHeadersLoose> for Headers {
    727     fn from(h: ManagerHeadersLoose) -> Self {
    728         Self {
    729             host: h.host,
    730             device: h.device,
    731             user: h.user,
    732             ip: h.ip,
    733         }
    734     }
    735 }
    736 
    737 impl ManagerHeaders {
    738     pub async fn from_loose(
    739         h: ManagerHeadersLoose,
    740         collections: &Vec<String>,
    741         conn: &DbConn,
    742     ) -> Result<Self, Error> {
    743         for col_id in collections {
    744             if uuid::Uuid::parse_str(col_id).is_err() {
    745                 err!("Collection Id is malformed!");
    746             }
    747             if !Collection::can_access_collection(&h.org_user, col_id, conn).await {
    748                 err!("You don't have access to all collections!");
    749             }
    750         }
    751         Ok(Self {
    752             host: h.host,
    753             device: h.device,
    754             user: h.user,
    755             ip: h.ip,
    756         })
    757     }
    758 }
    759 
    760 pub struct OwnerHeaders {
    761     #[allow(dead_code)]
    762     pub device: Device,
    763     pub user: User,
    764     #[allow(dead_code)]
    765     pub ip: ClientIp,
    766 }
    767 
    768 #[rocket::async_trait]
    769 impl<'r> FromRequest<'r> for OwnerHeaders {
    770     type Error = &'static str;
    771     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    772         let headers = try_outcome!(OrgHeaders::from_request(request).await);
    773         if headers.org_user_type == UserOrgType::Owner {
    774             Outcome::Success(Self {
    775                 device: headers.device,
    776                 user: headers.user,
    777                 ip: headers.ip,
    778             })
    779         } else {
    780             err_handler!("You need to be Owner to call this endpoint")
    781         }
    782     }
    783 }
    784 
    785 //
    786 // Client IP address detection
    787 //
    788 use std::net::IpAddr;
    789 #[derive(Clone, Copy)]
    790 pub struct ClientIp {
    791     pub ip: IpAddr,
    792 }
    793 
    794 #[rocket::async_trait]
    795 impl<'r> FromRequest<'r> for ClientIp {
    796     type Error = ();
    797     #[allow(clippy::map_err_ignore, clippy::string_slice)]
    798     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    799         let ip = request.headers().get_one("X-Real-IP").and_then(|ip| {
    800             ip.find(',')
    801                 .map_or(ip, |idx| &ip[..idx])
    802                 .parse()
    803                 .map_err(|_| warn!("'X-Real-IP' header is malformed: {ip}"))
    804                 .ok()
    805         });
    806         let ip = ip
    807             .or_else(|| request.remote().map(|r| r.ip()))
    808             .unwrap_or_else(|| "0.0.0.0".parse().unwrap());
    809         Outcome::Success(Self { ip })
    810     }
    811 }
    812 
    813 pub struct WsAccessTokenHeader {
    814     #[allow(dead_code)]
    815     pub access_token: Option<String>,
    816 }
    817 
    818 #[rocket::async_trait]
    819 impl<'r> FromRequest<'r> for WsAccessTokenHeader {
    820     type Error = ();
    821     async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
    822         let headers = request.headers();
    823         let access_token = headers
    824             .get_one("Authorization")
    825             .and_then(|a| a.rsplit("Bearer ").next().map(String::from));
    826         Outcome::Success(Self { access_token })
    827     }
    828 }