auth.rs (27825B)
1 use crate::{ 2 config::{self, Config}, 3 error::Error, 4 }; 5 use chrono::{TimeDelta, Utc}; 6 use jsonwebtoken::{self, Algorithm, DecodingKey, EncodingKey, Header, errors::ErrorKind}; 7 use openssl::pkey::{Id, PKey}; 8 use serde::de::DeserializeOwned; 9 use serde::ser::Serialize; 10 use std::fs::File; 11 use std::io::{Read as _, Write as _}; 12 use std::sync::{Arc, Mutex, MutexGuard, OnceLock}; 13 use webauthn_rp::request::{ 14 BackupReq, BuildIdentityHasher, FixedCapHashSet, 15 auth::{ 16 AuthenticationVerificationOptions, AuthenticatorAttachmentEnforcement, 17 NonDiscoverableAuthenticationServerState, SignatureCounterEnforcement, 18 }, 19 register::{RegistrationServerState, RegistrationVerificationOptions}, 20 }; 21 static ALLOWED_ORIGINS: OnceLock<[&str; 1]> = OnceLock::new(); 22 #[inline] 23 fn init_allowed_origins() { 24 ALLOWED_ORIGINS 25 .set([config::get_config().domain_url()]) 26 .expect("ALLOWED_ORIGINS must only be initialized once"); 27 } 28 #[inline] 29 fn get_allowed_origins() -> &'static [&'static str] { 30 ALLOWED_ORIGINS 31 .get() 32 .expect("ALLOWED_ORIGINS must be initialized in main") 33 .as_slice() 34 } 35 static REG_CEREMONIES: OnceLock< 36 Arc<Mutex<FixedCapHashSet<RegistrationServerState<16>, BuildIdentityHasher>>>, 37 > = OnceLock::new(); 38 #[inline] 39 fn init_reg_ceremonies() { 40 REG_CEREMONIES 41 .set(Arc::new(Mutex::new(FixedCapHashSet::new(10000)))) 42 .expect("REG_CEREMONIES must only be initialized once"); 43 } 44 #[inline] 45 pub fn get_reg_ceremonies() 46 -> MutexGuard<'static, FixedCapHashSet<RegistrationServerState<16>, BuildIdentityHasher>> { 47 REG_CEREMONIES 48 .get() 49 .expect("REG_CEREMONIES must be initialized in main") 50 .lock() 51 .unwrap() 52 } 53 static REG_OPTIONS: OnceLock<RegistrationVerificationOptions<'static, 'static, &str, &str>> = 54 OnceLock::new(); 55 #[inline] 56 fn init_reg_options() { 57 REG_OPTIONS 58 .set(RegistrationVerificationOptions { 59 allowed_origins: get_allowed_origins(), 60 allowed_top_origins: None, 61 backup_requirement: BackupReq::None, 62 error_on_unsolicited_extensions: true, 63 require_authenticator_attachment: false, 64 client_data_json_relaxed: true, 65 }) 66 .expect("REG_OPTIONS must only be initialized once"); 67 } 68 #[inline] 69 pub fn get_reg_options() 70 -> &'static RegistrationVerificationOptions<'static, 'static, &'static str, &'static str> { 71 REG_OPTIONS 72 .get() 73 .expect("REG_OPTIONS must be initialized in main") 74 } 75 static AUTH_CEREMONIES: OnceLock< 76 Arc<Mutex<FixedCapHashSet<NonDiscoverableAuthenticationServerState, BuildIdentityHasher>>>, 77 > = OnceLock::new(); 78 #[inline] 79 fn init_auth_ceremonies() { 80 AUTH_CEREMONIES 81 .set(Arc::new(Mutex::new(FixedCapHashSet::new(10000)))) 82 .expect("AUTH_CEREMONIES must only be initialized once"); 83 } 84 #[inline] 85 pub fn get_auth_ceremonies() -> MutexGuard< 86 'static, 87 FixedCapHashSet<NonDiscoverableAuthenticationServerState, BuildIdentityHasher>, 88 > { 89 AUTH_CEREMONIES 90 .get() 91 .expect("AUTH_CEREMONIES must be initialized in main") 92 .lock() 93 .unwrap() 94 } 95 static AUTH_OPTIONS: OnceLock<AuthenticationVerificationOptions<'static, 'static, &str, &str>> = 96 OnceLock::new(); 97 #[inline] 98 fn init_auth_options() { 99 AUTH_OPTIONS 100 .set(AuthenticationVerificationOptions { 101 allowed_origins: get_allowed_origins(), 102 allowed_top_origins: None, 103 backup_requirement: None, 104 error_on_unsolicited_extensions: true, 105 auth_attachment_enforcement: AuthenticatorAttachmentEnforcement::Update(false), 106 client_data_json_relaxed: true, 107 update_uv: false, 108 sig_counter_enforcement: SignatureCounterEnforcement::Fail, 109 }) 110 .expect("AUTH_OPTIONS must only be initialized once"); 111 } 112 #[inline] 113 pub fn get_auth_options() 114 -> &'static AuthenticationVerificationOptions<'static, 'static, &'static str, &'static str> { 115 AUTH_OPTIONS 116 .get() 117 .expect("AUTH_OPTIONS must be initialized in main") 118 } 119 static DEFAULT_VALIDITY: OnceLock<TimeDelta> = OnceLock::new(); 120 #[inline] 121 fn init_default_validity() { 122 DEFAULT_VALIDITY 123 .set(TimeDelta::try_hours(2).expect("TimeDelta::try_hours(2) should work")) 124 .expect("DEFAULT_VALIDITY must only be initialized once"); 125 } 126 #[inline] 127 pub fn get_default_validity() -> &'static TimeDelta { 128 DEFAULT_VALIDITY 129 .get() 130 .expect("DEFAULT_VALIDITY must be initialized in main") 131 } 132 static JWT_HEADER: OnceLock<Header> = OnceLock::new(); 133 #[inline] 134 fn init_jwt_header() { 135 JWT_HEADER 136 .set(Header::new(JWT_ALGORITHM)) 137 .expect("JWT_HEADER must only be initialized once"); 138 } 139 #[inline] 140 fn get_jwt_header() -> &'static Header { 141 JWT_HEADER 142 .get() 143 .expect("JWT_HEADER must be initialized in main") 144 } 145 static JWT_LOGIN_ISSUER: OnceLock<String> = OnceLock::new(); 146 #[inline] 147 fn init_jwt_login_issuer() { 148 JWT_LOGIN_ISSUER 149 .set(format!("{}|login", config::get_config().domain_url())) 150 .expect("JWT_LOGIN_ISSUER must only be initialized once"); 151 } 152 #[inline] 153 pub fn get_jwt_login_issuer() -> &'static str { 154 JWT_LOGIN_ISSUER 155 .get() 156 .expect("JWT_LOGIN_ISSUER must be initialized in main") 157 .as_str() 158 } 159 static JWT_INVITE_ISSUER: OnceLock<String> = OnceLock::new(); 160 #[inline] 161 fn init_jwt_invite_issuer() { 162 JWT_INVITE_ISSUER 163 .set(format!("{}|invite", config::get_config().domain_url())) 164 .expect("JWT_INVITE_ISSUER must only be initialized once"); 165 } 166 #[inline] 167 fn get_jwt_invite_issuer() -> &'static str { 168 JWT_INVITE_ISSUER 169 .get() 170 .expect("JWT_INVITE_ISSUER must be initialized in main") 171 .as_str() 172 } 173 static JWT_DELETE_ISSUER: OnceLock<String> = OnceLock::new(); 174 #[inline] 175 fn init_jwt_delete_issuer() { 176 JWT_DELETE_ISSUER 177 .set(format!("{}|delete", config::get_config().domain_url())) 178 .expect("JWT_DELETE_ISSUER must only be initialized once"); 179 } 180 #[inline] 181 fn get_jwt_delete_issuer() -> &'static str { 182 JWT_DELETE_ISSUER 183 .get() 184 .expect("JWT_DELETE_ISSUER must be initialized in main") 185 .as_str() 186 } 187 const JWT_ALGORITHM: Algorithm = Algorithm::EdDSA; 188 static ED_KEYS: OnceLock<(EncodingKey, DecodingKey)> = OnceLock::new(); 189 #[allow(clippy::map_err_ignore, clippy::verbose_file_reads)] 190 #[inline] 191 fn init_ed_keys() -> Result<(), Error> { 192 let mut file = File::options() 193 .create(true) 194 .read(true) 195 .truncate(false) 196 .write(true) 197 .open(Config::PRIVATE_ED25519_KEY)?; 198 let mut priv_pem = Vec::with_capacity(128); 199 let ed_key = if file.read_to_end(&mut priv_pem)? == 0 { 200 let ed_key = PKey::generate_ed25519()?; 201 priv_pem = ed_key.private_key_to_pem_pkcs8()?; 202 file.write_all(priv_pem.as_slice())?; 203 ed_key 204 } else { 205 let ed_key = PKey::private_key_from_pem(priv_pem.as_slice())?; 206 if ed_key.id() == Id::ED25519 { 207 ed_key 208 } else { 209 let msg = format!( 210 "{} is not a private Ed25519 key", 211 Config::PRIVATE_ED25519_KEY 212 ); 213 return Err(Error::new(msg.as_str(), msg.as_str())); 214 } 215 }; 216 ED_KEYS 217 .set(( 218 EncodingKey::from_ed_pem(priv_pem.as_slice())?, 219 DecodingKey::from_ed_pem(ed_key.public_key_to_pem()?.as_slice())?, 220 )) 221 .map_err(|_| { 222 const MSG: &str = "ED_KEYS must only be initialized once"; 223 Error::new(MSG, MSG) 224 }) 225 } 226 #[inline] 227 fn get_private_ed_key() -> &'static EncodingKey { 228 &ED_KEYS 229 .get() 230 .expect("ED_KEYS must be initialized in main") 231 .0 232 } 233 #[inline] 234 fn get_public_ed_key() -> &'static DecodingKey { 235 &ED_KEYS 236 .get() 237 .expect("ED_KEYS must be initialized in main") 238 .1 239 } 240 #[inline] 241 pub fn init_values() { 242 init_allowed_origins(); 243 init_reg_ceremonies(); 244 init_reg_options(); 245 init_auth_ceremonies(); 246 init_auth_options(); 247 init_default_validity(); 248 init_jwt_header(); 249 init_jwt_login_issuer(); 250 init_jwt_invite_issuer(); 251 init_jwt_delete_issuer(); 252 init_ed_keys().expect("error creating Ed25519 keys"); 253 } 254 pub fn encode_jwt<T: Serialize>(claims: &T) -> String { 255 match jsonwebtoken::encode(get_jwt_header(), claims, get_private_ed_key()) { 256 Ok(token) => token, 257 Err(e) => panic!("Error encoding jwt {e}"), 258 } 259 } 260 261 #[allow(clippy::match_same_arms)] 262 fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Error> { 263 let mut validation = jsonwebtoken::Validation::new(JWT_ALGORITHM); 264 validation.leeway = 30; // 30 seconds 265 validation.validate_exp = true; 266 validation.validate_nbf = true; 267 validation.set_issuer(&[issuer]); 268 let token = token.replace(char::is_whitespace, ""); 269 match jsonwebtoken::decode(&token, get_public_ed_key(), &validation) { 270 Ok(d) => Ok(d.claims), 271 Err(err) => match *err.kind() { 272 ErrorKind::InvalidToken => err!("Token is invalid"), 273 ErrorKind::InvalidIssuer => err!("Issuer is invalid"), 274 ErrorKind::ExpiredSignature => err!("Token has expired"), 275 ErrorKind::InvalidSignature 276 | ErrorKind::InvalidEcdsaKey 277 | ErrorKind::InvalidRsaKey(_) 278 | ErrorKind::RsaFailedSigning 279 | ErrorKind::InvalidAlgorithmName 280 | ErrorKind::InvalidKeyFormat 281 | ErrorKind::MissingRequiredClaim(_) 282 | ErrorKind::InvalidAudience 283 | ErrorKind::InvalidSubject 284 | ErrorKind::ImmatureSignature 285 | ErrorKind::InvalidAlgorithm 286 | ErrorKind::MissingAlgorithm 287 | ErrorKind::Base64(_) 288 | ErrorKind::Json(_) 289 | ErrorKind::Utf8(_) 290 | ErrorKind::Crypto(_) => err!("Error decoding JWT"), 291 _ => err!("Error decoding JWT"), 292 }, 293 } 294 } 295 pub fn decode_login(token: &str) -> Result<LoginJwtClaims, Error> { 296 decode_jwt(token, get_jwt_login_issuer().to_owned()) 297 } 298 pub fn decode_invite(token: &str) -> Result<InviteJwtClaims, Error> { 299 decode_jwt(token, get_jwt_invite_issuer().to_owned()) 300 } 301 pub fn decode_delete(token: &str) -> Result<BasicJwtClaims, Error> { 302 decode_jwt(token, get_jwt_delete_issuer().to_owned()) 303 } 304 305 #[derive(Serialize, Deserialize)] 306 pub struct LoginJwtClaims { 307 // Not before 308 pub nbf: i64, 309 // Expiration time 310 pub exp: i64, 311 // Issuer 312 pub iss: String, 313 // Subject 314 pub sub: String, 315 pub premium: bool, 316 pub name: String, 317 pub email: String, 318 pub email_verified: bool, 319 // --- 320 // Disabled these keys to be added to the JWT since they could cause the JWT to get too large 321 // Also These key/value pairs are not used anywhere by either Vaultwarden or Bitwarden Clients 322 // Because these might get used in the future, and they are added by the Bitwarden Server, lets keep it, but then commented out 323 // See: https://github.com/dani-garcia/vaultwarden/issues/4156 324 // --- 325 // pub orgowner: Vec<String>, 326 // pub orgadmin: Vec<String>, 327 // pub orguser: Vec<String>, 328 // pub orgmanager: Vec<String>, 329 // user security_stamp 330 pub sstamp: String, 331 // device uuid 332 pub device: String, 333 // [ "api", "offline_access" ] 334 pub scope: Vec<String>, 335 // [ "Application" ] 336 pub amr: Vec<String>, 337 } 338 339 #[derive(Serialize, Deserialize)] 340 pub struct InviteJwtClaims { 341 // Not before 342 nbf: i64, 343 // Expiration time 344 exp: i64, 345 // Issuer 346 iss: String, 347 // Subject 348 sub: String, 349 pub email: String, 350 pub org_id: Option<String>, 351 pub user_org_id: Option<String>, 352 invited_by_email: Option<String>, 353 } 354 355 #[derive(Serialize, Deserialize)] 356 pub struct FileDownloadClaims { 357 // Not before 358 nbf: i64, 359 // Expiration time 360 exp: i64, 361 // Issuer 362 iss: String, 363 // Subject 364 pub sub: String, 365 pub file_id: String, 366 } 367 #[derive(Serialize, Deserialize)] 368 pub struct BasicJwtClaims { 369 // Not before 370 nbf: i64, 371 // Expiration time 372 exp: i64, 373 // Issuer 374 iss: String, 375 // Subject 376 pub sub: String, 377 } 378 use crate::db::{ 379 DbConn, 380 models::{ 381 Collection, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException, 382 }, 383 }; 384 use rocket::{ 385 outcome::try_outcome, 386 request::{FromRequest, Outcome, Request}, 387 }; 388 389 struct Host { 390 host: String, 391 } 392 393 #[rocket::async_trait] 394 impl<'r> FromRequest<'r> for Host { 395 type Error = &'static str; 396 async fn from_request(_: &'r Request<'_>) -> Outcome<Self, Self::Error> { 397 Outcome::Success(Self { 398 host: config::get_config().domain_url().to_owned(), 399 }) 400 } 401 } 402 pub struct ClientHeaders { 403 #[allow(dead_code)] 404 pub device_type: i32, 405 pub ip: ClientIp, 406 } 407 408 #[rocket::async_trait] 409 impl<'r> FromRequest<'r> for ClientHeaders { 410 type Error = &'static str; 411 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 412 let Outcome::Success(ip) = ClientIp::from_request(request).await else { 413 err_handler!("Error getting Client IP") 414 }; 415 // When unknown or unable to parse, return 14, which is 'Unknown Browser' 416 let device_type: i32 = request 417 .headers() 418 .get_one("device-type") 419 .map_or(14i32, |d| d.parse().unwrap_or(14i32)); 420 421 Outcome::Success(Self { device_type, ip }) 422 } 423 } 424 425 pub struct Headers { 426 pub host: String, 427 pub device: Device, 428 pub user: User, 429 pub ip: ClientIp, 430 } 431 #[allow(clippy::else_if_without_else)] 432 #[rocket::async_trait] 433 impl<'r> FromRequest<'r> for Headers { 434 type Error = &'static str; 435 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 436 let headers = request.headers(); 437 let host = try_outcome!(Host::from_request(request).await).host; 438 let Outcome::Success(ip) = ClientIp::from_request(request).await else { 439 err_handler!("Error getting Client IP") 440 }; 441 // Get access_token 442 let access_token: &str = match headers.get_one("Authorization") { 443 Some(a) => match a.rsplit("Bearer ").next() { 444 Some(split) => split, 445 None => err_handler!("No access token provided"), 446 }, 447 None => err_handler!("No access token provided"), 448 }; 449 // Check JWT token is valid and get device and user from it 450 let Ok(claims) = decode_login(access_token) else { 451 err_handler!("Invalid claim") 452 }; 453 let device_uuid = claims.device; 454 let user_uuid = claims.sub; 455 let Outcome::Success(conn) = DbConn::from_request(request).await else { 456 err_handler!("Error getting DB") 457 }; 458 let Some(device) = Device::find_by_uuid_and_user(&device_uuid, &user_uuid, &conn).await 459 else { 460 err_handler!("Invalid device id") 461 }; 462 let Some(user) = User::find_by_uuid(&user_uuid, &conn).await else { 463 err_handler!("Device has no user associated") 464 }; 465 if user.security_stamp != claims.sstamp { 466 if let Some(stamp_exception) = user 467 .stamp_exception 468 .as_deref() 469 .and_then(|s| serde_json::from_str::<UserStampException>(s).ok()) 470 { 471 let Some(current_route) = request.route().and_then(|r| r.name.as_deref()) else { 472 err_handler!("Error getting current route for stamp exception") 473 }; 474 // Check if the stamp exception has expired first. 475 // Then, check if the current route matches any of the allowed routes. 476 // After that check the stamp in exception matches the one in the claims. 477 if u64::try_from(Utc::now().naive_utc().and_utc().timestamp()).expect("underflow") 478 > stamp_exception.expire 479 { 480 // If the stamp exception has been expired remove it from the database. 481 // This prevents checking this stamp exception for new requests. 482 let mut user = user; 483 user.reset_stamp_exception(); 484 if let Err(e) = user.save(&conn).await { 485 error!("Error updating user: {:#?}", e); 486 } 487 err_handler!("Stamp exception is expired") 488 } else if !stamp_exception.routes.contains(¤t_route.to_owned()) { 489 err_handler!( 490 "Invalid security stamp: Current route and exception route do not match" 491 ) 492 } else if stamp_exception.security_stamp != claims.sstamp { 493 err_handler!("Invalid security stamp for matched stamp exception") 494 } 495 } else { 496 err_handler!("Invalid security stamp") 497 } 498 } 499 Outcome::Success(Self { 500 host, 501 device, 502 user, 503 ip, 504 }) 505 } 506 } 507 508 pub struct OrgHeaders { 509 host: String, 510 device: Device, 511 user: User, 512 org_user_type: UserOrgType, 513 org_user: UserOrganization, 514 ip: ClientIp, 515 } 516 517 #[rocket::async_trait] 518 impl<'r> FromRequest<'r> for OrgHeaders { 519 type Error = &'static str; 520 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 521 let headers = try_outcome!(Headers::from_request(request).await); 522 // org_id is usually the second path param ("/organizations/<org_id>"), 523 // but there are cases where it is a query value. 524 // First check the path, if this is not a valid uuid, try the query values. 525 let url_org_id: Option<&str> = { 526 let mut url_org_id = None; 527 if let Some(Ok(org_id)) = request.param::<&str>(1) { 528 if uuid::Uuid::parse_str(org_id).is_ok() { 529 url_org_id = Some(org_id); 530 } 531 } 532 if let Some(Ok(org_id)) = request.query_value::<&str>("organizationId") { 533 if uuid::Uuid::parse_str(org_id).is_ok() { 534 url_org_id = Some(org_id); 535 } 536 } 537 url_org_id 538 }; 539 match url_org_id { 540 Some(org_id) => { 541 let user = headers.user; 542 let org_user = match DbConn::from_request(request).await { 543 Outcome::Success(conn) => { 544 match UserOrganization::find_by_user_and_org(&user.uuid, org_id, &conn) 545 .await 546 { 547 Some(user) => { 548 if user.status == i32::from(UserOrgStatus::Confirmed) { 549 user 550 } else { 551 err_handler!( 552 "The current user isn't confirmed member of the organization" 553 ) 554 } 555 } 556 None => { 557 err_handler!("The current user isn't member of the organization") 558 } 559 } 560 } 561 Outcome::Error(_) | Outcome::Forward(_) => err_handler!("Error getting DB"), 562 }; 563 Outcome::Success(Self { 564 host: headers.host, 565 device: headers.device, 566 user, 567 org_user_type: { 568 if let Ok(org_usr_type) = UserOrgType::try_from(org_user.atype) { 569 org_usr_type 570 } else { 571 // This should only happen if the DB is corrupted 572 err_handler!("Unknown user type in the database") 573 } 574 }, 575 org_user, 576 ip: headers.ip, 577 }) 578 } 579 _ => err_handler!("Error getting the organization id"), 580 } 581 } 582 } 583 584 pub struct AdminHeaders { 585 pub host: String, 586 pub device: Device, 587 pub user: User, 588 pub org_user_type: UserOrgType, 589 pub client_version: Option<String>, 590 pub ip: ClientIp, 591 } 592 593 #[rocket::async_trait] 594 impl<'r> FromRequest<'r> for AdminHeaders { 595 type Error = &'static str; 596 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 597 let headers = try_outcome!(OrgHeaders::from_request(request).await); 598 let client_version = request 599 .headers() 600 .get_one("Bitwarden-Client-Version") 601 .map(String::from); 602 if headers.org_user_type >= UserOrgType::Admin { 603 Outcome::Success(Self { 604 host: headers.host, 605 device: headers.device, 606 user: headers.user, 607 org_user_type: headers.org_user_type, 608 client_version, 609 ip: headers.ip, 610 }) 611 } else { 612 err_handler!("You need to be Admin or Owner to call this endpoint") 613 } 614 } 615 } 616 617 impl From<AdminHeaders> for Headers { 618 fn from(h: AdminHeaders) -> Self { 619 Self { 620 host: h.host, 621 device: h.device, 622 user: h.user, 623 ip: h.ip, 624 } 625 } 626 } 627 628 // col_id is usually the fourth path param ("/organizations/<org_id>/collections/<col_id>"), 629 // but there could be cases where it is a query value. 630 // First check the path, if this is not a valid uuid, try the query values. 631 fn get_col_id(request: &Request<'_>) -> Option<String> { 632 if let Some(Ok(col_id)) = request.param::<String>(3) { 633 if uuid::Uuid::parse_str(&col_id).is_ok() { 634 return Some(col_id); 635 } 636 } 637 if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") { 638 if uuid::Uuid::parse_str(&col_id).is_ok() { 639 return Some(col_id); 640 } 641 } 642 None 643 } 644 645 /// The ManagerHeaders are used to check if you are at least a Manager 646 /// and have access to the specific collection provided via the <col_id>/collections/collectionId. 647 /// This does strict checking on the collection_id, ManagerHeadersLoose does not. 648 pub struct ManagerHeaders { 649 host: String, 650 device: Device, 651 pub user: User, 652 ip: ClientIp, 653 } 654 655 #[rocket::async_trait] 656 impl<'r> FromRequest<'r> for ManagerHeaders { 657 type Error = &'static str; 658 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 659 let headers = try_outcome!(OrgHeaders::from_request(request).await); 660 if headers.org_user_type >= UserOrgType::Manager { 661 match get_col_id(request) { 662 Some(col_id) => { 663 let Outcome::Success(conn) = DbConn::from_request(request).await else { 664 err_handler!("Error getting DB") 665 }; 666 667 if !Collection::can_access_collection(&headers.org_user, &col_id, &conn).await { 668 err_handler!("The current user isn't a manager for this collection") 669 } 670 } 671 _ => err_handler!("Error getting the collection id"), 672 } 673 674 Outcome::Success(Self { 675 host: headers.host, 676 device: headers.device, 677 user: headers.user, 678 ip: headers.ip, 679 }) 680 } else { 681 err_handler!("You need to be a Manager, Admin or Owner to call this endpoint") 682 } 683 } 684 } 685 686 impl From<ManagerHeaders> for Headers { 687 fn from(h: ManagerHeaders) -> Self { 688 Self { 689 host: h.host, 690 device: h.device, 691 user: h.user, 692 ip: h.ip, 693 } 694 } 695 } 696 697 /// The ManagerHeadersLoose is used when you at least need to be a Manager, 698 /// but there is no collection_id sent with the request (either in the path or as form data). 699 pub struct ManagerHeadersLoose { 700 host: String, 701 device: Device, 702 pub user: User, 703 pub org_user: UserOrganization, 704 ip: ClientIp, 705 } 706 707 #[rocket::async_trait] 708 impl<'r> FromRequest<'r> for ManagerHeadersLoose { 709 type Error = &'static str; 710 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 711 let headers = try_outcome!(OrgHeaders::from_request(request).await); 712 if headers.org_user_type >= UserOrgType::Manager { 713 Outcome::Success(Self { 714 host: headers.host, 715 device: headers.device, 716 user: headers.user, 717 org_user: headers.org_user, 718 ip: headers.ip, 719 }) 720 } else { 721 err_handler!("You need to be a Manager, Admin or Owner to call this endpoint") 722 } 723 } 724 } 725 726 impl From<ManagerHeadersLoose> for Headers { 727 fn from(h: ManagerHeadersLoose) -> Self { 728 Self { 729 host: h.host, 730 device: h.device, 731 user: h.user, 732 ip: h.ip, 733 } 734 } 735 } 736 737 impl ManagerHeaders { 738 pub async fn from_loose( 739 h: ManagerHeadersLoose, 740 collections: &Vec<String>, 741 conn: &DbConn, 742 ) -> Result<Self, Error> { 743 for col_id in collections { 744 if uuid::Uuid::parse_str(col_id).is_err() { 745 err!("Collection Id is malformed!"); 746 } 747 if !Collection::can_access_collection(&h.org_user, col_id, conn).await { 748 err!("You don't have access to all collections!"); 749 } 750 } 751 Ok(Self { 752 host: h.host, 753 device: h.device, 754 user: h.user, 755 ip: h.ip, 756 }) 757 } 758 } 759 760 pub struct OwnerHeaders { 761 #[allow(dead_code)] 762 pub device: Device, 763 pub user: User, 764 #[allow(dead_code)] 765 pub ip: ClientIp, 766 } 767 768 #[rocket::async_trait] 769 impl<'r> FromRequest<'r> for OwnerHeaders { 770 type Error = &'static str; 771 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 772 let headers = try_outcome!(OrgHeaders::from_request(request).await); 773 if headers.org_user_type == UserOrgType::Owner { 774 Outcome::Success(Self { 775 device: headers.device, 776 user: headers.user, 777 ip: headers.ip, 778 }) 779 } else { 780 err_handler!("You need to be Owner to call this endpoint") 781 } 782 } 783 } 784 785 // 786 // Client IP address detection 787 // 788 use std::net::IpAddr; 789 #[derive(Clone, Copy)] 790 pub struct ClientIp { 791 pub ip: IpAddr, 792 } 793 794 #[rocket::async_trait] 795 impl<'r> FromRequest<'r> for ClientIp { 796 type Error = (); 797 #[allow(clippy::map_err_ignore, clippy::string_slice)] 798 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 799 let ip = request.headers().get_one("X-Real-IP").and_then(|ip| { 800 ip.find(',') 801 .map_or(ip, |idx| &ip[..idx]) 802 .parse() 803 .map_err(|_| warn!("'X-Real-IP' header is malformed: {ip}")) 804 .ok() 805 }); 806 let ip = ip 807 .or_else(|| request.remote().map(|r| r.ip())) 808 .unwrap_or_else(|| "0.0.0.0".parse().unwrap()); 809 Outcome::Success(Self { ip }) 810 } 811 } 812 813 pub struct WsAccessTokenHeader { 814 #[allow(dead_code)] 815 pub access_token: Option<String>, 816 } 817 818 #[rocket::async_trait] 819 impl<'r> FromRequest<'r> for WsAccessTokenHeader { 820 type Error = (); 821 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { 822 let headers = request.headers(); 823 let access_token = headers 824 .get_one("Authorization") 825 .and_then(|a| a.rsplit("Bearer ").next().map(String::from)); 826 Outcome::Success(Self { access_token }) 827 } 828 }