vw_small

Hardened fork of Vaultwarden (https://github.com/dani-garcia/vaultwarden) with fewer features.
git clone https://git.philomathiclife.com/repos/vw_small
Log | Files | Refs | README

commit cd856fed66dec4fd6305804e06044ed926bf4809
parent cabbfe1221c3b704dbc43d53829719859e253dc2
Author: Zack Newman <zack@philomathiclife.com>
Date:   Wed,  3 Jan 2024 22:08:32 -0700

ed25519 webauthn. cleanup

Diffstat:
MCargo.toml | 11+++++------
MREADME.md | 41+++++++++++++++++++++++++++++++++++++----
Mconfig.toml | 1+
Msrc/api/core/accounts.rs | 149++++++++++++++++++++++++-------------------------------------------------------
Msrc/api/core/ciphers.rs | 91+++++++++++++++++++++++--------------------------------------------------------
Msrc/api/core/emergency_access.rs | 112++++++++++++++++++++++++++++++++++++-------------------------------------------
Msrc/api/core/events.rs | 66++++++++++++++++--------------------------------------------------
Msrc/api/core/mod.rs | 18+++++++-----------
Msrc/api/core/organizations.rs | 273+++++++++++++++++++++++---------------------------------------------------------
Msrc/api/core/public.rs | 76+++++-----------------------------------------------------------------------
Msrc/api/core/sends.rs | 81+++++++++++++++++++++++++------------------------------------------------------
Msrc/api/core/two_factor/authenticator.rs | 77+++++++++++++++++++++++------------------------------------------------------
Asrc/api/core/two_factor/duo.rs | 38++++++++++++++++++++++++++++++++++++++
Asrc/api/core/two_factor/email.rs | 57+++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Msrc/api/core/two_factor/mod.rs | 73++++++++++++++++++++++---------------------------------------------------
Msrc/api/core/two_factor/protected_actions.rs | 101++++++-------------------------------------------------------------------------
Msrc/api/core/two_factor/webauthn.rs | 250++++++++++++++++++++++++++++++++++---------------------------------------------
Asrc/api/core/two_factor/yubikey.rs | 41+++++++++++++++++++++++++++++++++++++++++
Msrc/api/identity.rs | 270++++++++++++++++++++++++++-----------------------------------------------------
Msrc/api/mod.rs | 11+++++------
Msrc/api/notifications.rs | 73+++++++++++++++++++++++++++++++++++--------------------------------------
Msrc/api/web.rs | 15++++-----------
Msrc/auth.rs | 99+++++++++++++------------------------------------------------------------------
Msrc/config.rs | 7+++++--
Msrc/crypto.rs | 27---------------------------
Msrc/db/models/mod.rs | 6+++---
Msrc/db/models/org_policy.rs | 31+++++++++++++++++++++++++++----
Msrc/db/models/organization.rs | 95++++++++++++++++++++++---------------------------------------------------------
Msrc/db/models/two_factor.rs | 796+++++++++++++++++++++++++++++++++++++++++++++++++------------------------------
Msrc/db/models/user.rs | 12+++++++++---
Msrc/db/schemas/sqlite/schema.rs | 42++++++++++++++++++++----------------------
Msrc/util.rs | 2+-
32 files changed, 1325 insertions(+), 1717 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml @@ -24,8 +24,6 @@ data-encoding = { version = "2.5.0", default-features = false } diesel = { version = "2.1.4", default-features = false, features = ["32-column-tables", "chrono", "r2d2", "sqlite"] } jsonwebtoken = { version = "9.2.0", default-features = false, features = ["use_pem"] } libsqlite3-sys = { version = "0.27.0", default-features = false, features = ["bundled"] } -num-derive = { version = "0.4.1", default-features = false } -num-traits = { version = "0.2.17", default-features = false } openssl = { version = "0.10.62", default-features = false } paste = { version = "1.0.14", default-features = false } rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } @@ -34,9 +32,9 @@ ring = { version = "0.17.7", default-features = false } rmpv = { version = "1.0.1", default-features = false } rocket = { version = "0.5.0", default-features = false, features = ["json", "tls"] } rocket_ws = { version = "0.1.0", default-features = false, features = ["tokio-tungstenite"] } -semver = { version = "1.0.20", default-features = false } -serde = { version = "1.0.193", default-features = false } -serde_json = { version = "1.0.108", default-features = false } +semver = { version = "1.0.21", default-features = false } +serde = { version = "1.0.194", default-features = false } +serde_json = { version = "1.0.110", default-features = false } tokio = { version = "1.35.1", default-features = false } tokio-tungstenite = { version = "0.20.1", default-features = false } toml = { version = "0.8.8", default-features = false, features = ["parse"] } @@ -46,7 +44,8 @@ uuid = { version = "1.6.1", default-features = false, features = ["v4"] } webauthn-rs = { version = "0.4.8", default-features = false, features = ["danger-allow-state-serialisation", "danger-user-presence-only-security-keys"] } [patch.crates-io] -webauthn-rs-core = { git = "https://git.philomathiclife.com/repos/webauthn-rs-core", tag = "v0.4.9" } +webauthn-rs-core = { git = "https://git.philomathiclife.com/repos/webauthn-rs-core", tag = "v0.4.10" } +webauthn-rs-proto = { git = "https://git.philomathiclife.com/repos/webauthn-rs-proto", tag = "v0.4.10" } [profile.release] lto = true diff --git a/README.md b/README.md @@ -13,10 +13,27 @@ This crate has first-class support for OpenBSD-stable; and when compiled/install [`unveil(2)`](https://man.openbsd.org/amd64/unveil.2) to lock down the daemon. This crate does not support all of the features Vaultwarden supports. To some fewer features _is_ a feature. In particular, this crate assumes a small-scale environment; thus -only SQLite is supported for the database, there is no HTTP(S) client, no SMTP client, no DNS resolver, no support for groups, no admin panel, no attachment support, no send support, -no push notifications, and only WebAuthn and TOTP are supported for 2FA. +the following are true and likely won’t change in the future: -This crate makes a better attempt and performing state-changing operations in an atomic fashion (e.g., instead of mutating two database tables in separate transactions allowing +* No containers +* WebAuthn and TOTP are the only forms of 2FA +* SQLite is the only supported database engine +* HTTPS is required +* No HTTP(S) client +* No SMTP client +* No DNS resolver +* No groups +* No admin panel +* No attachments +* No sends +* No push notifications +* No log in with device +* No recovery code +* No emergency access +* No log in via the API +* No automatic jobs (e.g., purging trash) + +This crate makes a better attempt at performing state-changing operations in an atomic fashion (e.g., instead of mutating two database tables in separate transactions allowing for the possibility the first change occurs without the second, both changes are done as a single transaction). ## Config file @@ -33,6 +50,7 @@ ip=<IPv6_or_IPv4_address> password_iterations=<100000-4294967295> port=<0-65535> web_vault_enabled=<true/false> +webauthn_require_yubi=<true/false> workers=<1-255> [tls] cert=<absolute_path_to_complete_X509_certificate> @@ -49,13 +67,28 @@ database_timeout=30 db_connection_retries=15 password_iterations=600000 web_vault_enabled=true +webauthn_require_yubi=false workers=<number_of_CPU_cores> [tls] ciphers=["TLS_CHACHA20_POLY1305_SHA256","TLS_AES_256_GCM_SHA384","TLS_AES_128_GCM_SHA256","TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256","TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256","TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384","TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256","TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384","TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"] prefer_server_cipher_order=false ``` -When `database_timeout` is `0`, there is no timeout; otherwise the value represents the maximum _seconds_ allowed for a database connection to be made. +When `database_timeout` is `0`, there is no timeout; otherwise the value represents the maximum seconds allowed for a database connection to be made. +When `webauthn_require_yubi` is `true`, then WebAuthn registrations require a FIDO2 YubiKey with firmware 5.2.a, 5.4.b, 5.5.c, or 5.6.d. + +## Directory hierachy + +The running directory must conform to the following: + +```bash +$PWD/ + config.toml + data/ + web-vault/ +``` + +Where `web-vault` must exist if `web_vault_enabled=true` and must be the output of an extracted [`bw_web_builds`](https://github.com/dani-garcia/bw_web_builds/releases). ### Status diff --git a/config.toml b/config.toml @@ -6,6 +6,7 @@ ip="fdb5:d87:ae42:1::1" #password_iterations=600000 port=8443 #web_vault_enabled=true +webauthn_require_yubi=true workers=4 [tls] cert="/etc/ssl/pmd.philomathiclife.com.fullchain" diff --git a/src/api/core/accounts.rs b/src/api/core/accounts.rs @@ -4,11 +4,12 @@ use crate::{ PasswordOrOtpData, UpdateType, }, auth::{decode_delete, decode_verify_email, ClientHeaders, Headers}, - config, crypto, + config, db::{ models::{AuthRequest, Cipher, Device, DeviceType, Folder, User, UserKdfType}, DbConn, }, + error::Error, }; use chrono::Utc; use rocket::serde::json::Json; @@ -59,7 +60,7 @@ pub fn routes() -> Vec<rocket::Route> { ] } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(dead_code, non_snake_case)] pub struct RegisterData { Email: String, @@ -77,7 +78,7 @@ pub struct RegisterData { OrganizationUserId: Option<String>, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct KeysData { EncryptedPrivateKey: String, @@ -101,15 +102,16 @@ fn enforce_password_hint_setting(password_hint: &Option<String>) -> EmptyResult #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/accounts/register", data = "<data>")] -fn register(data: JsonUpcase<RegisterData>, _conn: DbConn) -> JsonResult { - err!("Registration is permanently disabled.") +fn register(data: JsonUpcase<RegisterData>) -> Error { + const MSG: &str = "Registration is permanently disabled."; + Error::new(MSG, MSG) } #[get("/accounts/profile")] async fn profile(headers: Headers, conn: DbConn) -> Json<Value> { Json(headers.user.to_json(&conn).await) } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct ProfileData { // Culture: String, // Ignored, always use en-US @@ -391,7 +393,7 @@ async fn post_sstamp( ) -> EmptyResult { let otp_data: PasswordOrOtpData = data.into_inner().data; let mut user = headers.user; - otp_data.validate(&user, true, &conn).await?; + otp_data.validate(&user)?; Device::delete_all_by_user(&user.uuid, &conn).await?; user.reset_security_stamp(); let save_result = user.save(&conn).await; @@ -408,12 +410,9 @@ struct EmailTokenData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/accounts/email-token", data = "<data>")] -fn post_email_token( - data: JsonUpcase<EmailTokenData>, - _headers: Headers, - _conn: DbConn, -) -> EmptyResult { - err!("Email change is not allowed."); +fn post_email_token(data: JsonUpcase<EmailTokenData>, _headers: Headers) -> Error { + const MSG: &str = "E-mail change is not allowed."; + Error::new(MSG, MSG) } #[derive(Deserialize)] @@ -428,19 +427,16 @@ struct ChangeEmailData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/accounts/email", data = "<data>")] -fn post_email( - data: JsonUpcase<ChangeEmailData>, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> EmptyResult { - err!("Email change is not allowed."); +fn post_email(data: JsonUpcase<ChangeEmailData>, _headers: Headers) -> Error { + const MSG: &str = "E-mail change is not allowed."; + Error::new(MSG, MSG) } #[allow(clippy::needless_pass_by_value)] #[post("/accounts/verify-email")] -fn post_verify_email(_headers: Headers) -> EmptyResult { - err!("Cannot verify email address") +fn post_verify_email(_headers: Headers) -> Error { + const MSG: &str = "E-mail is disabled."; + Error::new(MSG, MSG) } #[derive(Deserialize)] @@ -479,8 +475,9 @@ struct DeleteRecoverData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/accounts/delete-recover", data = "<data>")] -fn post_delete_recover(data: JsonUpcase<DeleteRecoverData>, _conn: DbConn) -> EmptyResult { - err!("Please contact the administrator to delete your account"); +fn post_delete_recover(data: JsonUpcase<DeleteRecoverData>) -> Error { + const MSG: &str = "Account deletion is disabled with at this endpoint."; + Error::new(MSG, MSG) } #[derive(Deserialize)] @@ -525,14 +522,13 @@ async fn delete_account( ) -> EmptyResult { let otp_data: PasswordOrOtpData = data.into_inner().data; let user = headers.user; - otp_data.validate(&user, true, &conn).await?; + otp_data.validate(&user)?; user.delete(&conn).await } -#[allow(clippy::needless_pass_by_value, clippy::unnecessary_wraps)] +#[allow(clippy::needless_pass_by_value)] #[get("/accounts/revision-date")] -fn revision_date(headers: Headers) -> JsonResult { - let revision_date = headers.user.updated_at.timestamp_millis(); - Ok(Json(json!(revision_date))) +fn revision_date(headers: Headers) -> Json<Value> { + Json(json!(headers.user.updated_at.timestamp_millis())) } #[derive(Deserialize)] @@ -543,8 +539,9 @@ struct PasswordHintData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/accounts/password-hint", data = "<data>")] -fn password_hint(data: JsonUpcase<PasswordHintData>, _conn: DbConn) -> EmptyResult { - err!("This server is not configured to provide password hints.") +fn password_hint(data: JsonUpcase<PasswordHintData>) -> Error { + const MSG: &str = "Password hints are disabled."; + Error::new(MSG, MSG) } #[derive(Deserialize)] @@ -601,43 +598,17 @@ fn verify_password(data: JsonUpcase<SecretVerificationRequest>, headers: Headers Ok(()) } -async fn _api_key( - data: JsonUpcase<PasswordOrOtpData>, - rotate: bool, - headers: Headers, - conn: DbConn, -) -> JsonResult { - use crate::util::format_date; - let otp_data: PasswordOrOtpData = data.into_inner().data; - let mut user = headers.user; - otp_data.validate(&user, true, &conn).await?; - if rotate || user.api_key.is_none() { - user.api_key = Some(crypto::generate_api_key()); - user.save(&conn).await.expect("Error saving API key"); - } - Ok(Json(json!({ - "ApiKey": user.api_key, - "RevisionDate": format_date(&user.updated_at), - "Object": "apiKey", - }))) -} - +const API_DISABLED_MSG: &str = "API access is disabled."; +#[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/accounts/api-key", data = "<data>")] -async fn api_key( - data: JsonUpcase<PasswordOrOtpData>, - headers: Headers, - conn: DbConn, -) -> JsonResult { - _api_key(data, false, headers, conn).await +fn api_key(data: JsonUpcase<PasswordOrOtpData>, _headers: Headers) -> Error { + Error::new(API_DISABLED_MSG, API_DISABLED_MSG) } +#[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/accounts/rotate-api-key", data = "<data>")] -async fn rotate_api_key( - data: JsonUpcase<PasswordOrOtpData>, - headers: Headers, - conn: DbConn, -) -> JsonResult { - _api_key(data, true, headers, conn).await +fn rotate_api_key(data: JsonUpcase<PasswordOrOtpData>, _headers: Headers) -> Error { + Error::new(API_DISABLED_MSG, API_DISABLED_MSG) } // This variant is deprecated: https://github.com/bitwarden/server/pull/2682 @@ -701,52 +672,22 @@ impl<'r> FromRequest<'r> for KnownDevice { #[allow(non_snake_case)] struct PushToken; -#[allow( - unused_variables, - clippy::needless_pass_by_value, - clippy::unnecessary_wraps -)] +#[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/devices/identifier/<uuid>/token", data = "<data>")] -fn post_device_token( - uuid: &str, - data: JsonUpcase<PushToken>, - _headers: Headers, - _conn: DbConn, -) -> EmptyResult { - Ok(()) -} -#[allow( - unused_variables, - clippy::needless_pass_by_value, - clippy::unnecessary_wraps -)] +fn post_device_token(uuid: &str, data: JsonUpcase<PushToken>, _headers: Headers) {} +#[allow(unused_variables, clippy::needless_pass_by_value)] #[put("/devices/identifier/<uuid>/token", data = "<data>")] -fn put_device_token( - uuid: &str, - data: JsonUpcase<PushToken>, - _headers: Headers, - _conn: DbConn, -) -> EmptyResult { - Ok(()) -} -#[allow( - unused_variables, - clippy::needless_pass_by_value, - clippy::unnecessary_wraps -)] +fn put_device_token(uuid: &str, data: JsonUpcase<PushToken>, _headers: Headers) {} +#[allow(unused_variables)] #[put("/devices/identifier/<uuid>/clear-token")] -fn put_clear_device_token(uuid: &str, _conn: DbConn) -> EmptyResult { - Ok(()) -} +const fn put_clear_device_token(uuid: &str) {} // On upstream server, both PUT and POST are declared. Implementing the POST method in case it would be useful somewhere -#[allow(clippy::unnecessary_wraps)] +#[allow(unused_variables)] #[post("/devices/identifier/<uuid>/clear-token")] -fn post_clear_device_token(uuid: &str, conn: DbConn) -> EmptyResult { - put_clear_device_token(uuid, conn) -} +const fn post_clear_device_token(uuid: &str) {} -#[derive(Debug, Deserialize)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct AuthRequestRequest { accessCode: String, @@ -819,7 +760,7 @@ async fn get_auth_request(uuid: &str, conn: DbConn) -> JsonResult { ))) } -#[derive(Debug, Deserialize)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct AuthResponseRequest { deviceIdentifier: String, diff --git a/src/api/core/ciphers.rs b/src/api/core/ciphers.rs @@ -9,6 +9,7 @@ use crate::{ }, DbConn, }, + error::Error, }; use chrono::{NaiveDateTime, Utc}; use rocket::fs::TempFile; @@ -200,7 +201,7 @@ async fn get_cipher_details(uuid: &str, headers: Headers, conn: DbConn) -> JsonR get_cipher(uuid, headers, conn).await } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(dead_code, non_snake_case)] pub struct CipherData { // Id is optional as it is included only in bulk share @@ -236,14 +237,14 @@ pub struct CipherData { LastKnownRevisionDate: Option<String>, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct PartialCipherData { FolderId: Option<String>, Favorite: bool, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(dead_code, non_snake_case)] struct Attachments2Data { FileName: String, @@ -892,6 +893,7 @@ async fn share_cipher_by_uuid( )) } +const ATTACHMENTS_DISABLED_MSG: &str = "Attachments are disabled."; /// v2 API for downloading an attachment. This just redirects the client to /// the actual location of an attachment. /// @@ -900,8 +902,8 @@ async fn share_cipher_by_uuid( /// redirects to the same location as before the v2 API. #[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/ciphers/<uuid>/attachment/<attachment_id>")] -fn get_attachment(uuid: &str, attachment_id: &str, _headers: Headers, _conn: DbConn) -> JsonResult { - err!("Attachments are disabled") +fn get_attachment(uuid: &str, attachment_id: &str, _headers: Headers) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } #[derive(Deserialize)] @@ -923,9 +925,8 @@ fn post_attachment_v2( uuid: &str, data: JsonUpcase<AttachmentRequestData>, _headers: Headers, - _conn: DbConn, -) -> JsonResult { - err!("Attachments are disabled") +) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } #[allow(dead_code)] @@ -951,10 +952,8 @@ fn post_attachment_v2_data( attachment_id: &str, data: Form<UploadData<'_>>, _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> EmptyResult { - err!("Attachments are disabled") +) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } /// Legacy API for creating an attachment associated with a cipher. @@ -964,14 +963,8 @@ fn post_attachment_v2_data( format = "multipart/form-data", data = "<data>" )] -fn post_attachment( - uuid: &str, - data: Form<UploadData<'_>>, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> JsonResult { - err!("Attachments are disabled") +fn post_attachment(uuid: &str, data: Form<UploadData<'_>>, _headers: Headers) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -980,14 +973,8 @@ fn post_attachment( format = "multipart/form-data", data = "<data>" )] -fn post_attachment_admin( - uuid: &str, - data: Form<UploadData<'_>>, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> JsonResult { - err!("Attachments are disabled") +fn post_attachment_admin(uuid: &str, data: Form<UploadData<'_>>, _headers: Headers) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -1001,58 +988,32 @@ fn post_attachment_share( attachment_id: &str, data: Form<UploadData<'_>>, _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> JsonResult { - err!("Attachments are disabled") +) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/ciphers/<uuid>/attachment/<attachment_id>/delete-admin")] -fn delete_attachment_post_admin( - uuid: &str, - attachment_id: &str, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> EmptyResult { - err!("Attachments are disabled") +fn delete_attachment_post_admin(uuid: &str, attachment_id: &str, _headers: Headers) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/ciphers/<uuid>/attachment/<attachment_id>/delete")] -fn delete_attachment_post( - uuid: &str, - attachment_id: &str, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> EmptyResult { - err!("Attachments are disabled") +fn delete_attachment_post(uuid: &str, attachment_id: &str, _headers: Headers) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[delete("/ciphers/<uuid>/attachment/<attachment_id>")] -fn delete_attachment( - uuid: &str, - attachment_id: &str, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> EmptyResult { - err!("Attachments are disabled") +fn delete_attachment(uuid: &str, attachment_id: &str, _headers: Headers) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[delete("/ciphers/<uuid>/attachment/<attachment_id>/admin")] -fn delete_attachment_admin( - uuid: &str, - attachment_id: &str, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> EmptyResult { - err!("Attachments are disabled") +fn delete_attachment_admin(uuid: &str, attachment_id: &str, _headers: Headers) -> Error { + Error::new(ATTACHMENTS_DISABLED_MSG, ATTACHMENTS_DISABLED_MSG) } #[post("/ciphers/<uuid>/delete")] @@ -1275,7 +1236,7 @@ async fn delete_all( ) -> EmptyResult { let data: PasswordOrOtpData = data.into_inner().data; let mut user = headers.user; - data.validate(&user, true, &conn).await?; + data.validate(&user)?; if let Some(org_data) = organization { // Organization ID in query params, purging organization vault match UserOrganization::find_by_user_and_org(&user.uuid, &org_data.org_id, &conn).await { diff --git a/src/api/core/emergency_access.rs b/src/api/core/emergency_access.rs @@ -1,9 +1,10 @@ use crate::{ - api::{EmptyResult, JsonResult, JsonUpcase, NumberOrString}, + api::{JsonUpcase, NumberOrString}, auth::Headers, - db::DbConn, + error::Error, }; -use rocket::Route; +use rocket::{serde::json::Json, Route}; +use serde_json::Value; pub fn routes() -> Vec<Route> { routes![ @@ -29,19 +30,27 @@ pub fn routes() -> Vec<Route> { } #[allow(clippy::needless_pass_by_value)] #[get("/emergency-access/trusted")] -fn get_contacts(_headers: Headers, _conn: DbConn) -> JsonResult { - err!("Emergency access is not allowed.") +fn get_contacts(_headers: Headers) -> Json<Value> { + Json(json!({ + "Data": Vec::<Value>::new(), + "Object": "list", + "ContinuationToken": null + })) } - #[allow(clippy::needless_pass_by_value)] #[get("/emergency-access/granted")] -fn get_grantees(_headers: Headers, _conn: DbConn) -> JsonResult { - err!("Emergency access is not allowed.") -} -#[allow(unused_variables, clippy::needless_pass_by_value)] +fn get_grantees(_headers: Headers) -> Json<Value> { + Json(json!({ + "Data": Vec::<Value>::new(), + "Object": "list", + "ContinuationToken": null + })) +} +const ACCESS_NOT_ALLOWED_MSG: &str = "Emergency access is not allowed."; +#[allow(unused_variables)] #[get("/emergency-access/<emer_id>")] -fn get_emergency_access(emer_id: &str, _conn: DbConn) -> JsonResult { - err!("Emergency access is not allowed.") +fn get_emergency_access(emer_id: &str) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[derive(Deserialize)] @@ -53,34 +62,26 @@ struct EmergencyAccessUpdateData { } #[allow(unused_variables, clippy::needless_pass_by_value)] #[put("/emergency-access/<emer_id>", data = "<data>")] -fn put_emergency_access( - emer_id: &str, - data: JsonUpcase<EmergencyAccessUpdateData>, - _conn: DbConn, -) -> JsonResult { - err!("Emergency access is not allowed.") +fn put_emergency_access(emer_id: &str, data: JsonUpcase<EmergencyAccessUpdateData>) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/<emer_id>", data = "<data>")] -fn post_emergency_access( - emer_id: &str, - data: JsonUpcase<EmergencyAccessUpdateData>, - _conn: DbConn, -) -> JsonResult { - err!("Emergency access is not allowed.") +fn post_emergency_access(emer_id: &str, data: JsonUpcase<EmergencyAccessUpdateData>) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[delete("/emergency-access/<emer_id>")] -fn delete_emergency_access(emer_id: &str, _headers: Headers, _conn: DbConn) -> EmptyResult { - err!("Emergency access is not allowed.") +fn delete_emergency_access(emer_id: &str, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/<emer_id>/delete")] -fn post_delete_emergency_access(emer_id: &str, _headers: Headers, _conn: DbConn) -> EmptyResult { - err!("Emergency access is not allowed.") +fn post_delete_emergency_access(emer_id: &str, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[derive(Deserialize)] @@ -93,18 +94,14 @@ struct EmergencyAccessInviteData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/invite", data = "<data>")] -fn send_invite( - data: JsonUpcase<EmergencyAccessInviteData>, - _headers: Headers, - _conn: DbConn, -) -> EmptyResult { - err!("Emergency access is not allowed.") +fn send_invite(data: JsonUpcase<EmergencyAccessInviteData>, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/<emer_id>/reinvite")] -fn resend_invite(emer_id: &str, _headers: Headers, _conn: DbConn) -> EmptyResult { - err!("Emergency access is not allowed.") +fn resend_invite(emer_id: &str, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[derive(Deserialize)] @@ -115,13 +112,8 @@ struct AcceptData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/<emer_id>/accept", data = "<data>")] -fn accept_invite( - emer_id: &str, - data: JsonUpcase<AcceptData>, - _headers: Headers, - _conn: DbConn, -) -> EmptyResult { - err!("Emergency access is not allowed.") +fn accept_invite(emer_id: &str, data: JsonUpcase<AcceptData>, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[derive(Deserialize)] @@ -136,39 +128,38 @@ fn confirm_emergency_access( emer_id: &str, data: JsonUpcase<ConfirmData>, _headers: Headers, - _conn: DbConn, -) -> JsonResult { - err!("Emergency access is not allowed.") +) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/<emer_id>/initiate")] -fn initiate_emergency_access(emer_id: &str, _headers: Headers, _conn: DbConn) -> JsonResult { - err!("Emergency access is not allowed.") +fn initiate_emergency_access(emer_id: &str, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/<emer_id>/approve")] -fn approve_emergency_access(emer_id: &str, _headers: Headers, _conn: DbConn) -> JsonResult { - err!("Emergency access is not allowed.") +fn approve_emergency_access(emer_id: &str, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/<emer_id>/reject")] -fn reject_emergency_access(emer_id: &str, _headers: Headers, _conn: DbConn) -> JsonResult { - err!("Emergency access is not allowed.") +fn reject_emergency_access(emer_id: &str, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/<emer_id>/view")] -fn view_emergency_access(emer_id: &str, _headers: Headers, _conn: DbConn) -> JsonResult { - err!("Emergency access is not allowed.") +fn view_emergency_access(emer_id: &str, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/emergency-access/<emer_id>/takeover")] -fn takeover_emergency_access(emer_id: &str, _headers: Headers, _conn: DbConn) -> JsonResult { - err!("Emergency access is not allowed.") +fn takeover_emergency_access(emer_id: &str, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[derive(Deserialize)] @@ -184,13 +175,12 @@ fn password_emergency_access( emer_id: &str, data: JsonUpcase<EmergencyAccessPasswordData>, _headers: Headers, - _conn: DbConn, -) -> EmptyResult { - err!("Emergency access is not allowed.") +) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/emergency-access/<emer_id>/policies")] -fn policies_emergency_access(emer_id: &str, _headers: Headers, _conn: DbConn) -> JsonResult { - err!("Emergency access not valid.") +fn policies_emergency_access(emer_id: &str, _headers: Headers) -> Error { + Error::new(ACCESS_NOT_ALLOWED_MSG, ACCESS_NOT_ALLOWED_MSG) } diff --git a/src/api/core/events.rs b/src/api/core/events.rs @@ -1,7 +1,6 @@ use crate::{ - api::{EmptyResult, JsonResult, JsonUpcaseVec}, + api::JsonUpcaseVec, auth::{AdminHeaders, Headers}, - db::DbConn, }; use rocket::{form::FromForm, serde::json::Json, Route}; use serde_json::Value; @@ -19,83 +18,50 @@ struct EventRange { } // Upstream: https://github.com/bitwarden/server/blob/9ecf69d9cabce732cf2c57976dd9afa5728578fb/src/Api/Controllers/EventsController.cs#LL84C35-L84C41 -#[allow( - unused_variables, - clippy::needless_pass_by_value, - clippy::unnecessary_wraps -)] +#[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/organizations/<org_id>/events?<data..>")] -fn get_org_events( - org_id: &str, - data: EventRange, - _headers: AdminHeaders, - _conn: DbConn, -) -> JsonResult { - Ok(Json(json!({ +fn get_org_events(org_id: &str, data: EventRange, _headers: AdminHeaders) -> Json<Value> { + Json(json!({ "Data": Vec::<Value>::new(), "Object": "list", "ContinuationToken": None::<&str>, - }))) + })) } -#[allow( - unused_variables, - clippy::needless_pass_by_value, - clippy::unnecessary_wraps -)] +#[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/ciphers/<cipher_id>/events?<data..>")] -fn get_cipher_events( - cipher_id: &str, - data: EventRange, - _headers: Headers, - _conn: DbConn, -) -> JsonResult { - Ok(Json(json!({ +fn get_cipher_events(cipher_id: &str, data: EventRange, _headers: Headers) -> Json<Value> { + Json(json!({ "Data": Vec::<Value>::new(), "Object": "list", "ContinuationToken": None::<&str>, - }))) + })) } -#[allow( - unused_variables, - clippy::needless_pass_by_value, - clippy::unnecessary_wraps -)] +#[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/organizations/<org_id>/users/<user_org_id>/events?<data..>")] fn get_user_events( org_id: &str, user_org_id: &str, data: EventRange, _headers: AdminHeaders, - _conn: DbConn, -) -> JsonResult { - Ok(Json(json!({ +) -> Json<Value> { + Json(json!({ "Data": Vec::<Value>::new(), "Object": "list", "ContinuationToken": None::<&str>, - }))) + })) } pub fn main_routes() -> Vec<Route> { routes![post_events_collect,] } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct EventCollection; // Upstream: // https://github.com/bitwarden/server/blob/8a22c0479e987e756ce7412c48a732f9002f0a2d/src/Events/Controllers/CollectController.cs // https://github.com/bitwarden/server/blob/8a22c0479e987e756ce7412c48a732f9002f0a2d/src/Core/Services/Implementations/EventService.cs -#[allow( - unused_variables, - clippy::needless_pass_by_value, - clippy::unnecessary_wraps -)] +#[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/collect", format = "application/json", data = "<data>")] -fn post_events_collect( - data: JsonUpcaseVec<EventCollection>, - _headers: Headers, - _conn: DbConn, -) -> EmptyResult { - Ok(()) -} +fn post_events_collect(data: JsonUpcaseVec<EventCollection>, _headers: Headers) {} diff --git a/src/api/core/mod.rs b/src/api/core/mod.rs @@ -41,14 +41,14 @@ pub fn events_routes() -> Vec<Route> { // Move this somewhere else // use crate::{ - api::{JsonResult, JsonUpcase, Notify}, + api::{JsonResult, JsonUpcase}, auth::Headers, db::DbConn, }; use rocket::{serde::json::Json, Catcher, Route}; use serde_json::Value; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize)] #[allow(non_snake_case)] struct GlobalDomain { Type: i32, @@ -82,7 +82,7 @@ fn _get_eq_domains(headers: Headers, no_excluded: bool) -> Json<Value> { })) } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct EquivDomainData { ExcludedGlobalEquivalentDomains: Option<Vec<i32>>, @@ -94,7 +94,6 @@ async fn post_eq_domains( data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbConn, - _nt: Notify<'_>, ) -> JsonResult { let data: EquivDomainData = data.into_inner().data; let excluded_globals = data.ExcludedGlobalEquivalentDomains.unwrap_or_default(); @@ -112,20 +111,17 @@ async fn put_eq_domains( data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbConn, - nt: Notify<'_>, ) -> JsonResult { - post_eq_domains(data, headers, conn, nt).await + post_eq_domains(data, headers, conn).await } #[allow(unused_variables)] #[get("/hibp/breach?<username>")] -fn hibp_breach(username: &str) -> JsonResult { - Err(Error::empty().with_code(404)) +fn hibp_breach(username: &str) -> Error { + Error::empty().with_code(404) } -// We use DbConn here to let the alive healthcheck also verify the database connection. -#[allow(clippy::needless_pass_by_value)] #[get("/alive")] -fn alive(_conn: DbConn) -> Json<String> { +fn alive() -> Json<String> { now() } diff --git a/src/api/core/organizations.rs b/src/api/core/organizations.rs @@ -5,12 +5,11 @@ use crate::{ PasswordOrOtpData, UpdateType, }, auth::{self, AdminHeaders, Headers, ManagerHeaders, ManagerHeadersLoose, OwnerHeaders}, - crypto, db::{ models::{ Cipher, Collection, CollectionCipher, CollectionUser, OrgPolicy, OrgPolicyErr, - OrgPolicyType, Organization, OrganizationApiKey, TwoFactor, User, UserOrgStatus, - UserOrgType, UserOrganization, + OrgPolicyType, Organization, TwoFactorType, User, UserOrgStatus, UserOrgType, + UserOrganization, }, DbConn, }, @@ -18,7 +17,6 @@ use crate::{ util, }; use core::convert; -use num_traits::FromPrimitive; use rocket::serde::json::Json; use rocket::Route; use serde_json::Value; @@ -118,7 +116,7 @@ struct OrgData { _PlanType: NumberOrString, // Ignored, always use the same plan } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct OrganizationUpdateData { BillingEmail: String, @@ -149,16 +147,17 @@ struct OrgKeyData { PublicKey: String, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct OrgBulkIds { Ids: Vec<String>, } +const ORG_CREATION_NOT_ALLOWED_MSG: &str = "Organization creation is not allowed"; #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations", data = "<data>")] -fn create_organization(_headers: Headers, data: JsonUpcase<OrgData>, _conn: DbConn) -> JsonResult { - err!("User not allowed to create organizations") +fn create_organization(_headers: Headers, data: JsonUpcase<OrgData>) -> Error { + Error::new(ORG_CREATION_NOT_ALLOWED_MSG, ORG_CREATION_NOT_ALLOWED_MSG) } #[delete("/organizations/<org_id>", data = "<data>")] @@ -169,7 +168,7 @@ async fn delete_organization( conn: DbConn, ) -> EmptyResult { let data: PasswordOrOtpData = data.into_inner().data; - data.validate(&headers.user, true, &conn).await?; + data.validate(&headers.user)?; match Organization::find_by_uuid(org_id, &conn).await { None => err!("Organization not found"), Some(org) => org.delete(&conn).await, @@ -493,7 +492,7 @@ async fn delete_organization_collection( _delete_organization_collection(org_id, col_id, &headers, &conn).await } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case, dead_code)] struct DeleteCollectionData { Id: String, @@ -514,7 +513,7 @@ async fn post_organization_collection_delete( _delete_organization_collection(org_id, col_id, &headers, &conn).await } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct BulkCollectionIds { Ids: Vec<String>, @@ -751,16 +750,11 @@ struct InviteData { Collections: Option<Vec<CollectionData>>, AccessAll: Option<bool>, } - +const INVITATIONS_NOT_ALLOWED_MSG: &str = "Invitations are not allowed."; #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations/<org_id>/users/invite", data = "<data>")] -fn send_invite( - org_id: &str, - data: JsonUpcase<InviteData>, - _headers: AdminHeaders, - _conn: DbConn, -) -> EmptyResult { - err!("No more organizations are allowed.") +fn send_invite(org_id: &str, data: JsonUpcase<InviteData>, _headers: AdminHeaders) -> Error { + Error::new(INVITATIONS_NOT_ALLOWED_MSG, INVITATIONS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations/<org_id>/users/reinvite", data = "<data>")] @@ -768,7 +762,6 @@ fn bulk_reinvite_user( org_id: &str, data: JsonUpcase<OrgBulkIds>, _headers: AdminHeaders, - _conn: DbConn, ) -> Json<Value> { Json(json!({ "Data": Vec::<Value>::new(), @@ -779,13 +772,8 @@ fn bulk_reinvite_user( #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations/<org_id>/users/<user_org>/reinvite")] -fn reinvite_user( - org_id: &str, - user_org: &str, - _headers: AdminHeaders, - _conn: DbConn, -) -> EmptyResult { - err!("Invitations are not allowed.") +fn reinvite_user(org_id: &str, user_org: &str, _headers: AdminHeaders) -> Error { + Error::new(INVITATIONS_NOT_ALLOWED_MSG, INVITATIONS_NOT_ALLOWED_MSG) } #[derive(Deserialize)] @@ -797,13 +785,8 @@ struct AcceptData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations/<org_id>/users/<_org_user_id>/accept", data = "<data>")] -fn accept_invite( - org_id: &str, - _org_user_id: &str, - data: JsonUpcase<AcceptData>, - _conn: DbConn, -) -> EmptyResult { - err!("No more organizations are allowed.") +fn accept_invite(org_id: &str, _org_user_id: &str, data: JsonUpcase<AcceptData>) -> Error { + Error::new(INVITATIONS_NOT_ALLOWED_MSG, INVITATIONS_NOT_ALLOWED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -812,8 +795,6 @@ fn bulk_confirm_invite( org_id: &str, data: JsonUpcase<Value>, _headers: AdminHeaders, - _conn: DbConn, - _nt: Notify<'_>, ) -> Json<Value> { Json(json!({ "Data": Vec::<Value>::new(), @@ -829,10 +810,8 @@ fn confirm_invite( org_user_id: &str, data: JsonUpcase<Value>, _headers: AdminHeaders, - _conn: DbConn, - _nt: Notify<'_>, -) -> EmptyResult { - err!("No more organizations are allowed.") +) -> Error { + Error::new(INVITATIONS_NOT_ALLOWED_MSG, INVITATIONS_NOT_ALLOWED_MSG) } #[get("/organizations/<org_id>/users/<org_user_id>?<data..>")] @@ -1208,7 +1187,7 @@ async fn get_policy( _headers: AdminHeaders, conn: DbConn, ) -> JsonResult { - let Some(pol_type_enum) = OrgPolicyType::from_i32(pol_type) else { + let Ok(pol_type_enum) = OrgPolicyType::try_from(pol_type) else { err!("Invalid or unsupported policy type") }; let policy = (OrgPolicy::find_by_org_and_type(org_id, pol_type_enum, &conn).await).map_or_else( @@ -1235,20 +1214,17 @@ async fn put_policy( conn: DbConn, ) -> JsonResult { let data: PolicyData = data.into_inner(); - let Some(pol_type_enum) = OrgPolicyType::from_i32(pol_type) else { + let Ok(pol_type_enum) = OrgPolicyType::try_from(pol_type) else { err!("Invalid or unsupported policy type") }; // When enabling the TwoFactorAuthentication policy, remove this org's members that do have 2FA if pol_type_enum == OrgPolicyType::TwoFactorAuthentication && data.enabled { for member in UserOrganization::find_by_org(org_id, &conn).await { - let user_twofactor_disabled = TwoFactor::find_by_user(&member.user_uuid, &conn) - .await - .is_empty(); // Policy only applies to non-Owner/non-Admin members who have accepted joining the org // Invited users still need to accept the invite and will get an error when they try to accept the invite. - if user_twofactor_disabled - && member.atype < UserOrgType::Admin + if member.atype < UserOrgType::Admin && member.status != i32::from(UserOrgStatus::Invited) + && !TwoFactorType::has_twofactor(&member.user_uuid, &conn).await? { member.delete(&conn).await?; } @@ -1340,7 +1316,7 @@ fn _empty_data_json() -> Value { }) } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case, dead_code)] struct OrgImportGroupData { Name: String, // "GroupName" @@ -1348,7 +1324,7 @@ struct OrgImportGroupData { Users: Vec<String>, // ["uid=user,ou=People,dc=example,dc=com"] } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct OrgImportUserData { Email: String, // "user@maildomain.net" @@ -1357,7 +1333,7 @@ struct OrgImportUserData { Deleted: bool, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct OrgImportData { #[allow(dead_code)] @@ -1635,24 +1611,20 @@ async fn _restore_organization_user( Ok(()) } -#[allow( - unused_variables, - clippy::needless_pass_by_value, - clippy::unnecessary_wraps -)] +#[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/organizations/<org_id>/groups")] -fn get_groups(org_id: &str, _headers: ManagerHeadersLoose, _conn: DbConn) -> JsonResult { - Ok(Json(json!({ +fn get_groups(org_id: &str, _headers: ManagerHeadersLoose) -> Json<Value> { + Json(json!({ "Data": Vec::<Value>::new(), "Object": "list", "ContinuationToken": null, - }))) + })) } #[derive(Deserialize)] #[allow(non_snake_case)] struct GroupRequest; -#[derive(Deserialize, Serialize)] +#[derive(Serialize)] #[allow(non_snake_case)] struct SelectionReadOnly { Id: String, @@ -1673,6 +1645,7 @@ impl SelectionReadOnly { } } +const GROUPS_DISABLED_MSG: &str = "Groups are disabled."; #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations/<org_id>/groups/<group_id>", data = "<data>")] fn post_group( @@ -1680,20 +1653,14 @@ fn post_group( group_id: &str, data: JsonUpcase<GroupRequest>, _headers: AdminHeaders, - _conn: DbConn, -) -> JsonResult { - err!("Group support is disabled") +) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations/<org_id>/groups", data = "<data>")] -fn post_groups( - org_id: &str, - _headers: AdminHeaders, - data: JsonUpcase<GroupRequest>, - _conn: DbConn, -) -> JsonResult { - err!("Group support is disabled") +fn post_groups(org_id: &str, _headers: AdminHeaders, data: JsonUpcase<GroupRequest>) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -1703,69 +1670,43 @@ fn put_group( group_id: &str, data: JsonUpcase<GroupRequest>, _headers: AdminHeaders, - _conn: DbConn, -) -> JsonResult { - err!("Group support is disabled") +) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/organizations/<_org_id>/groups/<group_id>/details")] -fn get_group_details( - _org_id: &str, - group_id: &str, - _headers: AdminHeaders, - _conn: DbConn, -) -> JsonResult { - err!("Group support is disabled"); +fn get_group_details(_org_id: &str, group_id: &str, _headers: AdminHeaders) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations/<org_id>/groups/<group_id>/delete")] -fn post_delete_group( - org_id: &str, - group_id: &str, - _headers: AdminHeaders, - _conn: DbConn, -) -> EmptyResult { - err!("Group support is disabled"); +fn post_delete_group(org_id: &str, group_id: &str, _headers: AdminHeaders) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[delete("/organizations/<org_id>/groups/<group_id>")] -fn delete_group( - org_id: &str, - group_id: &str, - _headers: AdminHeaders, - _conn: DbConn, -) -> EmptyResult { - err!("Group support is disabled"); +fn delete_group(org_id: &str, group_id: &str, _headers: AdminHeaders) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[delete("/organizations/<org_id>/groups", data = "<data>")] -fn bulk_delete_groups( - org_id: &str, - data: JsonUpcase<OrgBulkIds>, - _headers: AdminHeaders, - _conn: DbConn, -) -> EmptyResult { - err!("Group support is disabled"); +fn bulk_delete_groups(org_id: &str, data: JsonUpcase<OrgBulkIds>, _headers: AdminHeaders) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/organizations/<_org_id>/groups/<group_id>")] -fn get_group(_org_id: &str, group_id: &str, _headers: AdminHeaders, _conn: DbConn) -> JsonResult { - err!("Group support is disabled"); +fn get_group(_org_id: &str, group_id: &str, _headers: AdminHeaders) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/organizations/<_org_id>/groups/<group_id>/users")] -fn get_group_users( - _org_id: &str, - group_id: &str, - _headers: AdminHeaders, - _conn: DbConn, -) -> JsonResult { - err!("Group support is disabled"); +fn get_group_users(_org_id: &str, group_id: &str, _headers: AdminHeaders) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[put("/organizations/<org_id>/groups/<group_id>/users", data = "<data>")] @@ -1774,20 +1715,14 @@ fn put_group_users( group_id: &str, _headers: AdminHeaders, data: JsonVec<String>, - _conn: DbConn, -) -> EmptyResult { - err!("Group support is disabled"); +) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/organizations/<_org_id>/users/<user_id>/groups")] -fn get_user_groups( - _org_id: &str, - user_id: &str, - _headers: AdminHeaders, - _conn: DbConn, -) -> JsonResult { - err!("Group support is disabled") +fn get_user_groups(_org_id: &str, user_id: &str, _headers: AdminHeaders) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[derive(Deserialize)] @@ -1801,9 +1736,8 @@ fn post_user_groups( org_user_id: &str, data: JsonUpcase<OrganizationUserUpdateGroupsRequest>, _headers: AdminHeaders, - _conn: DbConn, -) -> EmptyResult { - err!("Group support is disabled") +) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -1813,9 +1747,8 @@ fn put_user_groups( org_user_id: &str, data: JsonUpcase<OrganizationUserUpdateGroupsRequest>, _headers: AdminHeaders, - _conn: DbConn, -) -> EmptyResult { - err!("Group support is disabled") +) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -1825,9 +1758,8 @@ fn post_delete_group_user( group_id: &str, org_user_id: &str, _headers: AdminHeaders, - _conn: DbConn, -) -> EmptyResult { - err!("Group support is disabled") +) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -1837,9 +1769,8 @@ fn delete_group_user( group_id: &str, org_user_id: &str, _headers: AdminHeaders, - _conn: DbConn, -) -> EmptyResult { - err!("Group support is disabled") +) -> Error { + Error::new(GROUPS_DISABLED_MSG, GROUPS_DISABLED_MSG) } #[derive(Deserialize)] @@ -1867,7 +1798,7 @@ async fn get_organization_keys(org_id: &str, conn: DbConn) -> JsonResult { "PrivateKey": org.private_key, }))) } - +const PASS_RESET_MSG: &str = "Password reset is not supported on an email-disabled instance."; #[allow(unused_variables)] #[put( "/organizations/<org_id>/users/<org_user_id>/reset-password", @@ -1879,7 +1810,6 @@ async fn put_reset_password( _headers: AdminHeaders, data: JsonUpcase<OrganizationUserResetPasswordRequest>, conn: DbConn, - _nt: Notify<'_>, ) -> EmptyResult { let Some(org) = Organization::find_by_uuid(org_id, &conn).await else { err!("Required organization not found") @@ -1890,20 +1820,15 @@ async fn put_reset_password( err!("User to reset isn't member of required organization") }; match User::find_by_uuid(&org_user.user_uuid, &conn).await { - Some(_) => err!("Password reset is not supported on an email-disabled instance."), + Some(_) => err!(PASS_RESET_MSG), None => err!("User not found"), } } #[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/organizations/<org_id>/users/<org_user_id>/reset-password-details")] -fn get_reset_password_details( - org_id: &str, - org_user_id: &str, - _headers: AdminHeaders, - _conn: DbConn, -) -> JsonResult { - err!("Password reset is not supported on an email-disabled instance.") +fn get_reset_password_details(org_id: &str, org_user_id: &str, _headers: AdminHeaders) -> Error { + Error::new(PASS_RESET_MSG, PASS_RESET_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -1916,9 +1841,8 @@ fn put_reset_password_enrollment( org_user_id: &str, _headers: Headers, data: JsonUpcase<OrganizationUserResetPasswordEnrollmentRequest>, - _conn: DbConn, -) -> EmptyResult { - err!("Password reset is not supported on an email-disabled instance.") +) -> Error { + Error::new(PASS_RESET_MSG, PASS_RESET_MSG) } // This is a new function active since the v2022.9.x clients. @@ -1964,61 +1888,18 @@ async fn get_org_export(org_id: &str, headers: AdminHeaders, conn: DbConn) -> Js })) } } - -async fn _api_key( - org_id: &str, - data: JsonUpcase<PasswordOrOtpData>, - rotate: bool, - headers: AdminHeaders, - conn: DbConn, -) -> JsonResult { - let data: PasswordOrOtpData = data.into_inner().data; - let user = headers.user; - // Validate the admin users password/otp - data.validate(&user, true, &conn).await?; - let org_api_key = - if let Some(mut org_api_key) = OrganizationApiKey::find_by_org_uuid(org_id, &conn).await { - if rotate { - org_api_key.api_key = crypto::generate_api_key(); - org_api_key.revision_date = chrono::Utc::now().naive_utc(); - org_api_key - .save(&conn) - .await - .expect("Error rotating organization API Key"); - } - org_api_key - } else { - let api_key = crypto::generate_api_key(); - let new_org_api_key = OrganizationApiKey::new(String::from(org_id), api_key); - new_org_api_key - .save(&conn) - .await - .expect("Error creating organization API Key"); - new_org_api_key - }; - Ok(Json(json!({ - "ApiKey": org_api_key.api_key, - "RevisionDate": util::format_date(&org_api_key.revision_date), - "Object": "apiKey", - }))) -} - +const API_DISABLED_MSG: &str = "API access is disabled."; +#[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations/<org_id>/api-key", data = "<data>")] -async fn api_key( - org_id: &str, - data: JsonUpcase<PasswordOrOtpData>, - headers: AdminHeaders, - conn: DbConn, -) -> JsonResult { - _api_key(org_id, data, false, headers, conn).await +fn api_key(org_id: &str, data: JsonUpcase<PasswordOrOtpData>, _headers: AdminHeaders) -> Error { + Error::new(API_DISABLED_MSG, API_DISABLED_MSG) } - +#[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/organizations/<org_id>/rotate-api-key", data = "<data>")] -async fn rotate_api_key( +fn rotate_api_key( org_id: &str, data: JsonUpcase<PasswordOrOtpData>, - headers: AdminHeaders, - conn: DbConn, -) -> JsonResult { - _api_key(org_id, data, true, headers, conn).await + _headers: AdminHeaders, +) -> Error { + Error::new(API_DISABLED_MSG, API_DISABLED_MSG) } diff --git a/src/api/core/public.rs b/src/api/core/public.rs @@ -1,13 +1,5 @@ -use crate::{ - api::{EmptyResult, JsonUpcase}, - auth, - db::{models::OrganizationApiKey, DbConn}, -}; -use chrono::Utc; -use rocket::{ - request::{self, FromRequest, Outcome}, - Request, Route, -}; +use crate::{api::JsonUpcase, error::Error}; +use rocket::Route; pub fn routes() -> Vec<Route> { routes![ldap_import] @@ -34,65 +26,7 @@ struct OrgImportData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/public/organization/import", data = "<data>")] -fn ldap_import(data: JsonUpcase<OrgImportData>, _token: PublicToken, _conn: DbConn) -> EmptyResult { - err!("LDAP import is permanently disabled.") -} - -struct PublicToken(String); - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for PublicToken { - type Error = &'static str; - async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> { - let headers = request.headers(); - // Get access_token - let access_token: &str = match headers.get_one("Authorization") { - Some(a) => match a.rsplit("Bearer ").next() { - Some(split) => split, - None => err_handler!("No access token provided"), - }, - None => err_handler!("No access token provided"), - }; - // Check JWT token is valid and get device and user from it - let Ok(claims) = auth::decode_api_org(access_token) else { - err_handler!("Invalid claim") - }; - // Check if time is between claims.nbf and claims.exp - let time_now = Utc::now().naive_utc().timestamp(); - if time_now < claims.nbf { - err_handler!("Token issued in the future"); - } - if time_now > claims.exp { - err_handler!("Token expired"); - } - // Check if claims.iss is host|claims.scope[0] - let Outcome::Success(host) = auth::Host::from_request(request).await else { - err_handler!("Error getting Host") - }; - let complete_host = format!("{}|{}", host.host, claims.scope[0]); - if complete_host != claims.iss { - err_handler!("Token not issued by this server"); - } - // Check if claims.sub is org_api_key.uuid - // Check if claims.client_sub is org_api_key.org_uuid - let Some(org_uuid) = claims.client_id.strip_prefix("organization.") else { - err_handler!("Malformed client_id") - }; - let org_api_key = match DbConn::from_request(request).await { - Outcome::Success(conn) => { - match OrganizationApiKey::find_by_org_uuid(org_uuid, &conn).await { - Some(org_api_key) => org_api_key, - None => err_handler!("Invalid client_id"), - } - } - Outcome::Error(_) | Outcome::Forward(_) => err_handler!("Error getting DB"), - }; - if org_api_key.org_uuid != claims.client_sub { - err_handler!("Token not issued for this org"); - } - if org_api_key.uuid != claims.sub { - err_handler!("Token not issued for this client"); - } - Outcome::Success(Self(claims.client_sub)) - } +fn ldap_import(data: JsonUpcase<OrgImportData>) -> Error { + const MSG: &str = "LDAP import is permanently disabled."; + Error::new(MSG, MSG) } diff --git a/src/api/core/sends.rs b/src/api/core/sends.rs @@ -1,7 +1,7 @@ use crate::{ - api::{EmptyResult, JsonResult, JsonUpcase, Notify, NumberOrString}, - auth::{ClientIp, Headers, Host}, - db::DbConn, + api::{JsonUpcase, NumberOrString}, + auth::Headers, + error::Error, util::{SafeString, UpCase}, }; use chrono::{DateTime, Utc}; @@ -47,29 +47,24 @@ struct SendData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/sends")] -fn get_sends(_headers: Headers, _conn: DbConn) -> Json<Value> { +fn get_sends(_headers: Headers) -> Json<Value> { Json(json!({ "Data": Vec::<Value>::new(), "Object": "list", "ContinuationToken": null })) } - +const SENDS_DISABLED_MSG: &str = "Sends are disabled."; #[allow(unused_variables, clippy::needless_pass_by_value)] #[get("/sends/<uuid>")] -fn get_send(uuid: &str, _headers: Headers, _conn: DbConn) -> JsonResult { - err!("Sends are permanently disabled.") +fn get_send(uuid: &str, _headers: Headers) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/sends", data = "<data>")] -fn post_send( - data: JsonUpcase<SendData>, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> JsonResult { - err!("Sends are permanently disabled.") +fn post_send(data: JsonUpcase<SendData>, _headers: Headers) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } #[allow(dead_code)] @@ -87,19 +82,14 @@ struct UploadDataV2<'f> { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/sends/file", format = "multipart/form-data", data = "<data>")] -fn post_send_file( - data: Form<UploadData<'_>>, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> JsonResult { - err!("Sends are permanently disabled.") +fn post_send_file(data: Form<UploadData<'_>>, _headers: Headers) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/sends/file/v2", data = "<data>")] -fn post_send_file_v2(data: JsonUpcase<SendData>, _headers: Headers, _conn: DbConn) -> JsonResult { - err!("Sends are permanently disabled.") +fn post_send_file_v2(data: JsonUpcase<SendData>, _headers: Headers) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -113,10 +103,8 @@ fn post_send_file_v2_data( file_id: &str, data: Form<UploadDataV2<'_>>, _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> EmptyResult { - err!("Sends are permanently disabled.") +) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } #[derive(Deserialize)] @@ -127,27 +115,14 @@ struct SendAccessData { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/sends/access/<access_id>", data = "<data>")] -fn post_access( - access_id: &str, - data: JsonUpcase<SendAccessData>, - _conn: DbConn, - _ip: ClientIp, - _nt: Notify<'_>, -) -> JsonResult { - err!("Sends are permanently disabled.") +fn post_access(access_id: &str, data: JsonUpcase<SendAccessData>) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/sends/<send_id>/access/file/<file_id>", data = "<data>")] -fn post_access_file( - send_id: &str, - file_id: &str, - data: JsonUpcase<SendAccessData>, - _host: Host, - _conn: DbConn, - _nt: Notify<'_>, -) -> JsonResult { - err!("Sends are permanently disabled.") +fn post_access_file(send_id: &str, file_id: &str, data: JsonUpcase<SendAccessData>) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] @@ -158,24 +133,18 @@ fn download_send(send_id: SafeString, file_id: SafeString, t: &str) -> Option<Na #[allow(unused_variables, clippy::needless_pass_by_value)] #[put("/sends/<id>", data = "<data>")] -fn put_send( - id: &str, - data: JsonUpcase<SendData>, - _headers: Headers, - _conn: DbConn, - _nt: Notify<'_>, -) -> JsonResult { - err!("Sends are permanently disabled.") +fn put_send(id: &str, data: JsonUpcase<SendData>, _headers: Headers) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[delete("/sends/<id>")] -fn delete_send(id: &str, _headers: Headers, _conn: DbConn, _nt: Notify<'_>) -> EmptyResult { - err!("Sends are permanently disabled.") +fn delete_send(id: &str, _headers: Headers) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } #[allow(unused_variables, clippy::needless_pass_by_value)] #[put("/sends/<id>/remove-password")] -fn put_remove_password(id: &str, _headers: Headers, _conn: DbConn, _nt: Notify<'_>) -> JsonResult { - err!("Sends are permanently disabled.") +fn put_remove_password(id: &str, _headers: Headers) -> Error { + Error::new(SENDS_DISABLED_MSG, SENDS_DISABLED_MSG) } diff --git a/src/api/core/two_factor/authenticator.rs b/src/api/core/two_factor/authenticator.rs @@ -2,10 +2,8 @@ use crate::{ api::{EmptyResult, JsonResult, JsonUpcase, NumberOrString, PasswordOrOtpData}, auth::{ClientIp, Headers}, crypto, - db::{ - models::{TwoFactor, TwoFactorType}, - DbConn, - }, + db::{models::Totp, DbConn}, + error::Error, }; use data_encoding::BASE32; use rocket::serde::json::Json; @@ -27,11 +25,10 @@ async fn generate_authenticator( ) -> JsonResult { let data: PasswordOrOtpData = data.into_inner().data; let user = headers.user; - data.validate(&user, false, &conn).await?; - let type_ = i32::from(TwoFactorType::Authenticator); - let twofactor = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await; - let (enabled, key) = match twofactor { - Some(tf) => (true, tf.data), + data.validate(&user)?; + let totp = Totp::find_by_user(&user.uuid, &conn).await?; + let (enabled, key) = match totp { + Some(t) => (true, t.token), _ => (false, crypto::encode_random_bytes::<20>(&BASE32)), }; Ok(Json(json!({ @@ -41,7 +38,7 @@ async fn generate_authenticator( }))) } -#[derive(Deserialize, Debug)] +#[derive(Deserialize)] #[allow(non_snake_case)] struct EnableAuthenticatorData { Key: String, @@ -64,8 +61,7 @@ async fn activate_authenticator( MasterPasswordHash: data.MasterPasswordHash, Otp: data.Otp, } - .validate(&user, true, &conn) - .await?; + .validate(&user)?; // Validate key as base32 and 20 bytes length let decoded_key: Vec<u8> = match BASE32.decode(key.as_bytes()) { Ok(decoded) => decoded, @@ -75,7 +71,7 @@ async fn activate_authenticator( err!("Invalid key length") } // Validate the token provided with the key, and save new twofactor - validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &conn).await?; + validate_totp_code(user.uuid, &token, key.to_uppercase(), &headers.ip, &conn).await?; Ok(Json(json!({ "Enabled": true, "Key": key, @@ -93,9 +89,9 @@ async fn activate_authenticator_put( } pub async fn validate_totp_code_str( - user_uuid: &str, + user_uuid: String, totp_code: &str, - secret: &str, + secret: String, ip: &ClientIp, conn: &DbConn, ) -> EmptyResult { @@ -106,9 +102,9 @@ pub async fn validate_totp_code_str( } #[allow(clippy::integer_division, clippy::redundant_else)] async fn validate_totp_code( - user_uuid: &str, + user_uuid: String, totp_code: &str, - secret: &str, + secret: String, ip: &ClientIp, conn: &DbConn, ) -> EmptyResult { @@ -116,49 +112,22 @@ async fn validate_totp_code( let Ok(decoded_secret) = BASE32.decode(secret.as_bytes()) else { err!("Invalid TOTP secret") }; - let mut twofactor = (TwoFactor::find_by_user_and_type( - user_uuid, - i32::from(TwoFactorType::Authenticator), - conn, - ) - .await) - .map_or_else( - || { - TwoFactor::new( - user_uuid.to_owned(), - TwoFactorType::Authenticator, - secret.to_owned(), - ) - }, - |tf| tf, - ); - // Get the current system time in UNIX Epoch (UTC) + let mut totp = Totp::find_by_user(user_uuid.as_str(), conn) + .await? + .unwrap_or_else(|| Totp::new(user_uuid, secret)); let current_time = chrono::Utc::now(); let current_timestamp = u64::try_from(current_time.timestamp()).expect("underflow"); let time_step = current_timestamp / 30u64; - // We need to calculate the time offset and cast it as a u64. - // Since we only have times into the future and the totp generator needs, a u64 instead of the default i64. - let generated = totp_custom::<Sha1>(30, 6, &decoded_secret, current_timestamp); - // Check the given code equals the generated one and if the time_step is larger than the one last used. - if generated == totp_code && time_step > twofactor.last_used() { - // Save the last used time step so only totp time steps higher then this one are allowed. - // This will also save a newly created twofactor if the code is correct. - twofactor.set_last_used(time_step); - twofactor.save(conn).await?; - Ok(()) - } else if generated == totp_code && time_step <= twofactor.last_used() { - warn!("This TOTP or a TOTP code within 0 steps back or forward has already been used!"); - err!(format!( - "Invalid TOTP code! Server time: {} IP: {}", - current_time.format("%F %T UTC"), - ip.ip - )); + if time_step > totp.get_last_used() + && totp_custom::<Sha1>(30, 6, &decoded_secret, current_timestamp) == totp_code + { + totp.set_last_used(time_step); + totp.replace(conn).await } else { - // Else no valid code received, deny access - err!(format!( + Err(Error::from(format!( "Invalid TOTP code! Server time: {} IP: {}", current_time.format("%F %T UTC"), ip.ip - )); + ))) } } diff --git a/src/api/core/two_factor/duo.rs b/src/api/core/two_factor/duo.rs @@ -0,0 +1,38 @@ +use crate::{ + api::{JsonUpcase, PasswordOrOtpData}, + auth::Headers, + error::Error, +}; +use rocket::Route; + +pub fn routes() -> Vec<Route> { + routes![get_duo, activate_duo, activate_duo_put,] +} +const DUO_DISABLED_MSG: &str = "Duo is disabled."; +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[post("/two-factor/get-duo", data = "<data>")] +fn get_duo(data: JsonUpcase<PasswordOrOtpData>, _headers: Headers) -> Error { + Error::new(DUO_DISABLED_MSG, DUO_DISABLED_MSG) +} + +#[derive(Deserialize)] +#[allow(non_snake_case, dead_code)] +struct EnableDuoData { + Host: String, + SecretKey: String, + IntegrationKey: String, + MasterPasswordHash: Option<String>, + Otp: Option<String>, +} + +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[post("/two-factor/duo", data = "<data>")] +fn activate_duo(data: JsonUpcase<EnableDuoData>, _headers: Headers) -> Error { + Error::new(DUO_DISABLED_MSG, DUO_DISABLED_MSG) +} + +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[put("/two-factor/duo", data = "<data>")] +fn activate_duo_put(data: JsonUpcase<EnableDuoData>, _headers: Headers) -> Error { + Error::new(DUO_DISABLED_MSG, DUO_DISABLED_MSG) +} diff --git a/src/api/core/two_factor/email.rs b/src/api/core/two_factor/email.rs @@ -0,0 +1,57 @@ +use crate::{ + api::{JsonUpcase, PasswordOrOtpData}, + auth::Headers, + error::Error, +}; +use rocket::Route; + +pub fn routes() -> Vec<Route> { + routes![get_email, send_email_login, send_email, email,] +} + +#[derive(Deserialize)] +#[allow(non_snake_case, dead_code)] +struct SendEmailLoginData { + Email: String, + MasterPasswordHash: String, +} +const EMAIL_DISABLED_MSG: &str = "E-mail is disabled."; +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[post("/two-factor/send-email-login", data = "<data>")] +fn send_email_login(data: JsonUpcase<SendEmailLoginData>) -> Error { + Error::new(EMAIL_DISABLED_MSG, EMAIL_DISABLED_MSG) +} + +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[post("/two-factor/get-email", data = "<data>")] +fn get_email(data: JsonUpcase<PasswordOrOtpData>, _headers: Headers) -> Error { + Error::new(EMAIL_DISABLED_MSG, EMAIL_DISABLED_MSG) +} + +#[derive(Deserialize)] +#[allow(non_snake_case, dead_code)] +struct SendEmailData { + Email: String, + MasterPasswordHash: Option<String>, + Otp: Option<String>, +} + +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[post("/two-factor/send-email", data = "<data>")] +fn send_email(data: JsonUpcase<SendEmailData>, _headers: Headers) -> Error { + Error::new(EMAIL_DISABLED_MSG, EMAIL_DISABLED_MSG) +} + +#[derive(Deserialize)] +#[allow(non_snake_case, dead_code)] +struct EmailData { + Email: String, + Token: String, + MasterPasswordHash: Option<String>, + Otp: Option<String>, +} +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[put("/two-factor/email", data = "<data>")] +fn email(data: JsonUpcase<EmailData>, _headers: Headers) -> Error { + Error::new(EMAIL_DISABLED_MSG, EMAIL_DISABLED_MSG) +} diff --git a/src/api/core/two_factor/mod.rs b/src/api/core/two_factor/mod.rs @@ -2,16 +2,19 @@ use crate::{ api::{JsonResult, JsonUpcase, NumberOrString, PasswordOrOtpData}, auth::{ClientHeaders, Headers}, db::{ - models::{OrgPolicyType, TwoFactor, UserOrgType, UserOrganization}, + models::{OrgPolicyType, TwoFactorType, UserOrgType, UserOrganization}, DbConn, }, }; pub mod authenticator; -pub mod protected_actions; +mod duo; +mod email; +mod protected_actions; use rocket::serde::json::Json; use rocket::Route; use serde_json::Value; pub mod webauthn; +mod yubikey; pub fn routes() -> Vec<Route> { let mut routes = routes![ @@ -23,71 +26,41 @@ pub fn routes() -> Vec<Route> { recover, ]; routes.append(&mut authenticator::routes()); + routes.append(&mut duo::routes()); + routes.append(&mut email::routes()); routes.append(&mut protected_actions::routes()); routes.append(&mut webauthn::routes()); + routes.append(&mut yubikey::routes()); routes } #[get("/two-factor")] async fn get_twofactor(headers: Headers, conn: DbConn) -> Json<Value> { - let twofactors = TwoFactor::find_by_user(&headers.user.uuid, &conn).await; - let twofactors_json: Vec<Value> = twofactors.iter().map(TwoFactor::to_json_provider).collect(); Json(json!({ - "Data": twofactors_json, + "Data": TwoFactorType::get_factors(&headers.user.uuid, &conn).await.expect("unable to get two factors"), "Object": "list", "ContinuationToken": null, })) } +#[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/two-factor/get-recover", data = "<data>")] -async fn get_recover( - data: JsonUpcase<PasswordOrOtpData>, - headers: Headers, - conn: DbConn, -) -> JsonResult { - let data: PasswordOrOtpData = data.into_inner().data; - let user = headers.user; - data.validate(&user, true, &conn).await?; - Ok(Json(json!({ - "Code": user.totp_recover, - "Object": "twoFactorRecover" - }))) +fn get_recover(data: JsonUpcase<PasswordOrOtpData>, _headers: Headers) -> JsonResult { + err!("recovery codes are disabled") } #[derive(Deserialize)] -#[allow(non_snake_case)] +#[allow(non_snake_case, dead_code)] struct RecoverTwoFactor { MasterPasswordHash: String, Email: String, RecoveryCode: String, } +#[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/two-factor/recover", data = "<data>")] -async fn recover( - data: JsonUpcase<RecoverTwoFactor>, - _client_headers: ClientHeaders, - conn: DbConn, -) -> JsonResult { - let data: RecoverTwoFactor = data.into_inner().data; - use crate::db::models::User; - // Get the user - let Some(mut user) = User::find_by_mail(&data.Email, &conn).await else { - err!("Username or password is incorrect. Try again.") - }; - // Check password - if !user.check_valid_password(&data.MasterPasswordHash) { - err!("Username or password is incorrect. Try again.") - } - // Check if recovery code is correct - if !user.check_valid_recovery_code(&data.RecoveryCode) { - err!("Recovery code is incorrect. Try again.") - } - // Remove all twofactors from the user - TwoFactor::delete_all_by_user(&user.uuid, &conn).await?; - // Remove the recovery code, not needed without twofactors - user.totp_recover = None; - user.save(&conn).await?; - Ok(Json(Value::Object(serde_json::Map::new()))) +fn recover(data: JsonUpcase<RecoverTwoFactor>, _client_headers: ClientHeaders) -> JsonResult { + err!("recovery codes are disabled") } #[derive(Deserialize)] @@ -111,14 +84,12 @@ async fn disable_twofactor( MasterPasswordHash: data.MasterPasswordHash, Otp: data.Otp, } - .validate(&user, true, &conn) - .await?; + .validate(&user)?; let type_ = data.Type.into_i32()?; - if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await { - twofactor.delete(&conn).await?; - } - let twofactor_disabled = TwoFactor::find_by_user(&user.uuid, &conn).await.is_empty(); - if twofactor_disabled { + TwoFactorType::try_from(type_)? + .delete_by_user(&user.uuid, &conn) + .await?; + if !TwoFactorType::has_twofactor(&user.uuid, &conn).await? { for user_org in UserOrganization::find_by_user_and_policy( &user.uuid, OrgPolicyType::TwoFactorAuthentication, @@ -158,7 +129,7 @@ async fn disable_twofactor_put( // The HTML part is hidden via the CSS patches done via the bw_web_build repo #[allow(clippy::needless_pass_by_value)] #[get("/two-factor/get-device-verification-settings")] -fn get_device_verification_settings(_headers: Headers, _conn: DbConn) -> Json<Value> { +fn get_device_verification_settings(_headers: Headers) -> Json<Value> { Json(json!({ "isDeviceVerificationSectionEnabled":false, "unknownDeviceVerificationEnabled":false, diff --git a/src/api/core/two_factor/protected_actions.rs b/src/api/core/two_factor/protected_actions.rs @@ -1,55 +1,17 @@ -use crate::{ - api::{EmptyResult, JsonUpcase}, - auth::Headers, - crypto, - db::{ - models::{TwoFactor, TwoFactorType}, - DbConn, - }, - error::{Error, MapResult}, -}; -use chrono::{Duration, NaiveDateTime, Utc}; +use crate::{api::JsonUpcase, auth::Headers, error::Error}; use rocket::Route; -use serde_json; pub fn routes() -> Vec<Route> { routes![request_otp, verify_otp] } - -/// Data stored in the TwoFactor table in the db -#[derive(Serialize, Deserialize, Debug)] -struct ProtectedActionData { - /// Token issued to validate the protected action - token: String, - /// UNIX timestamp of token issue. - token_sent: i64, - // The total amount of attempts - attempts: u8, -} - -impl ProtectedActionData { - fn from_json(string: &str) -> Result<Self, Error> { - let res: Result<Self, serde_json::Error> = serde_json::from_str(string); - match res { - Ok(x) => Ok(x), - Err(_) => err!("Could not decode ProtectedActionData from string"), - } - } - fn add_attempt(&mut self) { - self.attempts = self - .attempts - .checked_add(1) - .expect("add attempts overflowed"); - } -} - +const DEVICE_LOG_IN_MSG: &str = "Log in via device is disabled."; #[allow(clippy::needless_pass_by_value)] #[post("/accounts/request-otp")] -fn request_otp(_headers: Headers, _conn: DbConn) -> EmptyResult { - err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device.") +fn request_otp(_headers: Headers) -> Error { + Error::new(DEVICE_LOG_IN_MSG, DEVICE_LOG_IN_MSG) } -#[derive(Deserialize, Serialize, Debug)] +#[derive(Deserialize, Serialize)] #[allow(non_snake_case)] struct ProtectedActionVerify { OTP: String, @@ -57,55 +19,6 @@ struct ProtectedActionVerify { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/accounts/verify-otp", data = "<data>")] -fn verify_otp( - data: JsonUpcase<ProtectedActionVerify>, - _headers: Headers, - _conn: DbConn, -) -> EmptyResult { - err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device."); -} - -pub async fn validate_protected_action_otp( - otp: &str, - user_uuid: &str, - delete_if_valid: bool, - conn: &DbConn, -) -> EmptyResult { - let pa = TwoFactor::find_by_user_and_type( - user_uuid, - i32::from(TwoFactorType::ProtectedActions), - conn, - ) - .await - .map_res( - "Protected action token not found, try sending the code again or restart the process", - )?; - let mut pa_data = ProtectedActionData::from_json(&pa.data)?; - pa_data.add_attempt(); - // Delete the token after x attempts if it has been used too many times - // We use the 6, which should be more then enough for invalid attempts and multiple valid checks - if pa_data.attempts > 6 { - pa.delete(conn).await?; - err!("Token has expired") - } - // Check if the token has expired (Using the email 2fa expiration time) - let date = NaiveDateTime::from_timestamp_opt(pa_data.token_sent, 0) - .expect("Protected Action token timestamp invalid."); - let max_time = 600; - if date - .checked_add_signed(Duration::seconds(max_time)) - .expect("Duration add overflowed") - < Utc::now().naive_utc() - { - pa.delete(conn).await?; - err!("Token has expired") - } - if !crypto::ct_eq(&pa_data.token, otp) { - pa.save(conn).await?; - err!("Token is invalid") - } - if delete_if_valid { - pa.delete(conn).await?; - } - Ok(()) +fn verify_otp(data: JsonUpcase<ProtectedActionVerify>, _headers: Headers) -> Error { + Error::new(DEVICE_LOG_IN_MSG, DEVICE_LOG_IN_MSG) } diff --git a/src/api/core/two_factor/webauthn.rs b/src/api/core/two_factor/webauthn.rs @@ -1,21 +1,19 @@ use crate::{ - api::{EmptyResult, JsonResult, JsonUpcase, NumberOrString, PasswordOrOtpData}, + api::{EmptyResult, JsonResult, JsonUpcase, PasswordOrOtpData}, auth::Headers, config, db::{ - models::{TwoFactor, TwoFactorType, WebAuthn, WebauthnRegistration}, + models::{WebAuthn, WebAuthnAuth, WebAuthnChallenge, WebAuthnInfo, WebAuthnReg}, DbConn, }, error::Error, }; use rocket::serde::json::Json; use rocket::Route; -use serde_json::Value; use url::Url; use webauthn_rs::prelude::{ AttestationCa, AttestationCaList, AuthenticatorAttachment, PublicKeyCredential, - RegisterPublicKeyCredential, SecurityKeyAuthentication, SecurityKeyRegistration, Uuid, - Webauthn, WebauthnBuilder, WebauthnError, + RegisterPublicKeyCredential, Uuid, Webauthn, WebauthnBuilder, WebauthnError, }; pub fn routes() -> Vec<Route> { @@ -46,22 +44,61 @@ async fn get_webauthn( ) -> JsonResult { let data: PasswordOrOtpData = data.into_inner().data; let user = headers.user; - data.validate(&user, false, &conn).await?; - let (enabled, regs) = get_tf_entry(&user.uuid, i32::from(TwoFactorType::Webauthn), &conn) - .await - .map_or_else( - || Ok((false, Vec::new())), - |tf| { - tf.get_webauthn_registrations() - .map(|regs| (tf.enabled, regs)) - }, - )?; + data.validate(&user)?; + let keys = WebAuthnInfo::get_all_by_user(&user.uuid, &conn).await?; Ok(Json(json!({ - "Enabled": enabled, - "Keys": regs.iter().map(WebauthnRegistration::to_json).collect::<Value>(), + "Enabled": !keys.is_empty(), + "Keys": keys, "Object": "twoFactorWebAuthn" }))) } +// When YubiKey enforcement is enabled, we generate the allowed list of attestation CAs. +fn get_attestation_list() -> Option<AttestationCaList> { + config::get_config().webauthn_require_yubi.then(|| { + let mut attest_list = AttestationCaList::default(); + let mut ca = AttestationCa::yubico_u2f_root_ca_serial_457200631(); + ca.aaguids.clear(); + // We only allow FIDO2 YubiKeys with firmware 5.2.a, 5.4.b, 5.5.c, or 5.6.d (https://support.yubico.com/hc/en-us/articles/360016648959-YubiKey-Hardware-FIDO2-AAGUIDs). + // YubiKey 5 (USB-A, No NFC), YubiKey 5 Nano, YubiKey 5 Nano CSPN, YubiKey 5C, YubiKey 5C CSPN, + // YubiKey 5C Nano, and YubiKey 5C Nano CSPN. + ca.aaguids + .insert(Uuid::try_parse("ee882879-721c-4913-9775-3dfcce97072a").expect("invaild UUID")); + // YubiKey 5 NFC, YubiKey 5 NFC CSPN, YubiKey 5C NFC, and YubiKey 5C NFC CSPN. + ca.aaguids + .insert(Uuid::try_parse("2fc0579f-8113-47ea-b116-bb5a8db9202a").expect("invaild UUID")); + // YubiKey 5 NFC FIPS and YubiKey 5C NFC FIPS. + ca.aaguids + .insert(Uuid::try_parse("c1f9a0bc-1dd2-404a-b27f-8e29047a43fd").expect("invaild UUID")); + // YubiKey 5 Nano FIPS, YubiKey 5C FIPS, and YubiKey 5C Nano FIPS. + ca.aaguids + .insert(Uuid::try_parse("73bb0cd4-e502-49b8-9c6f-b59445bf720b").expect("invaild UUID")); + // YubiKey 5Ci and YubiKey 5Ci CSPN. + ca.aaguids + .insert(Uuid::try_parse("c5ef55ff-ad9a-4b9f-b580-adebafe026d0").expect("invaild UUID")); + // YubiKey 5Ci FIPS. + ca.aaguids + .insert(Uuid::try_parse("85203421-48f9-4355-9bc8-8a53846e5083").expect("invaild UUID")); + // Security Key By Yubico. + ca.aaguids + .insert(Uuid::try_parse("b92c3f9a-c014-4056-887f-140a2501163b").expect("invaild UUID")); + // Security Key NFC (USB-A, USB-C). + ca.aaguids + .insert(Uuid::try_parse("149a2021-8ef6-4133-96b8-81f8d5b7f1f5").expect("invaild UUID")); + // YubiKey Bio Series. + ca.aaguids + .insert(Uuid::try_parse("d8522d9f-575b-4866-88a9-ba99fa02f35b").expect("invaild UUID")); + // Security Key NFC Black (USB-A, USB-C). + ca.aaguids + .insert(Uuid::try_parse("a4e9fc6d-4cbe-4758-b8ba-37598bb5bbaa").expect("invaild UUID")); + // Security Key NFC - Enterprise Edition (USB-A, USB-C). + ca.aaguids + .insert(Uuid::try_parse("0bb43545-fd2c-4185-87dd-feb0b2916ace").expect("invaild UUID")); + attest_list + .insert(ca) + .expect("unable to add attestation CAs"); + attest_list + }) +} #[post("/two-factor/get-webauthn-challenge", data = "<data>")] async fn generate_webauthn_challenge( @@ -71,30 +108,22 @@ async fn generate_webauthn_challenge( ) -> JsonResult { let data: PasswordOrOtpData = data.into_inner().data; let user = headers.user; - data.validate(&user, false, &conn).await?; - // We only allow YubiKeys with firmware 5.2 or 5.4. - let mut attest_list = AttestationCaList::default(); - let mut ca = AttestationCa::yubico_u2f_root_ca_serial_457200631(); - ca.aaguids.clear(); - ca.aaguids - .insert(Uuid::try_parse("ee882879-721c-4913-9775-3dfcce97072a").expect("invaild UUID")); - attest_list.insert(ca)?; + data.validate(&user)?; + let (attest_list, auth_plat) = get_attestation_list().map_or_else( + || (None, None), + |attest| (Some(attest), Some(AuthenticatorAttachment::CrossPlatform)), + ); let (challenge, registration) = build_webauthn()?.start_securitykey_registration( Uuid::try_parse(user.uuid.as_str()).expect("unable to create UUID"), user.email.as_str(), user.name.as_str(), Some(WebAuthn::get_all_credentials_by_user(&user.uuid, &conn).await?), - Some(attest_list), - Some(AuthenticatorAttachment::CrossPlatform), + attest_list, + auth_plat, )?; - // We replace any existing registration challenges. - TwoFactor::new( - user.uuid, - TwoFactorType::WebauthnRegisterChallenge, - serde_json::to_string(&registration)?, - ) - .replace_challenge(&conn) - .await?; + WebAuthnChallenge::Reg(WebAuthnReg::new(user.uuid, &registration)?) + .replace(&conn) + .await?; let mut challenge_value = serde_json::to_value(challenge.public_key)?; challenge_value["status"] = "ok".into(); challenge_value["errorMessage"] = "".into(); @@ -104,7 +133,7 @@ async fn generate_webauthn_challenge( #[derive(Deserialize)] #[serde(rename_all = "camelCase")] struct EnableWebauthnData { - id: u32, + id: i64, name: String, device_response: RegisterPublicKeyCredential, master_password_hash: String, @@ -122,54 +151,23 @@ async fn activate_webauthn( MasterPasswordHash: Some(data.master_password_hash), Otp: None, } - .validate(&user, true, &conn) - .await?; + .validate(&user)?; // Retrieve and delete the saved challenge state - let tf_challenge = get_tf_entry( - &user.uuid, - i32::from(TwoFactorType::WebauthnRegisterChallenge), - &conn, - ) - .await - .ok_or_else(|| Error::from(String::from("no webauthn challenge")))?; - let registration = serde_json::from_str::<SecurityKeyRegistration>(&tf_challenge.data)?; - tf_challenge.delete_challenge(&conn).await?; + let chall = WebAuthnReg::find_by_user(&user.uuid, &conn) + .await? + .ok_or_else(|| Error::from(String::from("no webauthn challenge")))?; + let registration = chall.security_key_reg()?; + WebAuthnChallenge::Reg(chall).delete(&conn).await?; // Verify the credentials with the saved state let security_key = build_webauthn()?.finish_securitykey_registration(&data.device_response, &registration)?; - let cred_id = security_key.cred_id().to_string(); - let regs = match get_tf_entry(&user.uuid, i32::from(TwoFactorType::Webauthn), &conn).await { - None => { - let regs = vec![WebauthnRegistration { - id: data.id, - name: data.name, - security_key, - }]; - let tf = TwoFactor::new( - user.uuid, - TwoFactorType::Webauthn, - serde_json::to_string(&regs)?, - ); - tf.insert_insert_webauthn(tf.create_webauthn(cred_id), &conn) - .await?; - regs - } - Some(mut tf) => { - let mut regs = tf.get_webauthn_registrations()?; - regs.push(WebauthnRegistration { - id: data.id, - name: data.name, - security_key, - }); - tf.data = serde_json::to_string(&regs)?; - tf.update_insert_webauthn(tf.create_webauthn(cred_id), &conn) - .await?; - regs - } - }; + WebAuthn::new(user.uuid.clone(), data.id, data.name, &security_key)? + .insert(&conn) + .await?; + let keys = WebAuthnInfo::get_all_by_user(user.uuid.as_str(), &conn).await?; Ok(Json(json!({ - "Enabled": true, - "Keys": regs.iter().map(WebauthnRegistration::to_json).collect::<Value>(), + "Enabled": !keys.is_empty(), + "Keys": keys, "Object": "twoFactorU2f" }))) } @@ -186,7 +184,7 @@ async fn activate_webauthn_put( #[derive(Deserialize)] #[allow(non_snake_case)] struct DeleteU2FData { - Id: NumberOrString, + Id: i64, MasterPasswordHash: String, } @@ -202,52 +200,24 @@ async fn delete_webauthn( { err!("Invalid password"); } - let mut tf = get_tf_entry( - &headers.user.uuid, - i32::from(TwoFactorType::Webauthn), - &conn, - ) - .await - .ok_or_else(|| Error::from(String::from("no twofactor entries")))?; - let mut regs = tf.get_webauthn_registrations()?; - let id = u32::try_from(data.data.Id.into_i32()?).expect("underflow"); - let Some(item_pos) = regs.iter().position(|r| r.id == id) else { - err!("Webauthn entry not found") - }; - let old_cred = regs.remove(item_pos).security_key.cred_id().to_string(); - tf.data = serde_json::to_string(&regs)?; - tf.update_delete_webauthn(old_cred, &conn).await?; - drop(tf); + WebAuthn::delete_by_user_uuid_and_id(&headers.user.uuid, data.data.Id, &conn).await?; + let keys = WebAuthnInfo::get_all_by_user(&headers.user.uuid, &conn).await?; Ok(Json(json!({ - "Enabled": true, - "Keys": regs.iter().map(WebauthnRegistration::to_json).collect::<Value>(), + "Enabled": !keys.is_empty(), + "Keys": keys, "Object": "twoFactorU2f" }))) } -async fn get_tf_entry(user_uuid: &str, type_: i32, conn: &DbConn) -> Option<TwoFactor> { - TwoFactor::find_by_user_and_type(user_uuid, type_, conn).await -} - -pub async fn generate_webauthn_login(user_uuid: &str, conn: &DbConn) -> JsonResult { - let tf = get_tf_entry(user_uuid, i32::from(TwoFactorType::Webauthn), conn) - .await - .ok_or_else(|| Error::from(String::from("no twofactor entries")))?; - let regs = tf.get_webauthn_registrations()?; - if regs.is_empty() { - err!("No Webauthn devices registered") +pub async fn generate_webauthn_login(user_uuid: String, conn: &DbConn) -> JsonResult { + let keys = WebAuthn::get_all_security_keys(user_uuid.as_str(), conn).await?; + if keys.is_empty() { + err!("No WebAuthn devices registered") } - let (challenge, auth) = build_webauthn()? - .start_securitykey_authentication(&WebauthnRegistration::to_security_keys(regs))?; - // Save the challenge state for later validation. - TwoFactor::new( - user_uuid.into(), - TwoFactorType::WebauthnLoginChallenge, - serde_json::to_string(&auth)?, - ) - .replace_challenge(conn) - .await?; - // Return challenge to the clients + let (challenge, auth) = build_webauthn()?.start_securitykey_authentication(keys.as_slice())?; + WebAuthnChallenge::Auth(WebAuthnAuth::new(user_uuid, &auth)?) + .replace(conn) + .await?; Ok(Json(serde_json::to_value(challenge.public_key)?)) } @@ -256,34 +226,30 @@ pub async fn validate_webauthn_login( response: &str, conn: &DbConn, ) -> EmptyResult { - let tf_challenge = get_tf_entry( - user_uuid, - i32::from(TwoFactorType::WebauthnLoginChallenge), - conn, - ) - .await - .ok_or_else(|| Error::from(String::from("no webauthn challenge")))?; - let security_key_authentication = - serde_json::from_str::<SecurityKeyAuthentication>(&tf_challenge.data)?; - tf_challenge.delete_challenge(conn).await?; + let chall = WebAuthnAuth::find_by_user(user_uuid, conn) + .await? + .ok_or_else(|| Error::from(String::from("no webauthn challenge")))?; + let security_key_authentication = chall.security_key_auth()?; + WebAuthnChallenge::Auth(chall).delete(conn).await?; let resp = serde_json::from_str::<PublicKeyCredential>(response)?; - let mut tf = get_tf_entry(user_uuid, i32::from(TwoFactorType::Webauthn), conn) - .await - .ok_or_else(|| Error::from(String::from("no twofactor entries")))?; - let mut regs = tf.get_webauthn_registrations()?; let auth = build_webauthn()?.finish_securitykey_authentication(&resp, &security_key_authentication)?; + let mut web = WebAuthn::get_by_cred_id(&resp.id, conn) + .await? + .ok_or_else(|| Error::from(String::from("no matching webauthn entry")))?; if auth.needs_update() { - for reg in &mut regs { - if let Some(update) = reg.security_key.update_credential(&auth) { - if update { - tf.data = serde_json::to_string(&regs)?; - tf.update_webauthn(conn).await?; - } - return Ok(()); + let mut sec_key = web.security_key()?; + if let Some(update) = sec_key.update_credential(&auth) { + if update { + web.set_security_key(&sec_key)?; + web.update(conn).await?; + Ok(()) + } else { + unreachable!("webauthn credential no longer needs to be updated") } + } else { + unreachable!("webauthn credential no longer matches challenge") } - Err(Error::from(String::from("Credential not present"))) } else { Ok(()) } diff --git a/src/api/core/two_factor/yubikey.rs b/src/api/core/two_factor/yubikey.rs @@ -0,0 +1,41 @@ +use crate::{ + api::{JsonUpcase, PasswordOrOtpData}, + auth::Headers, + error::Error, +}; +use rocket::Route; + +pub fn routes() -> Vec<Route> { + routes![generate_yubikey, activate_yubikey, activate_yubikey_put,] +} + +#[derive(Deserialize)] +#[allow(non_snake_case, dead_code)] +struct EnableYubikeyData { + Key1: Option<String>, + Key2: Option<String>, + Key3: Option<String>, + Key4: Option<String>, + Key5: Option<String>, + Nfc: bool, + MasterPasswordHash: Option<String>, + Otp: Option<String>, +} +const YUBI_DISABLED_MSG: &str = "Yubico OTP is disabled."; +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[post("/two-factor/get-yubikey", data = "<data>")] +fn generate_yubikey(data: JsonUpcase<PasswordOrOtpData>, _headers: Headers) -> Error { + Error::new(YUBI_DISABLED_MSG, YUBI_DISABLED_MSG) +} + +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[post("/two-factor/yubikey", data = "<data>")] +fn activate_yubikey(data: JsonUpcase<EnableYubikeyData>, _headers: Headers) -> Error { + Error::new(YUBI_DISABLED_MSG, YUBI_DISABLED_MSG) +} + +#[allow(unused_variables, clippy::needless_pass_by_value)] +#[put("/two-factor/yubikey", data = "<data>")] +fn activate_yubikey_put(data: JsonUpcase<EnableYubikeyData>, _headers: Headers) -> Error { + Error::new(YUBI_DISABLED_MSG, YUBI_DISABLED_MSG) +} diff --git a/src/api/identity.rs b/src/api/identity.rs @@ -3,16 +3,15 @@ use crate::{ core::accounts::{PreloginData, RegisterData, _prelogin}, ApiResult, EmptyResult, JsonResult, JsonUpcase, }, - auth::{self, generate_organization_api_key_login_claims, ClientHeaders, ClientIp}, + auth::{ClientHeaders, ClientIp}, config, db::{ - models::{AuthRequest, Device, OrganizationApiKey, TwoFactor, TwoFactorType, User}, + models::{AuthRequest, Device, TwoFactorType, User}, DbConn, }, - error::MapResult, + error::{Error, MapResult}, util, }; -use num_traits::FromPrimitive; use rocket::serde::json::Json; use rocket::{ form::{Form, FromForm}, @@ -43,15 +42,6 @@ async fn login(data: Form<ConnectData>, client_header: ClientHeaders, conn: DbCo _check_is_some(&data.device_type, "device_type cannot be blank")?; _password_login(data, &mut user_uuid, &conn, &client_header.ip).await } - "client_credentials" => { - _check_is_some(&data.client_id, "client_id cannot be blank")?; - _check_is_some(&data.client_secret, "client_secret cannot be blank")?; - _check_is_some(&data.scope, "scope cannot be blank")?; - _check_is_some(&data.device_identifier, "device_identifier cannot be blank")?; - _check_is_some(&data.device_name, "device_name cannot be blank")?; - _check_is_some(&data.device_type, "device_type cannot be blank")?; - _api_key_login(data, &mut user_uuid, &conn, &client_header.ip).await - } t => err!("Invalid type", t), }; login_result @@ -156,7 +146,12 @@ async fn _password_login( ) } let (mut device, _) = get_device(&data, conn, &user).await; - let twofactor_token = twofactor_auth(&user.uuid, &data, &mut device, ip, conn).await?; + let (access_token, expires_in) = device.refresh_tokens(&user, scope_vec); + let kdf = user.client_kdf_type; + let kdf_iter = user.client_kdf_iter(); + let kdf_mem = user.client_kdf_memory(); + let kdf_par = user.client_kdf_parallelism(); + let twofactor_token = twofactor_auth(user.uuid, &data, &mut device, ip, conn).await?; // Common // --- // Disabled this variable, it was used to generate the JWT @@ -164,7 +159,6 @@ async fn _password_login( // See: https://github.com/dani-garcia/vaultwarden/issues/4156 // --- // let orgs = UserOrganization::find_confirmed_by_user(&user.uuid, conn).await; - let (access_token, expires_in) = device.refresh_tokens(&user, scope_vec); device.save(conn).await?; let mut result = json!({ "access_token": access_token, @@ -173,10 +167,10 @@ async fn _password_login( "refresh_token": device.refresh_token, "Key": user.akey, "PrivateKey": user.private_key, - "Kdf": user.client_kdf_type, - "KdfIterations": user.client_kdf_iter(), - "KdfMemory": user.client_kdf_memory(), - "KdfParallelism": user.client_kdf_parallelism(), + "Kdf": kdf, + "KdfIterations": kdf_iter, + "KdfMemory": kdf_mem, + "KdfParallelism": kdf_par, "ResetMasterPassword": false,// TODO: Same as above "scope": scope, "unofficialServer": true, @@ -192,116 +186,6 @@ async fn _password_login( Ok(Json(result)) } -async fn _api_key_login( - data: ConnectData, - user_uuid: &mut Option<String>, - conn: &DbConn, - ip: &ClientIp, -) -> JsonResult { - // Validate scope - match data.scope.as_ref().unwrap().as_ref() { - "api" => _user_api_key_login(data, user_uuid, conn, ip).await, - "api.organization" => _organization_api_key_login(data, conn, ip).await, - _ => err!("Scope not supported"), - } -} - -async fn _user_api_key_login( - data: ConnectData, - user_uuid: &mut Option<String>, - conn: &DbConn, - ip: &ClientIp, -) -> JsonResult { - // Get the user via the client_id - let client_id = data.client_id.as_ref().unwrap(); - let Some(client_user_uuid) = client_id.strip_prefix("user.") else { - err!("Malformed client_id", format!("IP: {}.", ip.ip)) - }; - let Some(user) = User::find_by_uuid(client_user_uuid, conn).await else { - err!("Invalid client_id", format!("IP: {}.", ip.ip)) - }; - // Set the user_uuid here to be passed back used for event logging. - *user_uuid = Some(user.uuid.clone()); - // Check if the user is disabled - if !user.enabled { - err!( - "This user has been disabled (API key login)", - format!("IP: {}. Username: {}.", ip.ip, user.email) - ) - } - // Check API key. Note that API key logins bypass 2FA. - let client_secret = data.client_secret.as_ref().unwrap(); - if !user.check_valid_api_key(client_secret) { - err!( - "Incorrect client_secret", - format!("IP: {}. Username: {}.", ip.ip, user.email) - ) - } - let (mut device, _) = get_device(&data, conn, &user).await; - let scope_vec = vec!["api".into()]; - // --- - // Disabled this variable, it was used to generate the JWT - // Because this might get used in the future, and is add by the Bitwarden Server, lets keep it, but then commented out - // See: https://github.com/dani-garcia/vaultwarden/issues/4156 - // --- - // let orgs = UserOrganization::find_confirmed_by_user(&user.uuid, conn).await; - let (access_token, expires_in) = device.refresh_tokens(&user, scope_vec); - device.save(conn).await?; - info!( - "User {} logged in successfully via API key. IP: {}", - user.email, ip.ip - ); - // Note: No refresh_token is returned. The CLI just repeats the - // client_credentials login flow when the existing token expires. - let result = json!({ - "access_token": access_token, - "expires_in": expires_in, - "token_type": "Bearer", - "Key": user.akey, - "PrivateKey": user.private_key, - "Kdf": user.client_kdf_type, - "KdfIterations": user.client_kdf_iter(), - "KdfMemory": user.client_kdf_memory(), - "KdfParallelism": user.client_kdf_parallelism(), - "ResetMasterPassword": false, // TODO: Same as above - "scope": "api", - "unofficialServer": true, - }); - Ok(Json(result)) -} - -async fn _organization_api_key_login( - data: ConnectData, - conn: &DbConn, - ip: &ClientIp, -) -> JsonResult { - // Get the org via the client_id - let client_id = data.client_id.as_ref().unwrap(); - let Some(org_uuid) = client_id.strip_prefix("organization.") else { - err!("Malformed client_id", format!("IP: {}.", ip.ip)) - }; - let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(org_uuid, conn).await else { - err!("Invalid client_id", format!("IP: {}.", ip.ip)) - }; - // Check API key. - let client_secret = data.client_secret.as_ref().unwrap(); - if !org_api_key.check_valid_api_key(client_secret) { - err!( - "Incorrect client_secret", - format!("IP: {}. Organization: {}.", ip.ip, org_api_key.org_uuid) - ) - } - let claim = generate_organization_api_key_login_claims(org_api_key.uuid, org_api_key.org_uuid); - let access_token = auth::encode_jwt(&claim); - Ok(Json(json!({ - "access_token": access_token, - "expires_in": 3600i32, - "token_type": "Bearer", - "scope": "api.organization", - "unofficialServer": true, - }))) -} - /// Retrieves an existing device or creates a new device from ConnectData and the User async fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> (Device, bool) { // On iOS, device_type sends "iOS", on others it sends a number @@ -325,76 +209,96 @@ async fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> (Device, } async fn twofactor_auth( - user_uuid: &str, + user_uuid: String, data: &ConnectData, device: &mut Device, ip: &ClientIp, conn: &DbConn, ) -> ApiResult<Option<String>> { - let twofactors = TwoFactor::find_by_user(user_uuid, conn).await; - // No twofactor token if twofactor is disabled - if twofactors.is_empty() { - return Ok(None); - } - let twofactor_ids: Vec<_> = twofactors.iter().map(|tf| tf.atype).collect(); - let selected_id = data.two_factor_provider.unwrap_or(twofactor_ids[0]); // If we aren't given a two factor provider, assume the first one - let Some(ref twofactor_code) = data.two_factor_token else { - err_json!( - _json_err_twofactor(&twofactor_ids, user_uuid, conn).await?, - "2FA token not provided" - ) - }; - let selected_twofactor = twofactors - .into_iter() - .find(|tf| tf.atype == selected_id && tf.enabled); - use crate::api::core::two_factor as _tf; - let selected_data = _selected_data(selected_twofactor); - match TwoFactorType::from_i32(selected_id) { - Some(TwoFactorType::Authenticator) => { - _tf::authenticator::validate_totp_code_str( - user_uuid, - twofactor_code, - &selected_data?, - ip, - conn, + let (authn, totp_token) = TwoFactorType::get_factors(user_uuid.as_str(), conn).await?; + if authn || totp_token.is_some() { + let Some(ref token) = data.two_factor_token else { + err_json!( + _json_err_twofactor(authn, totp_token.is_some(), user_uuid, conn).await?, + "2FA token not provided" ) - .await?; - } - Some(TwoFactorType::Webauthn) => { - _tf::webauthn::validate_webauthn_login(user_uuid, twofactor_code, conn).await?; + }; + // If no provider is given, we use a fallback. + // The fallback is prioritized to be WebAuthn if possible. + let tf_type = data + .two_factor_provider + .map(|prov| { + TwoFactorType::try_from(prov) + .map_err(Error::from) + .and_then(|tf| { + if matches!(tf, TwoFactorType::WebAuthn) { + if authn { + Ok(tf) + } else { + const MSG: &str = "no webauthn registrations"; + Err(Error::new(MSG, MSG)) + } + } else if totp_token.is_some() { + Ok(tf) + } else { + const MSG: &str = "no totp registrations"; + Err(Error::new(MSG, MSG)) + } + }) + }) + .transpose()? + .unwrap_or({ + if authn { + TwoFactorType::WebAuthn + } else { + TwoFactorType::Totp + } + }); + use crate::api::core::two_factor as _tf; + match tf_type { + TwoFactorType::Totp => { + _tf::authenticator::validate_totp_code_str( + user_uuid, + token, + totp_token.unwrap_or_else(|| { + unreachable!("no totp registrations, but we verified there are") + }), + ip, + conn, + ) + .await?; + } + TwoFactorType::WebAuthn => { + _tf::webauthn::validate_webauthn_login(user_uuid.as_str(), token, conn).await?; + } } - _ => err!("Invalid two factor provider"), + device.delete_twofactor_remember(); } - device.delete_twofactor_remember(); Ok(None) } -fn _selected_data(tf: Option<TwoFactor>) -> ApiResult<String> { - tf.map(|t| t.data).map_res("Two factor doesn't exist") -} - async fn _json_err_twofactor( - providers: &[i32], - user_uuid: &str, + authn: bool, + totp: bool, + user_uuid: String, conn: &DbConn, ) -> ApiResult<Value> { use crate::api::core::two_factor; + let auth_num = i32::from(TwoFactorType::WebAuthn); + let totp_num = i32::from(TwoFactorType::Totp); + let providers = [auth_num, totp_num]; let mut result = json!({ "error" : "invalid_grant", "error_description" : "Two factor required.", - "TwoFactorProviders" : providers, + "TwoFactorProviders" : if authn { if totp { providers.as_slice() } else { &providers[..1] } } else if totp { &providers[1..] } else { [].as_slice() }, "TwoFactorProviders2" : {} // { "0" : null } }); - for provider in providers { - result["TwoFactorProviders2"][provider.to_string()] = Value::Null; - - if matches!( - TwoFactorType::from_i32(*provider), - Some(TwoFactorType::Webauthn) - ) { - let request = two_factor::webauthn::generate_webauthn_login(user_uuid, conn).await?; - result["TwoFactorProviders2"][provider.to_string()] = request.0; - } + if authn { + let request = two_factor::webauthn::generate_webauthn_login(user_uuid, conn).await?; + result["TwoFactorProviders2"][auth_num.to_string()] = request.0; + } + if totp { + result["TwoFactorProviders2"][totp_num.to_string()] = Value::Null; } Ok(result) } @@ -406,13 +310,14 @@ async fn prelogin(data: JsonUpcase<PreloginData>, conn: DbConn) -> Json<Value> { #[allow(unused_variables, clippy::needless_pass_by_value)] #[post("/accounts/register", data = "<data>")] -fn identity_register(data: JsonUpcase<RegisterData>, _conn: DbConn) -> JsonResult { - err!("No more registerations allowed.") +fn identity_register(data: JsonUpcase<RegisterData>) -> Error { + const MSG: &str = "No more registerations allowed."; + Error::new(MSG, MSG) } // https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts // https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs -#[derive(Debug, Clone, Default, FromForm)] +#[derive(Clone, Default, FromForm)] #[allow(non_snake_case)] struct ConnectData { #[field(name = uncased("grant_type"))] @@ -428,6 +333,7 @@ struct ConnectData { client_id: Option<String>, // web, cli, desktop, browser, mobile #[field(name = uncased("client_secret"))] #[field(name = uncased("clientsecret"))] + #[allow(dead_code)] client_secret: Option<String>, #[field(name = uncased("password"))] password: Option<String>, diff --git a/src/api/mod.rs b/src/api/mod.rs @@ -20,7 +20,7 @@ pub use crate::api::{ web::catchers as web_catchers, web::routes as web_routes, }; -use crate::db::{models::User, DbConn}; +use crate::db::models::User; use crate::error::Error; use crate::util; use rocket::serde::json::Json; @@ -45,16 +45,15 @@ impl PasswordOrOtpData { /// Tokens used via this struct can be used multiple times during the process /// First for the validation to continue, after that to enable or validate the following actions /// This is different per caller, so it can be adjusted to delete the token or not - async fn validate(&self, user: &User, delete_if_valid: bool, conn: &DbConn) -> EmptyResult { - use crate::api::core::two_factor::protected_actions::validate_protected_action_otp; + fn validate(&self, user: &User) -> EmptyResult { match (self.MasterPasswordHash.as_deref(), self.Otp.as_deref()) { (Some(pw_hash), None) => { if !user.check_valid_password(pw_hash) { err!("Invalid password"); } } - (None, Some(otp)) => { - validate_protected_action_otp(otp, &user.uuid, delete_if_valid, conn).await?; + (None, Some(_)) => { + err!("Login via device is not allowed. You must use the master password.") } _ => err!("No validation provided"), } @@ -62,7 +61,7 @@ impl PasswordOrOtpData { } } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Clone)] #[serde(untagged)] enum NumberOrString { Number(i32), diff --git a/src/api/notifications.rs b/src/api/notifications.rs @@ -53,7 +53,7 @@ pub fn routes() -> Vec<Route> { routes![anonymous_websockets_hub, websockets_hub] } -#[derive(FromForm, Debug)] +#[derive(FromForm)] struct WsAccessToken { access_token: Option<String>, } @@ -178,13 +178,12 @@ fn websockets_hub<'r>( }} }) } -#[allow(clippy::unnecessary_wraps)] #[get("/anonymous-hub?<token..>")] fn anonymous_websockets_hub<'r>( ws: rocket_ws::WebSocket, token: String, _ip: ClientIp, -) -> Result<rocket_ws::Stream!['r], Error> { +) -> rocket_ws::Stream!['r] { let (mut rx, guard) = { let subscriptions = Arc::clone(ws_anonymous_subscriptions()); // Add a channel to send messages to this client to the map @@ -193,49 +192,47 @@ fn anonymous_websockets_hub<'r>( // Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map (rx, WSAnonymousEntryMapGuard::new(subscriptions, token)) }; - Ok({ - rocket_ws::Stream! { ws => { - let mut ws_copy = ws; - let _guard = guard; - let mut interval = time::interval(Duration::from_secs(15)); - loop { - tokio::select! { - res = ws_copy.next() => { - match res { - Some(Ok(message)) => { - match message { - // Respond to any pings - Message::Ping(ping) => yield Message::Pong(ping), - Message::Pong(_) => {/* Ignored */}, - // We should receive an initial message with the protocol and version, and we will reply to it - Message::Text(ref message) => { - let msg = message.strip_suffix(char::from(RECORD_SEPARATOR)).unwrap_or(message); - if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) { - yield Message::binary(INITIAL_RESPONSE); - continue; - } + rocket_ws::Stream! { ws => { + let mut ws_copy = ws; + let _guard = guard; + let mut interval = time::interval(Duration::from_secs(15)); + loop { + tokio::select! { + res = ws_copy.next() => { + match res { + Some(Ok(message)) => { + match message { + // Respond to any pings + Message::Ping(ping) => yield Message::Pong(ping), + Message::Pong(_) => {/* Ignored */}, + // We should receive an initial message with the protocol and version, and we will reply to it + Message::Text(ref message) => { + let msg = message.strip_suffix(char::from(RECORD_SEPARATOR)).unwrap_or(message); + if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) { + yield Message::binary(INITIAL_RESPONSE); + continue; } - // Prevent sending anything back when a `Close` Message is received. - // Just break the loop - Message::Close(_) => break, - // Just echo anything else the client sends - _ => yield message, } + // Prevent sending anything back when a `Close` Message is received. + // Just break the loop + Message::Close(_) => break, + // Just echo anything else the client sends + _ => yield message, } - _ => break, } + _ => break, } - res = rx.recv() => { - match res { - Some(res) => yield res, - None => break, - } + } + res = rx.recv() => { + match res { + Some(res) => yield res, + None => break, } - _ = interval.tick() => yield Message::Ping(create_ping()) } + _ = interval.tick() => yield Message::Ping(create_ping()) } - }} - }) + } + }} } fn serialize(val: &Value) -> Vec<u8> { use rmpv::encode::write_value; diff --git a/src/api/web.rs b/src/api/web.rs @@ -1,5 +1,5 @@ use crate::{ - api::{core::now, EmptyResult}, + api::core::now, auth::decode_file_download, config::{self, Config}, error::Error, @@ -41,16 +41,14 @@ async fn web_index() -> Cached<Option<NamedFile>> { false, ) } -#[allow(clippy::unnecessary_wraps)] #[head("/")] -const fn web_index_head() -> EmptyResult { +const fn web_index_head() { // Add an explicit HEAD route to prevent uptime monitoring services from // generating "No matching routes for HEAD /" error messages. // // Rocket automatically implements a HEAD route when there's a matching GET // route, but relying on this behavior also means a spurious error gets // logged due to <https://github.com/SergioBenitez/Rocket/issues/1098>. - Ok(()) } #[get("/app-id.json")] @@ -111,17 +109,12 @@ async fn attachments(uuid: SafeString, file_id: SafeString, token: String) -> Op .ok() } -use crate::db::DbConn; -#[allow(clippy::needless_pass_by_value)] #[get("/alive")] -fn alive(_conn: DbConn) -> Json<String> { +fn alive() -> Json<String> { now() } -#[allow(clippy::needless_pass_by_value, clippy::unnecessary_wraps)] #[head("/alive")] -fn alive_head(_conn: DbConn) -> EmptyResult { - Ok(()) -} +const fn alive_head() {} #[get("/vw_static/<filename>", rank = 2)] fn static_files(filename: &str) -> Result<(ContentType, &'static [u8]), Error> { match filename { diff --git a/src/auth.rs b/src/auth.rs @@ -4,7 +4,6 @@ use crate::{ }; use chrono::{Duration, Utc}; use jsonwebtoken::{self, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header}; -use num_traits::FromPrimitive; use openssl::pkey::{Id, PKey}; use serde::de::DeserializeOwned; use serde::ser::Serialize; @@ -97,23 +96,6 @@ fn get_jwt_verifyemail_issuer() -> &'static str { .expect("JWT_VERIFYEMAIL_ISSUER must be initialized in main") .as_str() } -static JWT_ORG_API_KEY_ISSUER: OnceLock<String> = OnceLock::new(); -#[inline] -fn init_jwt_org_api_key_issuer() { - JWT_ORG_API_KEY_ISSUER - .set(format!( - "{}|api.organization", - config::get_config().domain_origin() - )) - .expect("JWT_ORG_API_KEY_ISSUER must only be initialized once"); -} -#[inline] -fn get_jwt_org_api_key_issuer() -> &'static str { - JWT_ORG_API_KEY_ISSUER - .get() - .expect("JWT_ORG_API_KEY_ISSUER must be initialized in main") - .as_str() -} static JWT_FILE_DOWNLOAD_ISSUER: OnceLock<String> = OnceLock::new(); #[inline] fn init_jwt_file_download_issuer() { @@ -152,10 +134,11 @@ fn init_ed_keys() -> Result<(), Error> { if ed_key.id() == Id::ED25519 { ed_key } else { - return Err(Error::from(format!( + let msg = format!( "{} is not a private Ed25519 key", Config::PRIVATE_ED25519_KEY - ))); + ); + return Err(Error::new(msg.as_str(), msg.as_str())); } }; ED_KEYS @@ -163,7 +146,10 @@ fn init_ed_keys() -> Result<(), Error> { EncodingKey::from_ed_pem(priv_pem.as_slice())?, DecodingKey::from_ed_pem(ed_key.public_key_to_pem()?.as_slice())?, )) - .map_err(|_| Error::from(String::from("ED_KEYS must only be initialized once"))) + .map_err(|_| { + const MSG: &str = "ED_KEYS must only be initialized once"; + Error::new(MSG, MSG) + }) } #[inline] fn get_private_ed_key() -> &'static EncodingKey { @@ -187,7 +173,6 @@ pub fn init_values() { init_jwt_invite_issuer(); init_jwt_delete_issuer(); init_jwt_verifyemail_issuer(); - init_jwt_org_api_key_issuer(); init_jwt_file_download_issuer(); init_ed_keys().expect("error creating Ed25519 keys"); } @@ -244,14 +229,11 @@ pub fn decode_delete(token: &str) -> Result<BasicJwtClaims, Error> { pub fn decode_verify_email(token: &str) -> Result<BasicJwtClaims, Error> { decode_jwt(token, get_jwt_verifyemail_issuer().to_owned()) } -pub fn decode_api_org(token: &str) -> Result<OrgApiKeyLoginJwtClaims, Error> { - decode_jwt(token, get_jwt_org_api_key_issuer().to_owned()) -} pub fn decode_file_download(token: &str) -> Result<FileDownloadClaims, Error> { decode_jwt(token, get_jwt_file_download_issuer().to_owned()) } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] pub struct LoginJwtClaims { // Not before pub nbf: i64, @@ -285,7 +267,7 @@ pub struct LoginJwtClaims { pub amr: Vec<String>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] pub struct InviteJwtClaims { // Not before nbf: i64, @@ -301,56 +283,7 @@ pub struct InviteJwtClaims { invited_by_email: Option<String>, } -#[derive(Debug, Serialize, Deserialize)] -struct EmergencyAccessInviteJwtClaims { - // Not before - nbf: i64, - // Expiration time - exp: i64, - // Issuer - iss: String, - // Subject - sub: String, - email: String, - emer_id: String, - grantor_name: String, - grantor_email: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct OrgApiKeyLoginJwtClaims { - // Not before - pub nbf: i64, - // Expiration time - pub exp: i64, - // Issuer - pub iss: String, - // Subject - pub sub: String, - pub client_id: String, - pub client_sub: String, - pub scope: Vec<String>, -} - -pub fn generate_organization_api_key_login_claims( - uuid: String, - org_id: String, -) -> OrgApiKeyLoginJwtClaims { - let time_now = Utc::now().naive_utc(); - OrgApiKeyLoginJwtClaims { - nbf: time_now.timestamp(), - exp: (time_now.checked_add_signed(Duration::hours(1))) - .expect("Duration add overflowed") - .timestamp(), - iss: get_jwt_org_api_key_issuer().to_owned(), - sub: uuid, - client_id: format!("organization.{org_id}"), - client_sub: org_id, - scope: vec!["api.organization".into()], - } -} - -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] pub struct FileDownloadClaims { // Not before nbf: i64, @@ -362,7 +295,7 @@ pub struct FileDownloadClaims { pub sub: String, pub file_id: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] pub struct BasicJwtClaims { // Not before nbf: i64, @@ -384,8 +317,8 @@ use rocket::{ request::{FromRequest, Outcome, Request}, }; -pub struct Host { - pub host: String, +struct Host { + host: String, } #[rocket::async_trait] @@ -570,7 +503,7 @@ impl<'r> FromRequest<'r> for OrgHeaders { device: headers.device, user, org_user_type: { - if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) { + if let Ok(org_usr_type) = UserOrgType::try_from(org_user.atype) { org_usr_type } else { // This should only happen if the DB is corrupted @@ -589,11 +522,11 @@ impl<'r> FromRequest<'r> for OrgHeaders { pub struct AdminHeaders { pub host: String, - device: Device, + pub device: Device, pub user: User, pub org_user_type: UserOrgType, pub client_version: Option<String>, - ip: ClientIp, + pub ip: ClientIp, } #[rocket::async_trait] diff --git a/src/config.rs b/src/config.rs @@ -64,14 +64,14 @@ impl From<ParseError> for ConfigErr { Self::Url(value) } } -#[derive(Debug, serde::Deserialize)] +#[derive(serde::Deserialize)] struct Tls { ciphers: Option<Vec<CipherSuite>>, cert: String, key: String, prefer_server_cipher_order: Option<bool>, } -#[derive(Debug, serde::Deserialize)] +#[derive(serde::Deserialize)] struct ConfigFile { database_max_conns: Option<NonZeroU8>, database_timeout: Option<u16>, @@ -81,6 +81,7 @@ struct ConfigFile { password_iterations: Option<u32>, port: u16, tls: Tls, + webauthn_require_yubi: Option<bool>, web_vault_enabled: Option<bool>, workers: Option<NonZeroU8>, } @@ -93,6 +94,7 @@ pub struct Config { pub password_iterations: u32, pub rocket: rocket::Config, pub web_vault_enabled: bool, + pub webauthn_require_yubi: bool, } impl Config { #[inline] @@ -165,6 +167,7 @@ impl Config { }, rocket, web_vault_enabled: config_file.web_vault_enabled.unwrap_or(true), + webauthn_require_yubi: config_file.webauthn_require_yubi.unwrap_or(false), }) } } diff --git a/src/crypto.rs b/src/crypto.rs @@ -35,33 +35,6 @@ pub fn encode_random_bytes<const N: usize>(e: &Encoding) -> String { e.encode(&get_random_bytes::<N>()) } -/// Generates a random string over a specified alphabet. -fn get_random_string(alphabet: &[u8], num_chars: usize) -> String { - // Ref: https://rust-lang-nursery.github.io/rust-cookbook/algorithms/randomness.html - use rand::Rng; - let mut rng = rand::thread_rng(); - (0..num_chars) - .map(|_| { - let i = rng.gen_range(0..alphabet.len()); - char::from(alphabet[i]) - }) - .collect() -} - -/// Generates a random alphanumeric string. -fn get_random_string_alphanum(num_chars: usize) -> String { - const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ - abcdefghijklmnopqrstuvwxyz\ - 0123456789"; - get_random_string(ALPHABET, num_chars) -} - -/// Generates a personal API key. -/// Upstream uses 30 chars, which is ~178 bits of entropy. -pub fn generate_api_key() -> String { - get_random_string_alphanum(30) -} - // // Constant time compare // diff --git a/src/db/models/mod.rs b/src/db/models/mod.rs @@ -15,8 +15,8 @@ pub use self::device::{Device, DeviceType}; pub use self::favorite::Favorite; pub use self::folder::{Folder, FolderCipher}; pub use self::org_policy::{OrgPolicy, OrgPolicyErr, OrgPolicyType}; -pub use self::organization::{ - Organization, OrganizationApiKey, UserOrgStatus, UserOrgType, UserOrganization, +pub use self::organization::{Organization, UserOrgStatus, UserOrgType, UserOrganization}; +pub use self::two_factor::{ + Totp, TwoFactorType, WebAuthn, WebAuthnAuth, WebAuthnChallenge, WebAuthnInfo, WebAuthnReg, }; -pub use self::two_factor::{TwoFactor, TwoFactorType, WebAuthn, WebauthnRegistration}; pub use self::user::{User, UserKdfType, UserStampException}; diff --git a/src/db/models/org_policy.rs b/src/db/models/org_policy.rs @@ -1,7 +1,8 @@ -use super::{TwoFactor, UserOrgStatus, UserOrgType, UserOrganization}; +use super::{UserOrgStatus, UserOrgType, UserOrganization}; use crate::api::EmptyResult; +use crate::db::models::TwoFactorType; use crate::db::DbConn; -use crate::error::MapResult; +use crate::error::{Error, MapResult}; use crate::util::{self, UpCase}; use diesel::result::{self, DatabaseErrorKind}; use serde::Deserialize; @@ -20,7 +21,7 @@ db_object! { } // https://github.com/bitwarden/server/blob/b86a04cef9f1e1b82cf18e49fc94e017c641130c/src/Core/Enums/PolicyType.cs -#[derive(Copy, Clone, Eq, PartialEq, num_derive::FromPrimitive)] +#[derive(Copy, Clone, Eq, PartialEq)] pub enum OrgPolicyType { TwoFactorAuthentication = 0, MasterPassword = 1, @@ -45,6 +46,25 @@ impl From<OrgPolicyType> for i32 { } } } +impl TryFrom<i32> for OrgPolicyType { + type Error = Error; + fn try_from(value: i32) -> Result<Self, Self::Error> { + match value { + 0i32 => Ok(Self::TwoFactorAuthentication), + 1i32 => Ok(Self::MasterPassword), + 2i32 => Ok(Self::PasswordGenerator), + 3i32 => Ok(Self::SingleOrg), + 5i32 => Ok(Self::PersonalOwnership), + 6i32 => Ok(Self::DisableSend), + 7i32 => Ok(Self::SendOptions), + 8i32 => Ok(Self::ResetPassword), + _ => { + const MSG: &str = "i32 is invalid OrgPolicyType"; + Err(Error::new(MSG, MSG)) + } + } + } +} // https://github.com/bitwarden/server/blob/5cbdee137921a19b1f722920f0fa3cd45af2ef0f/src/Core/Models/Data/Organizations/Policies/ResetPasswordDataModel.cs #[derive(Deserialize)] @@ -222,7 +242,10 @@ impl OrgPolicy { conn: &DbConn, ) -> OrgPolicyResult { // Enforce TwoFactor/TwoStep login - if TwoFactor::find_by_user(user_uuid, conn).await.is_empty() { + if !TwoFactorType::has_twofactor(user_uuid, conn) + .await + .expect("unable to get two factor information") + { match Self::find_by_org_and_type(org_uuid, OrgPolicyType::TwoFactorAuthentication, conn) .await { diff --git a/src/db/models/organization.rs b/src/db/models/organization.rs @@ -1,9 +1,8 @@ -use super::{CollectionUser, OrgPolicy, OrgPolicyType, TwoFactor, User}; -use crate::crypto; +use super::{CollectionUser, OrgPolicy, OrgPolicyType, User}; +use crate::db::models::TwoFactorType; +use crate::error::Error; use crate::util; -use chrono::{NaiveDateTime, Utc}; use diesel::result::{self, DatabaseErrorKind}; -use num_traits::FromPrimitive; use serde_json::Value; use std::cmp::Ordering; @@ -31,16 +30,6 @@ db_object! { pub reset_password_key: Option<String>, pub external_id: Option<String>, } - - #[derive(AsChangeset, Insertable, Queryable)] - #[diesel(table_name = organization_api_key)] - pub struct OrganizationApiKey { - pub uuid: String, - pub org_uuid: String, - atype: i32, - pub api_key: String, - pub revision_date: NaiveDateTime, - } } pub enum UserOrgStatus { Revoked = -1, @@ -59,7 +48,7 @@ impl From<UserOrgStatus> for i32 { } } -#[derive(Clone, Copy, Eq, PartialEq, num_derive::FromPrimitive)] +#[derive(Clone, Copy, Eq, PartialEq)] pub enum UserOrgType { Owner = 0, Admin = 1, @@ -76,6 +65,21 @@ impl From<UserOrgType> for i32 { } } } +impl TryFrom<i32> for UserOrgType { + type Error = Error; + fn try_from(value: i32) -> Result<Self, Error> { + match value { + 0i32 => Ok(Self::Owner), + 1i32 => Ok(Self::Admin), + 2i32 => Ok(Self::User), + 3i32 => Ok(Self::Manager), + _ => { + const MSG: &str = "i32 is not valid UserOrgType"; + Err(Error::new(MSG, MSG)) + } + } + } +} impl From<UserOrgType> for usize { fn from(value: UserOrgType) -> Self { match value { @@ -138,7 +142,7 @@ impl PartialEq<i32> for UserOrgType { impl PartialOrd<i32> for UserOrgType { fn partial_cmp(&self, other: &i32) -> Option<Ordering> { - if let Some(other) = Self::from_i32(*other) { + if let Ok(other) = Self::try_from(*other) { return Some(self.cmp(&other)); } None @@ -164,7 +168,7 @@ impl PartialEq<UserOrgType> for i32 { impl PartialOrd<UserOrgType> for i32 { fn partial_cmp(&self, other: &UserOrgType) -> Option<Ordering> { - if let Some(self_type) = UserOrgType::from_i32(*self) { + if let Ok(self_type) = UserOrgType::try_from(*self) { return Some(self_type.cmp(other)); } None @@ -277,21 +281,6 @@ impl UserOrganization { } } -impl OrganizationApiKey { - pub fn new(org_uuid: String, api_key: String) -> Self { - Self { - uuid: util::get_uuid(), - org_uuid, - atype: 0, // Type 0 is the default and only type we support currently - api_key, - revision_date: Utc::now().naive_utc(), - } - } - pub fn check_valid_api_key(&self, api_key: &str) -> bool { - crypto::ct_eq(&self.api_key, api_key) - } -} - use crate::api::EmptyResult; use crate::db::DbConn; use crate::error::MapResult; @@ -393,7 +382,9 @@ impl UserOrganization { self.status }; - let twofactor_enabled = !TwoFactor::find_by_user(&user.uuid, conn).await.is_empty(); + let twofactor_enabled = TwoFactorType::has_twofactor(&user.uuid, conn) + .await + .expect("unable to get two factor information"); let groups: Vec<String> = Vec::new(); let collections: Vec<Value> = if include_collections { CollectionUser::find_by_organization_and_user_uuid( @@ -559,8 +550,8 @@ impl UserOrganization { db_run! { conn: { users_organizations::table .filter(users_organizations::user_uuid.eq(user_uuid)) - .filter(users_organizations::status.eq(i32::from(UserOrgStatus::Accepted))) - .or_filter(users_organizations::status.eq(i32::from(UserOrgStatus::Confirmed))) + .filter(users_organizations::status.eq(i32::from(UserOrgStatus::Accepted)) + .or(users_organizations::status.eq(i32::from(UserOrgStatus::Confirmed)))) .count() .first::<i64>(conn) .unwrap_or(0) @@ -706,40 +697,6 @@ impl UserOrganization { } } -impl OrganizationApiKey { - pub async fn save(&self, conn: &DbConn) -> EmptyResult { - db_run! { conn: - { - match diesel::replace_into(organization_api_key::table) - .values(OrganizationApiKeyDb::to_db(self)) - .execute(conn) - { - Ok(_) => Ok(()), - // Record already exists and causes a Foreign Key Violation because replace_into() wants to delete the record first. - Err(result::Error::DatabaseError(DatabaseErrorKind::ForeignKeyViolation, _)) => { - diesel::update(organization_api_key::table) - .filter(organization_api_key::uuid.eq(&self.uuid)) - .set(OrganizationApiKeyDb::to_db(self)) - .execute(conn) - .map_res("Error saving organization") - } - Err(e) => Err(e.into()), - }.map_res("Error saving organization") - - } - } - } - - pub async fn find_by_org_uuid(org_uuid: &str, conn: &DbConn) -> Option<Self> { - db_run! { conn: { - organization_api_key::table - .filter(organization_api_key::org_uuid.eq(org_uuid)) - .first::<OrganizationApiKeyDb>(conn) - .ok().from_db() - }} - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/db/models/two_factor.rs b/src/db/models/two_factor.rs @@ -1,148 +1,334 @@ -use crate::{api::EmptyResult, db::DbConn, error::Error, util}; -use num_traits::FromPrimitive; -use serde_json::Value; +use crate::{api::EmptyResult, db::DbConn, error::Error}; +use serde::ser::{Serialize, SerializeStruct, Serializer}; +use serde_json::de; use tokio::task; -use webauthn_rs::prelude::{CredentialID, SecurityKey}; - +use webauthn_rs::prelude::{ + CredentialID, SecurityKey, SecurityKeyAuthentication, SecurityKeyRegistration, +}; db_object! { - #[derive(AsChangeset, Insertable, Queryable)] - #[diesel(table_name = twofactor)] - pub struct TwoFactor { - uuid: String, - user_uuid: String, - pub atype: i32, - pub enabled: bool, - pub data: String, - last_used: i64, - } - #[derive(Insertable)] + #[derive(Insertable, Queryable)] #[diesel(table_name = webauthn)] pub struct WebAuthn { - pub credential_id: String, - uuid: String, + credential_id: String, + pub user_uuid: String, + pub id: i64, + pub name: String, + security_key: String, + } + #[derive(Insertable, Queryable)] + #[diesel(table_name = webauthn_auth)] + pub struct WebAuthnAuth { + pub user_uuid: String, + data: String, + } + #[derive(Insertable, Queryable)] + #[diesel(table_name = webauthn_reg)] + pub struct WebAuthnReg { + pub user_uuid: String, + data: String, + } + #[derive(Insertable, Queryable)] + #[diesel(table_name = totp)] + pub struct Totp { + user_uuid: String, + pub token: String, + last_used: i64, } } -#[derive(Deserialize, Serialize)] -pub struct WebauthnRegistration { - pub id: u32, - pub name: String, - pub security_key: SecurityKey, -} - -impl WebauthnRegistration { - pub fn to_json(&self) -> Value { - json!({ - "id": self.id, - "name": self.name, - "migrated": false, +impl WebAuthn { + pub fn new( + user_uuid: String, + id: i64, + name: String, + security_key: &SecurityKey, + ) -> Result<Self, Error> { + Ok(Self { + credential_id: security_key.cred_id().to_string(), + user_uuid, + id, + name, + security_key: serde_json::to_string(security_key)?, }) } - pub fn to_security_keys(source: Vec<Self>) -> Vec<SecurityKey> { - let len = source.len(); - source - .into_iter() - .fold(Vec::with_capacity(len), |mut keys, reg| { - keys.push(reg.security_key); - keys - }) + pub fn credential_id(&self) -> Result<CredentialID, Error> { + CredentialID::try_from(self.credential_id.as_str()) + .map_err(|()| Error::from(String::from("invalid credential ID"))) + } + pub fn security_key(&self) -> Result<SecurityKey, Error> { + serde_json::from_str(self.security_key.as_str()).map_err(Error::from) + } + pub fn set_security_key(&mut self, security_key: &SecurityKey) -> Result<(), Error> { + self.security_key = serde_json::to_string(security_key)?; + Ok(()) } } -impl TwoFactor { - pub fn last_used(&self) -> u64 { - u64::try_from(self.last_used).expect("underflow") +impl WebAuthnAuth { + pub fn new( + user_uuid: String, + security_key_auth: &SecurityKeyAuthentication, + ) -> Result<Self, Error> { + Ok(Self { + user_uuid, + data: serde_json::to_string(security_key_auth)?, + }) + } + pub fn security_key_auth(&self) -> Result<SecurityKeyAuthentication, Error> { + serde_json::from_str(&self.data).map_err(Error::from) } - pub fn set_last_used(&mut self, last: u64) { - self.last_used = i64::try_from(last).expect("overflow"); +} +impl WebAuthnReg { + pub fn new( + user_uuid: String, + security_key_reg: &SecurityKeyRegistration, + ) -> Result<Self, Error> { + Ok(Self { + user_uuid, + data: serde_json::to_string(security_key_reg)?, + }) } - pub fn get_webauthn_registrations(&self) -> Result<Vec<WebauthnRegistration>, Error> { + pub fn security_key_reg(&self) -> Result<SecurityKeyRegistration, Error> { serde_json::from_str(&self.data).map_err(Error::from) } - pub fn create_webauthn(&self, credential_id: String) -> WebAuthn { - WebAuthn { - credential_id, - uuid: self.uuid.clone(), - } +} +#[derive(Queryable)] +pub struct WebAuthnInfo { + id: i64, + name: String, +} +impl Serialize for WebAuthnInfo { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + let mut s = serializer.serialize_struct("WebAuthnInfo", 3)?; + s.serialize_field("id", &self.id)?; + s.serialize_field("name", self.name.as_str())?; + s.serialize_field("migrated", &false)?; + s.end() } } +/// Represents a WebAuthn challenge. +pub enum WebAuthnChallenge { + Auth(WebAuthnAuth), + Reg(WebAuthnReg), +} -#[derive(num_derive::FromPrimitive)] +#[derive(Clone, Copy)] pub enum TwoFactorType { - Authenticator = 0, - Webauthn = 7, - WebauthnRegisterChallenge = 1003, - WebauthnLoginChallenge = 1004, - ProtectedActions = 2000, + Totp = 0, + WebAuthn = 7, } impl From<TwoFactorType> for i32 { fn from(value: TwoFactorType) -> Self { match value { - TwoFactorType::Authenticator => 0i32, - TwoFactorType::Webauthn => 7i32, - TwoFactorType::WebauthnRegisterChallenge => 1003i32, - TwoFactorType::WebauthnLoginChallenge => 1004i32, - TwoFactorType::ProtectedActions => 2000i32, + TwoFactorType::Totp => 0i32, + TwoFactorType::WebAuthn => 7i32, } } } -/// Local methods -impl TwoFactor { - pub fn new(user_uuid: String, atype: TwoFactorType, data: String) -> Self { +impl TryFrom<i32> for TwoFactorType { + type Error = Error; + fn try_from(value: i32) -> Result<Self, Self::Error> { + match value { + 0i32 => Ok(Self::Totp), + 7i32 => Ok(Self::WebAuthn), + _ => Err(Error::from(String::from( + "i32 is not a valid TwoFactorType", + ))), + } + } +} +impl Serialize for TwoFactorType { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + let mut s = serializer.serialize_struct("TwoFactorType", 3)?; + s.serialize_field("Enabled", &true)?; + s.serialize_field("Type", &i32::from(*self))?; + s.serialize_field("Object", "twoFactorProvider")?; + s.end() + } +} +impl Totp { + pub const fn new(user_uuid: String, token: String) -> Self { Self { - uuid: util::get_uuid(), user_uuid, - atype: i32::from(atype), - enabled: true, - data, + token, last_used: 0, } } - - pub fn to_json_provider(&self) -> Value { - json!({ - "Enabled": self.enabled, - "Type": self.atype, - "Object": "twoFactorProvider" + pub fn get_last_used(&self) -> u64 { + u64::try_from(self.last_used).expect("underflow") + } + pub fn set_last_used(&mut self, last_used: u64) { + self.last_used = i64::try_from(last_used).expect("overflow"); + } +} +impl TwoFactorType { + #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] + pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { + use crate::db::__sqlite_schema::{totp, webauthn}; + use diesel::prelude::{Connection, ExpressionMethods, RunQueryDsl}; + use diesel::result; + let mut con_res = conn.conn.clone().lock_owned().await; + let con = con_res.as_mut().expect("unable to get a pooled connection"); + task::block_in_place(move || { + con.transaction(|con| { + diesel::delete(webauthn::table) + .filter(webauthn::user_uuid.eq(user_uuid)) + .execute(con) + .and_then(|_| { + diesel::delete(totp::table) + .filter(totp::user_uuid.eq(user_uuid)) + .execute(con) + .map(|_| ()) + }) + .map_err(result::Error::into) + }) + }) + } + #[allow(clippy::clone_on_ref_ptr)] + pub async fn delete_by_user(self, user_uuid: &str, conn: &DbConn) -> EmptyResult { + use crate::db::__sqlite_schema::{totp, webauthn}; + use diesel::prelude::{ExpressionMethods, RunQueryDsl}; + use diesel::result; + let mut con_res = conn.conn.clone().lock_owned().await; + let con = con_res.as_mut().expect("unable to get a pooled connection"); + task::block_in_place(move || { + match self { + Self::Totp => diesel::delete(totp::table) + .filter(totp::user_uuid.eq(user_uuid)) + .execute(con), + Self::WebAuthn => diesel::delete(webauthn::table) + .filter(webauthn::user_uuid.eq(user_uuid)) + .execute(con), + } + .map(|_| ()) + .map_err(result::Error::into) + }) + } + #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] + pub async fn has_twofactor(user_uuid: &str, conn: &DbConn) -> Result<bool, Error> { + use crate::db::__sqlite_schema::{totp, webauthn}; + use diesel::prelude::{Connection, ExpressionMethods, QueryDsl, RunQueryDsl}; + use diesel::result; + let mut con_res = conn.conn.clone().lock_owned().await; + let con = con_res.as_mut().expect("unable to get a pooled connection"); + task::block_in_place(move || { + con.transaction(|con| { + webauthn::table + .count() + .filter(webauthn::user_uuid.eq(user_uuid)) + .get_result::<i64>(con) + .map_err(result::Error::into) + .and_then(|count| { + if count == 0 { + totp::table + .count() + .filter(totp::user_uuid.eq(user_uuid)) + .get_result::<i64>(con) + .map_err(result::Error::into) + .map(|count| count > 0) + } else { + Ok(true) + } + }) + }) + }) + } + /// The `bool` represents if WebAuthn is enabled. + /// The `Option` represents if TOTP is enabled; and if so, contains the secret token. + #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] + pub async fn get_factors( + user_uuid: &str, + conn: &DbConn, + ) -> Result<(bool, Option<String>), Error> { + use crate::db::__sqlite_schema::{totp, webauthn}; + use diesel::prelude::{Connection, ExpressionMethods, QueryDsl, RunQueryDsl}; + use diesel::result; + use diesel::OptionalExtension; + let mut con_res = conn.conn.clone().lock_owned().await; + let con = con_res.as_mut().expect("unable to get a pooled connection"); + task::block_in_place(move || { + con.transaction(|con| { + webauthn::table + .count() + .filter(webauthn::user_uuid.eq(user_uuid)) + .get_result::<i64>(con) + .and_then(|count| { + let authn = count > 0; + totp::table + .select(totp::token) + .filter(totp::user_uuid.eq(user_uuid)) + .first(con) + .optional() + .map(|token| (authn, token)) + }) + .map_err(result::Error::into) + }) + }) + } +} +impl WebAuthnInfo { + #[allow(clippy::clone_on_ref_ptr)] + pub async fn get_all_by_user(user_uuid: &str, conn: &DbConn) -> Result<Vec<Self>, Error> { + use crate::db::__sqlite_schema::webauthn; + use diesel::prelude::{ExpressionMethods, QueryDsl, RunQueryDsl}; + use diesel::result; + let mut con_res = conn.conn.clone().lock_owned().await; + let con = con_res.as_mut().expect("unable to get a pooled connection"); + task::block_in_place(move || { + webauthn::table + .select((webauthn::id, webauthn::name)) + .filter(webauthn::user_uuid.eq(user_uuid)) + .load::<Self>(con) + .map_err(result::Error::into) }) } } -/// Database methods -impl TwoFactor { - #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] - pub async fn save(&self, conn: &DbConn) -> EmptyResult { - if matches!(TwoFactorType::from_i32(self.atype), Some(tf) if matches!(tf, TwoFactorType::Webauthn | TwoFactorType::WebauthnLoginChallenge | TwoFactorType::WebauthnRegisterChallenge)) - { - err!("TwoFactor::save must not be called when atype is Webauthn, WebauthnLoginChallenge, or WebauthnRegisterChallenge") - } - use crate::db::__sqlite_schema::twofactor; - use __sqlite_model::TwoFactorDb; - use diesel::prelude::RunQueryDsl; +impl WebAuthn { + #[allow(clippy::clone_on_ref_ptr)] + pub async fn get_all_security_keys( + user_uuid: &str, + conn: &DbConn, + ) -> Result<Vec<SecurityKey>, Error> { + use crate::db::__sqlite_schema::webauthn; + use diesel::prelude::{ExpressionMethods, QueryDsl, RunQueryDsl}; use diesel::result; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { - diesel::replace_into(twofactor::table) - .values(TwoFactorDb::to_db(self)) - .execute(con) + webauthn::table + .select(webauthn::security_key) + .filter(webauthn::user_uuid.eq(user_uuid)) + .load::<String>(con) .map_err(result::Error::into) - .map(|_| ()) + .and_then(|keys| { + let len = keys.len(); + keys.into_iter() + .try_fold(Vec::with_capacity(len), |mut sec_keys, key| { + de::from_str::<SecurityKey>(key.as_str()).map(|sec| { + sec_keys.push(sec); + sec_keys + }) + }) + .map_err(Error::from) + }) }) } #[allow(clippy::clone_on_ref_ptr)] - pub async fn replace_challenge(&self, conn: &DbConn) -> EmptyResult { - if !matches!(TwoFactorType::from_i32(self.atype), Some(tf) if matches!(tf, TwoFactorType::WebauthnLoginChallenge | TwoFactorType::WebauthnRegisterChallenge)) - { - err!("TwoFactor::replace_challenge must only be called when atype is WebauthnLoginChallenge or WebauthnRegisterChallenge") - } - use crate::db::__sqlite_schema::twofactor; - use __sqlite_model::TwoFactorDb; + pub async fn insert(self, conn: &DbConn) -> EmptyResult { + use crate::db::__sqlite_schema::webauthn; + use __sqlite_model::WebAuthnDb; use diesel::prelude::RunQueryDsl; use diesel::result; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { - diesel::replace_into(twofactor::table) - .values(TwoFactorDb::to_db(self)) + diesel::insert_into(webauthn::table) + .values(WebAuthnDb::to_db(&self)) .execute(con) .map_err(result::Error::into) .and_then(|count| { @@ -150,26 +336,23 @@ impl TwoFactor { Ok(()) } else { Err(Error::from(String::from( - "exactly one webauthn challenge would not have been replaced in twofactor", + "exactly one row would not have been inserted into webauthn", ))) } }) }) } #[allow(clippy::clone_on_ref_ptr)] - pub async fn delete_challenge(self, conn: &DbConn) -> EmptyResult { - if !matches!(TwoFactorType::from_i32(self.atype), Some(tf) if matches!(tf, TwoFactorType::WebauthnLoginChallenge | TwoFactorType::WebauthnRegisterChallenge)) - { - err!("TwoFactor::delete_challenge must only be called when atype is WebauthnLoginChallenge or WebauthnRegisterChallenge") - } - use crate::db::__sqlite_schema::twofactor; + pub async fn update(self, conn: &DbConn) -> EmptyResult { + use crate::db::__sqlite_schema::webauthn; use diesel::prelude::{ExpressionMethods, RunQueryDsl}; use diesel::result; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { - diesel::delete(twofactor::table) - .filter(twofactor::uuid.eq(self.uuid)) + diesel::update(webauthn::table) + .set(webauthn::security_key.eq(self.security_key)) + .filter(webauthn::credential_id.eq(self.credential_id)) .execute(con) .map_err(result::Error::into) .and_then(|count| { @@ -177,27 +360,75 @@ impl TwoFactor { Ok(()) } else { Err(Error::from(String::from( - "exactly one webauthn challenge would not have been deleted from twofactor", + "exactly one webauthn row would not have been updated", ))) } }) }) } #[allow(clippy::clone_on_ref_ptr)] - pub async fn update_webauthn(&self, conn: &DbConn) -> EmptyResult { - if !matches!(TwoFactorType::from_i32(self.atype), Some(tf) if matches!(tf, TwoFactorType::Webauthn)) - { - err!("TwoFactor::update_webauthn must only be called when atype is Webauthn") - } - use crate::db::__sqlite_schema::twofactor; + pub async fn get_all_credentials_by_user( + user_uuid: &str, + conn: &DbConn, + ) -> Result<Vec<CredentialID>, Error> { + use crate::db::__sqlite_schema::webauthn; + use diesel::prelude::{ExpressionMethods, QueryDsl, RunQueryDsl}; + use diesel::result; + let mut con_res = conn.conn.clone().lock_owned().await; + let con = con_res.as_mut().expect("unable to get a pooled connection"); + task::block_in_place(move || { + webauthn::table + .select(webauthn::credential_id) + .filter(webauthn::user_uuid.eq(user_uuid)) + .load::<String>(con) + .map_err(result::Error::into) + .and_then(|ids| { + let len = ids.len(); + ids.into_iter() + .try_fold(Vec::with_capacity(len), |mut cred_ids, id| { + CredentialID::try_from(id.as_str()) + .map_err(|()| Error::from(String::from("invalid credential ID"))) + .map(|cred_id| { + cred_ids.push(cred_id); + cred_ids + }) + }) + }) + }) + } + #[allow(clippy::clone_on_ref_ptr)] + pub async fn get_by_cred_id(credential_id: &str, conn: &DbConn) -> Result<Option<Self>, Error> { + use crate::db::{FromDb, __sqlite_schema::webauthn}; + use __sqlite_model::WebAuthnDb; + use diesel::prelude::{ExpressionMethods, QueryDsl, RunQueryDsl}; + use diesel::result; + use diesel::OptionalExtension; + let mut con_res = conn.conn.clone().lock_owned().await; + let con = con_res.as_mut().expect("unable to get a pooled connection"); + task::block_in_place(move || { + webauthn::table + .filter(webauthn::credential_id.eq(credential_id)) + .first::<WebAuthnDb>(con) + .optional() + .map_err(result::Error::into) + .map(FromDb::from_db) + }) + } + #[allow(clippy::clone_on_ref_ptr)] + pub async fn delete_by_user_uuid_and_id( + user_uuid: &str, + id: i64, + conn: &DbConn, + ) -> EmptyResult { + use crate::db::__sqlite_schema::webauthn; use diesel::prelude::{ExpressionMethods, RunQueryDsl}; use diesel::result; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { - diesel::update(twofactor::table) - .set(twofactor::data.eq(&self.data)) - .filter(twofactor::uuid.eq(&self.uuid)) + diesel::delete(webauthn::table) + .filter(webauthn::user_uuid.eq(user_uuid)) + .filter(webauthn::id.eq(id)) .execute(con) .map_err(result::Error::into) .and_then(|count| { @@ -205,246 +436,209 @@ impl TwoFactor { Ok(()) } else { Err(Error::from(String::from( - "exactly one webauthn entry would not have been updated in twofactor", + "exactly one webauthn row would not have been removed for the user", ))) } }) }) } - #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] - pub async fn delete(self, conn: &DbConn) -> EmptyResult { - use crate::db::__sqlite_schema::{twofactor, webauthn}; - use diesel::prelude::{Connection, ExpressionMethods, RunQueryDsl}; +} + +impl WebAuthnAuth { + #[allow(clippy::clone_on_ref_ptr)] + pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Result<Option<Self>, Error> { + use crate::db::{FromDb, __sqlite_schema::webauthn_auth}; + use __sqlite_model::WebAuthnAuthDb; + use diesel::prelude::{ExpressionMethods, QueryDsl, RunQueryDsl}; use diesel::result; + use diesel::OptionalExtension; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { - con.transaction(|con| { - diesel::delete(webauthn::table) - .filter(webauthn::uuid.eq(&self.uuid)) - .execute(con) - .and_then(|_| { - diesel::delete(twofactor::table) - .filter(twofactor::uuid.eq(self.uuid)) - .execute(con) - }) - }) - .map_err(result::Error::into) - .map(|_| ()) + webauthn_auth::table + .filter(webauthn_auth::user_uuid.eq(user_uuid)) + .first::<WebAuthnAuthDb>(con) + .optional() + .map_err(result::Error::into) + .map(FromDb::from_db) + }) + } +} + +impl WebAuthnReg { + #[allow(clippy::clone_on_ref_ptr)] + pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Result<Option<Self>, Error> { + use crate::db::{FromDb, __sqlite_schema::webauthn_reg}; + use __sqlite_model::WebAuthnRegDb; + use diesel::prelude::{ExpressionMethods, QueryDsl, RunQueryDsl}; + use diesel::result; + use diesel::OptionalExtension; + let mut con_res = conn.conn.clone().lock_owned().await; + let con = con_res.as_mut().expect("unable to get a pooled connection"); + task::block_in_place(move || { + webauthn_reg::table + .filter(webauthn_reg::user_uuid.eq(user_uuid)) + .first::<WebAuthnRegDb>(con) + .optional() + .map_err(result::Error::into) + .map(FromDb::from_db) }) } +} + +impl WebAuthnChallenge { #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] - pub async fn update_delete_webauthn(&self, cred_id: String, conn: &DbConn) -> EmptyResult { - if !matches!(TwoFactorType::from_i32(self.atype), Some(tf) if matches!(tf, TwoFactorType::Webauthn)) - { - err!("TwoFactor::update_delete_webauthn must only be called when atype is Webauthn") - } - use crate::db::__sqlite_schema::{twofactor, webauthn}; + pub async fn delete_all(user_uuid: &str, conn: &DbConn) -> EmptyResult { + use crate::db::__sqlite_schema::{webauthn_auth, webauthn_reg}; use diesel::prelude::{Connection, ExpressionMethods, RunQueryDsl}; use diesel::result; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { con.transaction(|con| { - diesel::delete(webauthn::table) - .filter(webauthn::credential_id.eq(cred_id)) - // We add this filter to ensure that the passed - // Credential ID is associated with the correct UUID. - .filter(webauthn::uuid.eq(&self.uuid)) + diesel::delete(webauthn_auth::table) + .filter(webauthn_auth::user_uuid.eq(user_uuid)) .execute(con) - .map_err(result::Error::into) - .and_then(|count| { - if count == 1 { - diesel::update(twofactor::table) - .set(twofactor::data.eq(&self.data)) - .filter(twofactor::uuid.eq(&self.uuid)) - .execute(con) - .map_err(result::Error::into) - .and_then(|count| { - if count == 1 { - Ok(()) - } else { - Err(Error::from(String::from( - "exactly one webauthn entry in twofactor would not have been updated", - ))) - } - }) - } else { - Err(Error::from(String::from( - "exactly one entry would not have been deleted from webauthn", - ))) - } + .and_then(|_| { + diesel::delete(webauthn_reg::table) + .filter(webauthn_reg::user_uuid.eq(user_uuid)) + .execute(con) + .map(|_| ()) }) + .map_err(result::Error::into) }) }) } - #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] - pub async fn insert_insert_webauthn(&self, authn: WebAuthn, conn: &DbConn) -> EmptyResult { - if !matches!(TwoFactorType::from_i32(self.atype), Some(tf) if matches!(tf, TwoFactorType::Webauthn)) - { - err!("TwoFactor::insert_insert_webauthn must only be called when atype is Webauthn") - } - use crate::db::__sqlite_schema::{twofactor, webauthn}; - use __sqlite_model::{TwoFactorDb, WebAuthnDb}; - use diesel::prelude::{Connection, RunQueryDsl}; + #[allow(clippy::clone_on_ref_ptr)] + pub async fn delete(self, conn: &DbConn) -> EmptyResult { + use crate::db::__sqlite_schema::{webauthn_auth, webauthn_reg}; + use diesel::prelude::{ExpressionMethods, RunQueryDsl}; use diesel::result; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { - con.transaction(|con| { - diesel::insert_into(twofactor::table) - .values(TwoFactorDb::to_db(self)) - .execute(con) - .map_err(result::Error::into) - .and_then(|count| { - if count == 1 { - diesel::insert_into(webauthn::table) - .values(WebAuthnDb::to_db(&authn)) - .execute(con) - .map_err(result::Error::into) - .and_then(|count| { - if count == 1 { - Ok(()) - } else { - Err(Error::from(String::from( - "exactly one entry would not have been inserted into webauthn", - ))) - } - }) - } else { - Err(Error::from(String::from("exactly one webauthn entry would have not been inserted into twofactor"))) - } - }) + match self { + Self::Auth(chal) => diesel::delete(webauthn_auth::table) + .filter(webauthn_auth::user_uuid.eq(chal.user_uuid.as_str())) + .execute(con), + Self::Reg(chal) => diesel::delete(webauthn_reg::table) + .filter(webauthn_reg::user_uuid.eq(chal.user_uuid.as_str())) + .execute(con), + } + .map_err(result::Error::into) + .and_then(|count| { + if count == 1 { + Ok(()) + } else { + Err(Error::from(String::from( + "exactly one webauthn challenge would not have been removed", + ))) + } }) }) } #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] - pub async fn update_insert_webauthn(&self, authn: WebAuthn, conn: &DbConn) -> EmptyResult { - if !matches!(TwoFactorType::from_i32(self.atype), Some(tf) if matches!(tf, TwoFactorType::Webauthn)) - { - err!("TwoFactor::update_insert_webauthn must only be called when atype is Webauthn") - } - use crate::db::__sqlite_schema::{twofactor, webauthn}; - use __sqlite_model::WebAuthnDb; + pub async fn replace(&self, conn: &DbConn) -> EmptyResult { + use crate::db::__sqlite_schema::{webauthn_auth, webauthn_reg}; + use __sqlite_model::{WebAuthnAuthDb, WebAuthnRegDb}; use diesel::prelude::{Connection, ExpressionMethods, RunQueryDsl}; use diesel::result; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { con.transaction(|con| { - diesel::update(twofactor::table) - .set(twofactor::data.eq(&self.data)) - .filter(twofactor::uuid.eq(&self.uuid)) - .execute(con) - .map_err(result::Error::into) - .and_then(|count| { - if count == 1 { - diesel::insert_into(webauthn::table) - .values(WebAuthnDb::to_db(&authn)) - .execute(con) - .map_err(result::Error::into) - .and_then(|count| { - if count == 1 { - Ok(()) - } else { - Err(Error::from(String::from( - "exactly one entry would not have been inserted into webauthn", - ))) - } - }) - } else { - Err(Error::from(String::from("exactly one webauthn entry would have not been updated in twofactor"))) + match *self { + Self::Auth(ref chal) => diesel::update(webauthn_auth::table) + .set(webauthn_auth::data.eq(&chal.data)) + .filter(webauthn_auth::user_uuid.eq(chal.user_uuid.as_str())) + .execute(con), + Self::Reg(ref chal) => diesel::update(webauthn_reg::table) + .set(webauthn_reg::data.eq(&chal.data)) + .filter(webauthn_reg::user_uuid.eq(chal.user_uuid.as_str())) + .execute(con), + } + .map_err(result::Error::into) + .and_then(|count| { + if count == 0 { + match *self { + Self::Auth(ref chal) => diesel::insert_into(webauthn_auth::table) + .values(WebAuthnAuthDb::to_db(chal)) + .execute(con), + Self::Reg(ref chal) => diesel::insert_into(webauthn_reg::table) + .values(WebAuthnRegDb::to_db(chal)) + .execute(con), } - }) + .map_err(result::Error::into) + .and_then(|count| { + if count == 1 { + Ok(()) + } else { + Err(Error::from(String::from( + "exactly one webauthn challenge would not have been inserted/updated", + ))) + } + }) + } else { + Ok(()) + } + }) }) }) } +} - pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { - db_run! { conn: { - twofactor::table - .filter(twofactor::user_uuid.eq(user_uuid)) - .filter(twofactor::atype.lt(1000i32)) // Filter implementation types - .load::<TwoFactorDb>(conn) - .expect("Error loading twofactor") - .from_db() - }} - } - - pub async fn find_by_user_and_type(user_uuid: &str, atype: i32, conn: &DbConn) -> Option<Self> { - db_run! { conn: { - twofactor::table - .filter(twofactor::user_uuid.eq(user_uuid)) - .filter(twofactor::atype.eq(atype)) - .first::<TwoFactorDb>(conn) - .ok() - .from_db() - }} - } +impl Totp { #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] - pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { - use crate::db::__sqlite_schema::{twofactor, webauthn}; - use diesel::prelude::{Connection, ExpressionMethods, QueryDsl, RunQueryDsl}; + pub async fn replace(self, conn: &DbConn) -> EmptyResult { + use crate::db::__sqlite_schema::totp; + use __sqlite_model::TotpDb; + use diesel::prelude::{ExpressionMethods, RunQueryDsl}; use diesel::result; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { - con.transaction(|con| { - diesel::delete(webauthn::table) - .filter( - webauthn::uuid.eq_any( - twofactor::table - .filter(twofactor::user_uuid.eq(user_uuid)) - .select(twofactor::uuid), - ), - ) - .execute(con) - .and_then(|_| { - diesel::delete(twofactor::table) - .filter(twofactor::user_uuid.eq(user_uuid)) + diesel::update(totp::table) + .set(totp::last_used.eq(self.last_used)) + .filter(totp::user_uuid.eq(&self.user_uuid)) + .execute(con) + .map_err(result::Error::into) + .and_then(|count| { + if count == 1 { + Ok(()) + } else { + diesel::insert_into(totp::table) + .values(TotpDb::to_db(&self)) .execute(con) - }) - }) - .map_err(result::Error::into) - .map(|_| ()) + .map_err(result::Error::into) + .and_then(|count| { + if count == 1 { + Ok(()) + } else { + Err(Error::from(String::from( + "exactly one totp row was not inserted/updated", + ))) + } + }) + } + }) }) } -} -impl WebAuthn { - #[allow(clippy::clone_on_ref_ptr, clippy::shadow_unrelated)] - pub async fn get_all_credentials_by_user( - user_uuid: &str, - conn: &DbConn, - ) -> Result<Vec<CredentialID>, Error> { - use crate::db::__sqlite_schema::{twofactor, webauthn}; + #[allow(clippy::clone_on_ref_ptr)] + pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Result<Option<Self>, Error> { + use crate::db::{FromDb, __sqlite_schema::totp}; + use __sqlite_model::TotpDb; use diesel::prelude::{ExpressionMethods, QueryDsl, RunQueryDsl}; use diesel::result; + use diesel::OptionalExtension; let mut con_res = conn.conn.clone().lock_owned().await; let con = con_res.as_mut().expect("unable to get a pooled connection"); task::block_in_place(move || { - webauthn::table - .select(webauthn::credential_id) - .filter( - webauthn::uuid.eq_any( - twofactor::table - .filter(twofactor::user_uuid.eq(user_uuid)) - .select(twofactor::uuid), - ), - ) - .load::<String>(con) + totp::table + .filter(totp::user_uuid.eq(user_uuid)) + .first::<TotpDb>(con) + .optional() .map_err(result::Error::into) - .and_then(|ids| { - let len = ids.len(); - ids.into_iter() - .try_fold(Vec::with_capacity(len), |mut cred_ids, id| { - CredentialID::try_from(id.as_str()) - .map_err(|()| Error::from(String::from("invalid credential ID"))) - .map(|cred_id| { - cred_ids.push(cred_id); - cred_ids - }) - }) - }) + .map(FromDb::from_db) }) } } diff --git a/src/db/models/user.rs b/src/db/models/user.rs @@ -247,7 +247,10 @@ impl User { } } -use super::{Cipher, Device, Favorite, Folder, TwoFactor, UserOrgType, UserOrganization}; +use super::{ + Cipher, Device, Favorite, Folder, TwoFactorType, UserOrgType, UserOrganization, + WebAuthnChallenge, +}; use crate::api::EmptyResult; use crate::db::DbConn; use crate::error::MapResult; @@ -259,7 +262,9 @@ impl User { for c in UserOrganization::find_confirmed_by_user(&self.uuid, conn).await { orgs_json.push(c.to_json(conn).await); } - let twofactor_enabled = !TwoFactor::find_by_user(&self.uuid, conn).await.is_empty(); + let twofactor_enabled = TwoFactorType::has_twofactor(&self.uuid, conn) + .await + .expect("unable to get two factor info"); // TODO: Might want to save the status field in the DB let status = if self.password_hash.is_empty() { UserStatus::Invited @@ -333,7 +338,8 @@ impl User { Favorite::delete_all_by_user(&self.uuid, conn).await?; Folder::delete_all_by_user(&self.uuid, conn).await?; Device::delete_all_by_user(&self.uuid, conn).await?; - TwoFactor::delete_all_by_user(&self.uuid, conn).await?; + TwoFactorType::delete_all_by_user(&self.uuid, conn).await?; + WebAuthnChallenge::delete_all(&self.uuid, conn).await?; db_run! {conn: { diesel::delete(users::table.filter(users::uuid.eq(self.uuid))) .execute(conn) diff --git a/src/db/schemas/sqlite/schema.rs b/src/db/schemas/sqlite/schema.rs @@ -91,12 +91,6 @@ table! { } table! { - invitations (email) { - email -> Text, - } -} - -table! { org_policies (uuid) { uuid -> Text, org_uuid -> Text, @@ -105,15 +99,6 @@ table! { data -> Text, } } -table! { - organization_api_key (uuid, org_uuid) { - uuid -> Text, - org_uuid -> Text, - atype -> Integer, - api_key -> Text, - revision_date -> Timestamp, - } -} table! { organizations (uuid) { @@ -126,12 +111,9 @@ table! { } table! { - twofactor (uuid) { - uuid -> Text, + totp (user_uuid) { user_uuid -> Text, - atype -> Integer, - enabled -> Bool, - data -> Text, + token -> Text, last_used -> BigInt, } } @@ -198,13 +180,29 @@ table! { table! { webauthn (credential_id) { credential_id -> Text, - uuid -> Text, + user_uuid -> Text, + id -> BigInt, + name -> Text, + security_key -> Text, + } +} + +table! { + webauthn_auth (user_uuid) { + user_uuid -> Text, + data -> Text, + } +} + +table! { + webauthn_reg (user_uuid) { + user_uuid -> Text, + data -> Text, } } joinable!(folders_ciphers -> ciphers (cipher_uuid)); joinable!(folders_ciphers -> folders (folder_uuid)); -allow_tables_to_appear_in_same_query!(twofactor, webauthn,); allow_tables_to_appear_in_same_query!( ciphers, ciphers_collections, diff --git a/src/util.rs b/src/util.rs @@ -308,7 +308,7 @@ use serde::de::{self, DeserializeOwned, Deserializer, MapAccess, SeqAccess, Visi use serde_json::{self, Value}; type JsonMap = serde_json::Map<String, Value>; -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct UpCase<T: DeserializeOwned> { #[serde(deserialize_with = "upcase_deserialize")] #[serde(flatten)]