auth.rs (27747B)
1 use crate::{ 2 config::{self, Config}, 3 error::Error, 4 }; 5 use chrono::{TimeDelta, Utc}; 6 use jsonwebtoken::{self, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header}; 7 use openssl::pkey::{Id, PKey}; 8 use serde::de::DeserializeOwned; 9 use serde::ser::Serialize; 10 use std::fs::File; 11 use std::io::{Read, Write}; 12 use std::sync::{Arc, Mutex, MutexGuard, OnceLock}; 13 use webauthn_rp::request::{ 14 auth::{ 15 AuthenticationServerState, AuthenticationVerificationOptions, 16 AuthenticatorAttachmentEnforcement, SignatureCounterEnforcement, 17 }, 18 register::{RegistrationServerState, RegistrationVerificationOptions}, 19 BackupReq, BuildIdentityHasher, FixedCapHashSet, 20 }; 21 static ALLOWED_ORIGINS: OnceLock<[&str; 1]> = OnceLock::new(); 22 #[inline] 23 fn init_allowed_origins() { 24 ALLOWED_ORIGINS 25 .set([config::get_config().domain_url()]) 26 .expect("ALLOWED_ORIGINS must only be initialized once"); 27 } 28 #[inline] 29 fn get_allowed_origins() -> &'static [&'static str] { 30 ALLOWED_ORIGINS 31 .get() 32 .expect("ALLOWED_ORIGINS must be initialized in main") 33 .as_slice() 34 } 35 static REG_CEREMONIES: OnceLock< 36 Arc<Mutex<FixedCapHashSet<RegistrationServerState, BuildIdentityHasher>>>, 37 > = OnceLock::new(); 38 #[inline] 39 fn init_reg_ceremonies() { 40 REG_CEREMONIES 41 .set(Arc::new(Mutex::new(FixedCapHashSet::new(10000)))) 42 .expect("REG_CEREMONIES must only be initialized once"); 43 } 44 #[inline] 45 pub fn get_reg_ceremonies( 46 ) -> MutexGuard<'static, FixedCapHashSet<RegistrationServerState, BuildIdentityHasher>> { 47 REG_CEREMONIES 48 .get() 49 .expect("REG_CEREMONIES must be initialized in main") 50 .lock() 51 .unwrap() 52 } 53 static REG_OPTIONS: OnceLock<RegistrationVerificationOptions<'static, 'static, &str, &str>> = 54 OnceLock::new(); 55 #[inline] 56 fn init_reg_options() { 57 REG_OPTIONS 58 .set(RegistrationVerificationOptions { 59 allowed_origins: get_allowed_origins(), 60 allowed_top_origins: None, 61 backup_requirement: BackupReq::None, 62 error_on_unsolicited_extensions: true, 63 require_authenticator_attachment: false, 64 client_data_json_relaxed: true, 65 }) 66 .expect("REG_OPTIONS must only be initialized once"); 67 } 68 #[inline] 69 pub fn get_reg_options( 70 ) -> &'static RegistrationVerificationOptions<'static, 'static, &'static str, &'static str> { 71 REG_OPTIONS 72 .get() 73 .expect("REG_OPTIONS must be initialized in main") 74 } 75 static AUTH_CEREMONIES: OnceLock< 76 Arc<Mutex<FixedCapHashSet<AuthenticationServerState, BuildIdentityHasher>>>, 77 > = OnceLock::new(); 78 #[inline] 79 fn init_auth_ceremonies() { 80 AUTH_CEREMONIES 81 .set(Arc::new(Mutex::new(FixedCapHashSet::new(10000)))) 82 .expect("AUTH_CEREMONIES must only be initialized once"); 83 } 84 #[inline] 85 pub fn get_auth_ceremonies( 86 ) -> MutexGuard<'static, FixedCapHashSet<AuthenticationServerState, BuildIdentityHasher>> { 87 AUTH_CEREMONIES 88 .get() 89 .expect("AUTH_CEREMONIES must be initialized in main") 90 .lock() 91 .unwrap() 92 } 93 static AUTH_OPTIONS: OnceLock<AuthenticationVerificationOptions<'static, 'static, &str, &str>> = 94 OnceLock::new(); 95 #[inline] 96 fn init_auth_options() { 97 AUTH_OPTIONS 98 .set(AuthenticationVerificationOptions { 99 allowed_origins: get_allowed_origins(), 100 allowed_top_origins: None, 101 backup_requirement: None, 102 error_on_unsolicited_extensions: true, 103 auth_attachment_enforcement: AuthenticatorAttachmentEnforcement::Update(false), 104 client_data_json_relaxed: true, 105 update_uv: false, 106 sig_counter_enforcement: SignatureCounterEnforcement::Fail, 107 }) 108 .expect("AUTH_OPTIONS must only be initialized once"); 109 } 110 #[inline] 111 pub fn get_auth_options( 112 ) -> &'static AuthenticationVerificationOptions<'static, 'static, &'static str, &'static str> { 113 AUTH_OPTIONS 114 .get() 115 .expect("AUTH_OPTIONS must be initialized in main") 116 } 117 static DEFAULT_VALIDITY: OnceLock<TimeDelta> = OnceLock::new(); 118 #[inline] 119 fn init_default_validity() { 120 DEFAULT_VALIDITY 121 .set(TimeDelta::try_hours(2).expect("TimeDelta::try_hours(2) should work")) 122 .expect("DEFAULT_VALIDITY must only be initialized once"); 123 } 124 #[inline] 125 pub fn get_default_validity() -> &'static TimeDelta { 126 DEFAULT_VALIDITY 127 .get() 128 .expect("DEFAULT_VALIDITY must be initialized in main") 129 } 130 static JWT_HEADER: OnceLock<Header> = OnceLock::new(); 131 #[inline] 132 fn init_jwt_header() { 133 JWT_HEADER 134 .set(Header::new(JWT_ALGORITHM)) 135 .expect("JWT_HEADER must only be initialized once"); 136 } 137 #[inline] 138 fn get_jwt_header() -> &'static Header { 139 JWT_HEADER 140 .get() 141 .expect("JWT_HEADER must be initialized in main") 142 } 143 static JWT_LOGIN_ISSUER: OnceLock<String> = OnceLock::new(); 144 #[inline] 145 fn init_jwt_login_issuer() { 146 JWT_LOGIN_ISSUER 147 .set(format!("{}|login", config::get_config().domain_url())) 148 .expect("JWT_LOGIN_ISSUER must only be initialized once"); 149 } 150 #[inline] 151 pub fn get_jwt_login_issuer() -> &'static str { 152 JWT_LOGIN_ISSUER 153 .get() 154 .expect("JWT_LOGIN_ISSUER must be initialized in main") 155 .as_str() 156 } 157 static JWT_INVITE_ISSUER: OnceLock<String> = OnceLock::new(); 158 #[inline] 159 fn init_jwt_invite_issuer() { 160 JWT_INVITE_ISSUER 161 .set(format!("{}|invite", config::get_config().domain_url())) 162 .expect("JWT_INVITE_ISSUER must only be initialized once"); 163 } 164 #[inline] 165 fn get_jwt_invite_issuer() -> &'static str { 166 JWT_INVITE_ISSUER 167 .get() 168 .expect("JWT_INVITE_ISSUER must be initialized in main") 169 .as_str() 170 } 171 static JWT_DELETE_ISSUER: OnceLock<String> = OnceLock::new(); 172 #[inline] 173 fn init_jwt_delete_issuer() { 174 JWT_DELETE_ISSUER 175 .set(format!("{}|delete", config::get_config().domain_url())) 176 .expect("JWT_DELETE_ISSUER must only be initialized once"); 177 } 178 #[inline] 179 fn get_jwt_delete_issuer() -> &'static str { 180 JWT_DELETE_ISSUER 181 .get() 182 .expect("JWT_DELETE_ISSUER must be initialized in main") 183 .as_str() 184 } 185 const JWT_ALGORITHM: Algorithm = Algorithm::EdDSA; 186 static ED_KEYS: OnceLock<(EncodingKey, DecodingKey)> = OnceLock::new(); 187 #[allow(clippy::map_err_ignore, clippy::verbose_file_reads)] 188 #[inline] 189 fn init_ed_keys() -> Result<(), Error> { 190 let mut file = File::options() 191 .create(true) 192 .read(true) 193 .truncate(false) 194 .write(true) 195 .open(Config::PRIVATE_ED25519_KEY)?; 196 let mut priv_pem = Vec::with_capacity(128); 197 let ed_key = if file.read_to_end(&mut priv_pem)? == 0 { 198 let ed_key = PKey::generate_ed25519()?; 199 priv_pem = ed_key.private_key_to_pem_pkcs8()?; 200 file.write_all(priv_pem.as_slice())?; 201 ed_key 202 } else { 203 let ed_key = PKey::private_key_from_pem(priv_pem.as_slice())?; 204 if ed_key.id() == Id::ED25519 { 205 ed_key 206 } else { 207 let msg = format!( 208 "{} is not a private Ed25519 key", 209 Config::PRIVATE_ED25519_KEY 210 ); 211 return Err(Error::new(msg.as_str(), msg.as_str())); 212 } 213 }; 214 ED_KEYS 215 .set(( 216 EncodingKey::from_ed_pem(priv_pem.as_slice())?, 217 DecodingKey::from_ed_pem(ed_key.public_key_to_pem()?.as_slice())?, 218 )) 219 .map_err(|_| { 220 const MSG: &str = "ED_KEYS must only be initialized once"; 221 Error::new(MSG, MSG) 222 }) 223 } 224 #[inline] 225 fn get_private_ed_key() -> &'static EncodingKey { 226 &ED_KEYS 227 .get() 228 .expect("ED_KEYS must be initialized in main") 229 .0 230 } 231 #[inline] 232 fn get_public_ed_key() -> &'static DecodingKey { 233 &ED_KEYS 234 .get() 235 .expect("ED_KEYS must be initialized in main") 236 .1 237 } 238 #[inline] 239 pub fn init_values() { 240 init_allowed_origins(); 241 init_reg_ceremonies(); 242 init_reg_options(); 243 init_auth_ceremonies(); 244 init_auth_options(); 245 init_default_validity(); 246 init_jwt_header(); 247 init_jwt_login_issuer(); 248 init_jwt_invite_issuer(); 249 init_jwt_delete_issuer(); 250 init_ed_keys().expect("error creating Ed25519 keys"); 251 } 252 pub fn encode_jwt<T: Serialize>(claims: &T) -> String { 253 match jsonwebtoken::encode(get_jwt_header(), claims, get_private_ed_key()) { 254 Ok(token) => token, 255 Err(e) => panic!("Error encoding jwt {e}"), 256 } 257 } 258 259 #[allow(clippy::match_same_arms)] 260 fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Error> { 261 let mut validation = jsonwebtoken::Validation::new(JWT_ALGORITHM); 262 validation.leeway = 30; // 30 seconds 263 validation.validate_exp = true; 264 validation.validate_nbf = true; 265 validation.set_issuer(&[issuer]); 266 let token = token.replace(char::is_whitespace, ""); 267 match jsonwebtoken::decode(&token, get_public_ed_key(), &validation) { 268 Ok(d) => Ok(d.claims), 269 Err(err) => match *err.kind() { 270 ErrorKind::InvalidToken => err!("Token is invalid"), 271 ErrorKind::InvalidIssuer => err!("Issuer is invalid"), 272 ErrorKind::ExpiredSignature => err!("Token has expired"), 273 ErrorKind::InvalidSignature 274 | ErrorKind::InvalidEcdsaKey 275 | ErrorKind::InvalidRsaKey(_) 276 | ErrorKind::RsaFailedSigning 277 | ErrorKind::InvalidAlgorithmName 278 | ErrorKind::InvalidKeyFormat 279 | ErrorKind::MissingRequiredClaim(_) 280 | ErrorKind::InvalidAudience 281 | ErrorKind::InvalidSubject 282 | ErrorKind::ImmatureSignature 283 | ErrorKind::InvalidAlgorithm 284 | ErrorKind::MissingAlgorithm 285 | ErrorKind::Base64(_) 286 | ErrorKind::Json(_) 287 | ErrorKind::Utf8(_) 288 | ErrorKind::Crypto(_) => err!("Error decoding JWT"), 289 _ => err!("Error decoding JWT"), 290 }, 291 } 292 } 293 pub fn decode_login(token: &str) -> Result<LoginJwtClaims, Error> { 294 decode_jwt(token, get_jwt_login_issuer().to_owned()) 295 } 296 pub fn decode_invite(token: &str) -> Result<InviteJwtClaims, Error> { 297 decode_jwt(token, get_jwt_invite_issuer().to_owned()) 298 } 299 pub fn decode_delete(token: &str) -> Result<BasicJwtClaims, Error> { 300 decode_jwt(token, get_jwt_delete_issuer().to_owned()) 301 } 302 303 #[derive(Serialize, Deserialize)] 304 pub struct LoginJwtClaims { 305 // Not before 306 pub nbf: i64, 307 // Expiration time 308 pub exp: i64, 309 // Issuer 310 pub iss: String, 311 // Subject 312 pub sub: String, 313 pub premium: bool, 314 pub name: String, 315 pub email: String, 316 pub email_verified: bool, 317 // --- 318 // Disabled these keys to be added to the JWT since they could cause the JWT to get too large 319 // Also These key/value pairs are not used anywhere by either Vaultwarden or Bitwarden Clients 320 // Because these might get used in the future, and they are added by the Bitwarden Server, lets keep it, but then commented out 321 // See: https://github.com/dani-garcia/vaultwarden/issues/4156 322 // --- 323 // pub orgowner: Vec<String>, 324 // pub orgadmin: Vec<String>, 325 // pub orguser: Vec<String>, 326 // pub orgmanager: Vec<String>, 327 // user security_stamp 328 pub sstamp: String, 329 // device uuid 330 pub device: String, 331 // [ "api", "offline_access" ] 332 pub scope: Vec<String>, 333 // [ "Application" ] 334 pub amr: Vec<String>, 335 } 336 337 #[derive(Serialize, Deserialize)] 338 pub struct InviteJwtClaims { 339 // Not before 340 nbf: i64, 341 // Expiration time 342 exp: i64, 343 // Issuer 344 iss: String, 345 // Subject 346 sub: String, 347 pub email: String, 348 pub org_id: Option<String>, 349 pub user_org_id: Option<String>, 350 invited_by_email: Option<String>, 351 } 352 353 #[derive(Serialize, Deserialize)] 354 pub struct FileDownloadClaims { 355 // Not before 356 nbf: i64, 357 // Expiration time 358 exp: i64, 359 // Issuer 360 iss: String, 361 // Subject 362 pub sub: String, 363 pub file_id: String, 364 } 365 #[derive(Serialize, Deserialize)] 366 pub struct BasicJwtClaims { 367 // Not before 368 nbf: i64, 369 // Expiration time 370 exp: i64, 371 // Issuer 372 iss: String, 373 // Subject 374 pub sub: String, 375 } 376 use crate::db::{ 377 models::{ 378 Collection, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException, 379 }, 380 DbConn, 381 }; 382 use rocket::{ 383 outcome::try_outcome, 384 request::{FromRequest, Outcome, Request}, 385 }; 386 387 struct Host { 388 host: String, 389 } 390 391 #[rocket::async_trait] 392 impl<'r> FromRequest<'r> for Host { 393 type Error = &'static str; 394 async fn from_request(_: &'r Request<'_>) -> Outcome<Self, Self::Error> { 395 Outcome::Success(Self { 396 host: config::get_config().domain_url().to_owned(), 397 }) 398 } 399 } 400 pub struct ClientHeaders { 401 #[allow(dead_code)] 402 pub device_type: i32, 403 pub ip: ClientIp, 404 } 405 406 #[rocket::async_trait] 407 impl<'r> FromRequest<'r> for ClientHeaders { 408 type Error = &'static str; 409 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 410 let Outcome::Success(ip) = ClientIp::from_request(request).await else { 411 err_handler!("Error getting Client IP") 412 }; 413 // When unknown or unable to parse, return 14, which is 'Unknown Browser' 414 let device_type: i32 = request 415 .headers() 416 .get_one("device-type") 417 .map_or(14i32, |d| d.parse().unwrap_or(14i32)); 418 419 Outcome::Success(Self { device_type, ip }) 420 } 421 } 422 423 pub struct Headers { 424 pub host: String, 425 pub device: Device, 426 pub user: User, 427 pub ip: ClientIp, 428 } 429 #[allow(clippy::else_if_without_else)] 430 #[rocket::async_trait] 431 impl<'r> FromRequest<'r> for Headers { 432 type Error = &'static str; 433 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 434 let headers = request.headers(); 435 let host = try_outcome!(Host::from_request(request).await).host; 436 let Outcome::Success(ip) = ClientIp::from_request(request).await else { 437 err_handler!("Error getting Client IP") 438 }; 439 // Get access_token 440 let access_token: &str = match headers.get_one("Authorization") { 441 Some(a) => match a.rsplit("Bearer ").next() { 442 Some(split) => split, 443 None => err_handler!("No access token provided"), 444 }, 445 None => err_handler!("No access token provided"), 446 }; 447 // Check JWT token is valid and get device and user from it 448 let Ok(claims) = decode_login(access_token) else { 449 err_handler!("Invalid claim") 450 }; 451 let device_uuid = claims.device; 452 let user_uuid = claims.sub; 453 let Outcome::Success(conn) = DbConn::from_request(request).await else { 454 err_handler!("Error getting DB") 455 }; 456 let Some(device) = Device::find_by_uuid_and_user(&device_uuid, &user_uuid, &conn).await 457 else { 458 err_handler!("Invalid device id") 459 }; 460 let Some(user) = User::find_by_uuid(&user_uuid, &conn).await else { 461 err_handler!("Device has no user associated") 462 }; 463 if user.security_stamp != claims.sstamp { 464 if let Some(stamp_exception) = user 465 .stamp_exception 466 .as_deref() 467 .and_then(|s| serde_json::from_str::<UserStampException>(s).ok()) 468 { 469 let Some(current_route) = request.route().and_then(|r| r.name.as_deref()) else { 470 err_handler!("Error getting current route for stamp exception") 471 }; 472 // Check if the stamp exception has expired first. 473 // Then, check if the current route matches any of the allowed routes. 474 // After that check the stamp in exception matches the one in the claims. 475 if u64::try_from(Utc::now().naive_utc().and_utc().timestamp()).expect("underflow") 476 > stamp_exception.expire 477 { 478 // If the stamp exception has been expired remove it from the database. 479 // This prevents checking this stamp exception for new requests. 480 let mut user = user; 481 user.reset_stamp_exception(); 482 if let Err(e) = user.save(&conn).await { 483 error!("Error updating user: {:#?}", e); 484 } 485 err_handler!("Stamp exception is expired") 486 } else if !stamp_exception.routes.contains(¤t_route.to_owned()) { 487 err_handler!( 488 "Invalid security stamp: Current route and exception route do not match" 489 ) 490 } else if stamp_exception.security_stamp != claims.sstamp { 491 err_handler!("Invalid security stamp for matched stamp exception") 492 } 493 } else { 494 err_handler!("Invalid security stamp") 495 } 496 } 497 Outcome::Success(Self { 498 host, 499 device, 500 user, 501 ip, 502 }) 503 } 504 } 505 506 pub struct OrgHeaders { 507 host: String, 508 device: Device, 509 user: User, 510 org_user_type: UserOrgType, 511 org_user: UserOrganization, 512 ip: ClientIp, 513 } 514 515 #[rocket::async_trait] 516 impl<'r> FromRequest<'r> for OrgHeaders { 517 type Error = &'static str; 518 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 519 let headers = try_outcome!(Headers::from_request(request).await); 520 // org_id is usually the second path param ("/organizations/<org_id>"), 521 // but there are cases where it is a query value. 522 // First check the path, if this is not a valid uuid, try the query values. 523 let url_org_id: Option<&str> = { 524 let mut url_org_id = None; 525 if let Some(Ok(org_id)) = request.param::<&str>(1) { 526 if uuid::Uuid::parse_str(org_id).is_ok() { 527 url_org_id = Some(org_id); 528 } 529 } 530 if let Some(Ok(org_id)) = request.query_value::<&str>("organizationId") { 531 if uuid::Uuid::parse_str(org_id).is_ok() { 532 url_org_id = Some(org_id); 533 } 534 } 535 url_org_id 536 }; 537 match url_org_id { 538 Some(org_id) => { 539 let user = headers.user; 540 let org_user = match DbConn::from_request(request).await { 541 Outcome::Success(conn) => { 542 match UserOrganization::find_by_user_and_org(&user.uuid, org_id, &conn) 543 .await 544 { 545 Some(user) => { 546 if user.status == i32::from(UserOrgStatus::Confirmed) { 547 user 548 } else { 549 err_handler!( 550 "The current user isn't confirmed member of the organization" 551 ) 552 } 553 } 554 None => { 555 err_handler!("The current user isn't member of the organization") 556 } 557 } 558 } 559 Outcome::Error(_) | Outcome::Forward(_) => err_handler!("Error getting DB"), 560 }; 561 Outcome::Success(Self { 562 host: headers.host, 563 device: headers.device, 564 user, 565 org_user_type: { 566 if let Ok(org_usr_type) = UserOrgType::try_from(org_user.atype) { 567 org_usr_type 568 } else { 569 // This should only happen if the DB is corrupted 570 err_handler!("Unknown user type in the database") 571 } 572 }, 573 org_user, 574 ip: headers.ip, 575 }) 576 } 577 _ => err_handler!("Error getting the organization id"), 578 } 579 } 580 } 581 582 pub struct AdminHeaders { 583 pub host: String, 584 pub device: Device, 585 pub user: User, 586 pub org_user_type: UserOrgType, 587 pub client_version: Option<String>, 588 pub ip: ClientIp, 589 } 590 591 #[rocket::async_trait] 592 impl<'r> FromRequest<'r> for AdminHeaders { 593 type Error = &'static str; 594 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 595 let headers = try_outcome!(OrgHeaders::from_request(request).await); 596 let client_version = request 597 .headers() 598 .get_one("Bitwarden-Client-Version") 599 .map(String::from); 600 if headers.org_user_type >= UserOrgType::Admin { 601 Outcome::Success(Self { 602 host: headers.host, 603 device: headers.device, 604 user: headers.user, 605 org_user_type: headers.org_user_type, 606 client_version, 607 ip: headers.ip, 608 }) 609 } else { 610 err_handler!("You need to be Admin or Owner to call this endpoint") 611 } 612 } 613 } 614 615 impl From<AdminHeaders> for Headers { 616 fn from(h: AdminHeaders) -> Self { 617 Self { 618 host: h.host, 619 device: h.device, 620 user: h.user, 621 ip: h.ip, 622 } 623 } 624 } 625 626 // col_id is usually the fourth path param ("/organizations/<org_id>/collections/<col_id>"), 627 // but there could be cases where it is a query value. 628 // First check the path, if this is not a valid uuid, try the query values. 629 fn get_col_id(request: &Request<'_>) -> Option<String> { 630 if let Some(Ok(col_id)) = request.param::<String>(3) { 631 if uuid::Uuid::parse_str(&col_id).is_ok() { 632 return Some(col_id); 633 } 634 } 635 if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") { 636 if uuid::Uuid::parse_str(&col_id).is_ok() { 637 return Some(col_id); 638 } 639 } 640 None 641 } 642 643 /// The ManagerHeaders are used to check if you are at least a Manager 644 /// and have access to the specific collection provided via the <col_id>/collections/collectionId. 645 /// This does strict checking on the collection_id, ManagerHeadersLoose does not. 646 pub struct ManagerHeaders { 647 host: String, 648 device: Device, 649 pub user: User, 650 ip: ClientIp, 651 } 652 653 #[rocket::async_trait] 654 impl<'r> FromRequest<'r> for ManagerHeaders { 655 type Error = &'static str; 656 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 657 let headers = try_outcome!(OrgHeaders::from_request(request).await); 658 if headers.org_user_type >= UserOrgType::Manager { 659 match get_col_id(request) { 660 Some(col_id) => { 661 let Outcome::Success(conn) = DbConn::from_request(request).await else { 662 err_handler!("Error getting DB") 663 }; 664 665 if !Collection::can_access_collection(&headers.org_user, &col_id, &conn).await { 666 err_handler!("The current user isn't a manager for this collection") 667 } 668 } 669 _ => err_handler!("Error getting the collection id"), 670 } 671 672 Outcome::Success(Self { 673 host: headers.host, 674 device: headers.device, 675 user: headers.user, 676 ip: headers.ip, 677 }) 678 } else { 679 err_handler!("You need to be a Manager, Admin or Owner to call this endpoint") 680 } 681 } 682 } 683 684 impl From<ManagerHeaders> for Headers { 685 fn from(h: ManagerHeaders) -> Self { 686 Self { 687 host: h.host, 688 device: h.device, 689 user: h.user, 690 ip: h.ip, 691 } 692 } 693 } 694 695 /// The ManagerHeadersLoose is used when you at least need to be a Manager, 696 /// but there is no collection_id sent with the request (either in the path or as form data). 697 pub struct ManagerHeadersLoose { 698 host: String, 699 device: Device, 700 pub user: User, 701 pub org_user: UserOrganization, 702 ip: ClientIp, 703 } 704 705 #[rocket::async_trait] 706 impl<'r> FromRequest<'r> for ManagerHeadersLoose { 707 type Error = &'static str; 708 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 709 let headers = try_outcome!(OrgHeaders::from_request(request).await); 710 if headers.org_user_type >= UserOrgType::Manager { 711 Outcome::Success(Self { 712 host: headers.host, 713 device: headers.device, 714 user: headers.user, 715 org_user: headers.org_user, 716 ip: headers.ip, 717 }) 718 } else { 719 err_handler!("You need to be a Manager, Admin or Owner to call this endpoint") 720 } 721 } 722 } 723 724 impl From<ManagerHeadersLoose> for Headers { 725 fn from(h: ManagerHeadersLoose) -> Self { 726 Self { 727 host: h.host, 728 device: h.device, 729 user: h.user, 730 ip: h.ip, 731 } 732 } 733 } 734 735 impl ManagerHeaders { 736 pub async fn from_loose( 737 h: ManagerHeadersLoose, 738 collections: &Vec<String>, 739 conn: &DbConn, 740 ) -> Result<Self, Error> { 741 for col_id in collections { 742 if uuid::Uuid::parse_str(col_id).is_err() { 743 err!("Collection Id is malformed!"); 744 } 745 if !Collection::can_access_collection(&h.org_user, col_id, conn).await { 746 err!("You don't have access to all collections!"); 747 } 748 } 749 Ok(Self { 750 host: h.host, 751 device: h.device, 752 user: h.user, 753 ip: h.ip, 754 }) 755 } 756 } 757 758 pub struct OwnerHeaders { 759 #[allow(dead_code)] 760 pub device: Device, 761 pub user: User, 762 #[allow(dead_code)] 763 pub ip: ClientIp, 764 } 765 766 #[rocket::async_trait] 767 impl<'r> FromRequest<'r> for OwnerHeaders { 768 type Error = &'static str; 769 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 770 let headers = try_outcome!(OrgHeaders::from_request(request).await); 771 if headers.org_user_type == UserOrgType::Owner { 772 Outcome::Success(Self { 773 device: headers.device, 774 user: headers.user, 775 ip: headers.ip, 776 }) 777 } else { 778 err_handler!("You need to be Owner to call this endpoint") 779 } 780 } 781 } 782 783 // 784 // Client IP address detection 785 // 786 use std::net::IpAddr; 787 #[derive(Clone, Copy)] 788 pub struct ClientIp { 789 pub ip: IpAddr, 790 } 791 792 #[rocket::async_trait] 793 impl<'r> FromRequest<'r> for ClientIp { 794 type Error = (); 795 #[allow(clippy::map_err_ignore, clippy::string_slice)] 796 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 797 let ip = request.headers().get_one("X-Real-IP").and_then(|ip| { 798 ip.find(',') 799 .map_or(ip, |idx| &ip[..idx]) 800 .parse() 801 .map_err(|_| warn!("'X-Real-IP' header is malformed: {ip}")) 802 .ok() 803 }); 804 let ip = ip 805 .or_else(|| request.remote().map(|r| r.ip())) 806 .unwrap_or_else(|| "0.0.0.0".parse().unwrap()); 807 Outcome::Success(Self { ip }) 808 } 809 } 810 811 pub struct WsAccessTokenHeader { 812 #[allow(dead_code)] 813 pub access_token: Option<String>, 814 } 815 816 #[rocket::async_trait] 817 impl<'r> FromRequest<'r> for WsAccessTokenHeader { 818 type Error = (); 819 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 820 let headers = request.headers(); 821 let access_token = headers 822 .get_one("Authorization") 823 .and_then(|a| a.rsplit("Bearer ").next().map(String::from)); 824 Outcome::Success(Self { access_token }) 825 } 826 }