commit c27faa018ad4353e5a57e97d0b21967a9095f0e7
parent 2db533ea55dd75ddb5fd5058f0caa5b5acc4b1fe
Author: Zack Newman <zack@philomathiclife.com>
Date: Sat, 18 Nov 2023 16:21:20 -0700
fix toctou race condition in rsa files
Diffstat:
6 files changed, 102 insertions(+), 132 deletions(-)
diff --git a/config.toml b/config.toml
@@ -1,8 +1,13 @@
database_max_conns=4
+#database_timeout=30
db_connection_retries=8
domain="pmd.philomathiclife.com"
ip="fdb5:d87:ae42:1::1"
+org_attachment_limit=0
+#password_iterations=600000
port=8443
+user_attachment_limit=0
+#web_vault_enabled=true
workers=4
[tls]
cert="/etc/ssl/pmd.philomathiclife.com.fullchain"
diff --git a/src/auth.rs b/src/auth.rs
@@ -5,8 +5,11 @@ use crate::{
use chrono::{Duration, Utc};
use jsonwebtoken::{self, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header};
use num_traits::FromPrimitive;
+use openssl::rsa::Rsa;
use serde::de::DeserializeOwned;
use serde::ser::Serialize;
+use std::fs::File;
+use std::io::{Read, Write};
use std::sync::OnceLock;
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
@@ -129,47 +132,48 @@ fn get_jwt_file_download_issuer() -> &'static str {
.expect("JWT_FILE_DOWNLOAD_ISSUER must be initialized in main")
.as_str()
}
-static PRIVATE_RSA_KEY: OnceLock<EncodingKey> = OnceLock::new();
+static RSA_KEYS: OnceLock<(EncodingKey, DecodingKey)> = OnceLock::new();
+#[allow(clippy::verbose_file_reads)]
#[inline]
-fn init_private_rsa_key() {
- let key = std::fs::read(Config::PRIVATE_RSA_KEY)
- .unwrap_or_else(|e| panic!("Error loading private RSA Key. \n{e}"));
+pub fn init_rsa_keys() -> Result<(), Error> {
+ let mut priv_file = File::options()
+ .create(true)
+ .read(true)
+ .write(true)
+ .open(Config::PRIVATE_RSA_KEY)?;
+ let mut priv_pem = Vec::with_capacity(2048);
+ let priv_key = if priv_file.read_to_end(&mut priv_pem)? == 0 {
+ let rsa_key = openssl::rsa::Rsa::generate(2048)?;
+ priv_pem = rsa_key.private_key_to_pem()?;
+ priv_file.write_all(priv_pem.as_slice())?;
+ rsa_key
+ } else {
+ Rsa::private_key_from_pem(priv_pem.as_slice())?
+ };
assert!(
- PRIVATE_RSA_KEY
- .set(
- EncodingKey::from_rsa_pem(&key)
- .unwrap_or_else(|e| panic!("Error decoding private RSA Key.\n{e}"))
- )
+ RSA_KEYS
+ .set((
+ EncodingKey::from_rsa_pem(priv_pem.as_slice())?,
+ DecodingKey::from_rsa_pem(priv_key.public_key_to_pem()?.as_slice())?
+ ))
.is_ok(),
- "PRIVATE_RSA_KEY must only be initialized once"
- )
+ "RSA_KEYS must only be initialized once"
+ );
+ Ok(())
}
#[inline]
fn get_private_rsa_key() -> &'static EncodingKey {
- PRIVATE_RSA_KEY
+ &RSA_KEYS
.get()
- .expect("PRIVATE_RSA_KEY must be initialized in main")
-}
-static PUBLIC_RSA_KEY: OnceLock<DecodingKey> = OnceLock::new();
-#[inline]
-fn init_public_rsa_key() {
- let key = std::fs::read(Config::PUBLIC_RSA_KEY)
- .unwrap_or_else(|e| panic!("Error loading public RSA Key. \n{e}"));
- assert!(
- PUBLIC_RSA_KEY
- .set(
- DecodingKey::from_rsa_pem(&key)
- .unwrap_or_else(|e| panic!("Error decoding public RSA Key.\n{e}"))
- )
- .is_ok(),
- "PUBLIC_RSA_KEY must only be initialized once"
- )
+ .expect("RSA_KEYS must be initialized in main")
+ .0
}
#[inline]
fn get_public_rsa_key() -> &'static DecodingKey {
- PUBLIC_RSA_KEY
+ &RSA_KEYS
.get()
- .expect("PUBLIC_RSA_KEY must be initialized in main")
+ .expect("RSA_KEYS must be initialized in main")
+ .1
}
#[inline]
pub fn init_values() {
@@ -182,11 +186,6 @@ pub fn init_values() {
init_jwt_org_api_key_issuer();
init_jwt_file_download_issuer();
}
-#[inline]
-pub fn init_rsa_keys() {
- init_private_rsa_key();
- init_public_rsa_key();
-}
pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
match jsonwebtoken::encode(get_jwt_header(), claims, get_private_rsa_key()) {
Ok(token) => token,
diff --git a/src/config.rs b/src/config.rs
@@ -2,19 +2,19 @@ use core::fmt::{self, Display, Formatter};
use core::num::NonZeroU8;
use rocket::config::{CipherSuite, LogLevel, TlsConfig};
use rocket::data::{Limits, ToByteUnit};
-use std::env;
use std::error;
use std::fs;
use std::io::Error;
use std::net::IpAddr;
+use std::path::PathBuf;
use std::sync::OnceLock;
use toml::{self, de};
use url::{ParseError, Url};
static CONFIG: OnceLock<Config> = OnceLock::new();
#[inline]
-pub fn init_config() {
+pub fn init_config(cur_dir: &mut PathBuf) {
CONFIG
- .set(Config::load().expect("valid TOML config file at 'config.toml'"))
+ .set(Config::load(cur_dir).expect("valid TOML config file at 'config.toml'"))
.expect("CONFIG must only be initialized once")
}
#[inline]
@@ -96,7 +96,7 @@ pub struct Config {
}
impl Config {
#[inline]
- pub fn load() -> Result<Self, ConfigErr> {
+ pub fn load(cur_dir: &mut PathBuf) -> Result<Self, ConfigErr> {
let config_file =
toml::from_str::<ConfigFile>(fs::read_to_string("config.toml")?.as_str())?;
let mut tls = TlsConfig::from_paths(config_file.tls.cert, config_file.tls.key);
@@ -114,8 +114,7 @@ impl Config {
Some(prefer) => tls.with_preferred_server_cipher_order(prefer),
None => tls,
};
- let mut tmp_folder = env::current_dir()?;
- tmp_folder.push(Self::TMP_FOLDER);
+ cur_dir.push(Self::TMP_FOLDER);
let mut rocket = rocket::Config {
address: config_file.ip,
cli_colors: false,
@@ -125,18 +124,29 @@ impl Config {
.limit("file", 525.megabytes()),
log_level: LogLevel::Off,
port: config_file.port,
- temp_dir: tmp_folder.into(),
+ temp_dir: cur_dir.into(),
tls: Some(tls),
..Default::default()
};
if let Some(count) = config_file.workers {
rocket.workers = count.get() as usize;
}
- let domain =
- Url::parse(format!("https://{}:{}", config_file.domain, config_file.port).as_str())?;
+ let domain = Url::parse(
+ format!(
+ "https://{}{}",
+ config_file.domain,
+ if config_file.port == 443 {
+ String::new()
+ } else {
+ config_file.port.to_string()
+ }
+ )
+ .as_str(),
+ )?;
if domain.domain().is_none() {
return Err(ConfigErr::BadDomain);
}
+ cur_dir.pop();
Ok(Self {
database_max_conns: config_file
.database_max_conns
@@ -155,15 +165,14 @@ impl Config {
}
}
impl Config {
- pub const ATTACHMENTS_FOLDER: &str = "data/attachments";
- pub const DATA_FOLDER: &str = "data";
- pub const DATABASE_URL: &str = "data/db.sqlite3";
- pub const ICON_CACHE_FOLDER: &str = "data/icon_cache";
- pub const PRIVATE_RSA_KEY: &str = "data/rsa_key.pem";
- pub const PUBLIC_RSA_KEY: &str = "data/rsa_key.pub.pem";
- pub const SENDS_FOLDER: &str = "data/sends";
- pub const TMP_FOLDER: &str = "data/tmp";
- pub const WEB_VAULT_FOLDER: &str = "web-vault/";
+ pub const ATTACHMENTS_FOLDER: &'static str = "data/attachments";
+ pub const DATA_FOLDER: &'static str = "data";
+ pub const DATABASE_URL: &'static str = "data/db.sqlite3";
+ pub const ICON_CACHE_FOLDER: &'static str = "data/icon_cache";
+ pub const PRIVATE_RSA_KEY: &'static str = "data/rsa_key.pem";
+ pub const SENDS_FOLDER: &'static str = "data/sends";
+ pub const TMP_FOLDER: &'static str = "data/tmp";
+ pub const WEB_VAULT_FOLDER: &'static str = "web-vault/";
#[inline]
pub fn domain_origin(&self) -> String {
self.domain.origin().ascii_serialization()
diff --git a/src/db/models/attachment.rs b/src/db/models/attachment.rs
@@ -1,6 +1,8 @@
use crate::config::Config;
use serde_json::Value;
+use std::fs;
use std::io::ErrorKind;
+use std::path::Path;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
@@ -109,25 +111,33 @@ impl Attachment {
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult {
db_run! { conn: {
- crate::util::retry(
- || diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn),
- 10,
- )
- .map_res("Error deleting attachment")?;
-
- let file_path = &self.get_file_path();
-
- match crate::util::delete_file(file_path) {
- // Ignore "file not found" errors. This can happen when the
- // upstream caller has already cleaned up the file as part of
- // its own error handling.
- Err(e) if e.kind() == ErrorKind::NotFound => {
- debug!("File '{}' already deleted.", file_path);
- Ok(())
+ crate::util::retry(
+ || diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn),
+ 10,
+ )
+ .map_res("Error deleting attachment")?;
+ let file_path = &self.get_file_path();
+ let path = Path::new(&file_path);
+ if let Err(err) = fs::remove_file(path) {
+ if err.kind() != ErrorKind::NotFound {
+ return Err(err.into())
+ }
+ }
+ match path.parent() {
+ None => Ok(()),
+ Some(parent) => fs::remove_dir(parent).or_else(|err| match fs::read_dir(parent) {
+ Err(err2) => if err2.kind() == ErrorKind::NotFound {
+ Ok(())
+ } else {
+ Err(err2.into())
+ },
+ Ok(dir) => if dir.count() == 0 {
+ Err(err.into())
+ } else {
+ Ok(())
+ }
+ })
}
- Err(e) => Err(e.into()),
- _ => Ok(()),
- }
}}
}
diff --git a/src/main.rs b/src/main.rs
@@ -61,6 +61,7 @@ mod util;
use config::Config;
pub use error::{Error, MapResult};
use std::env;
+use std::path::PathBuf;
use std::sync::Arc;
use tokio::runtime::Builder;
@@ -68,16 +69,16 @@ fn main() -> Result<(), Error> {
let mut promises = priv_sep::pledge_init()?;
let mut cur_dir = env::current_dir()?;
priv_sep::unveil_read(cur_dir.as_path())?;
- static_init();
+ static_init(&mut cur_dir);
cur_dir.push(Config::DATA_FOLDER);
unveil_create_read_write(cur_dir)?;
check_data_folder();
- check_rsa_keys().expect("error creating keys");
+ auth::init_rsa_keys().expect("error creating or reading RSA keys");
check_web_vault();
+ create_dir(Config::ATTACHMENTS_FOLDER, "attachments folder");
create_dir(Config::ICON_CACHE_FOLDER, "icon cache");
- create_dir(Config::TMP_FOLDER, "tmp folder");
create_dir(Config::SENDS_FOLDER, "sends folder");
- create_dir(Config::ATTACHMENTS_FOLDER, "attachments folder");
+ create_dir(Config::TMP_FOLDER, "tmp folder");
Builder::new_multi_thread()
.enable_all()
.build()
@@ -103,8 +104,8 @@ fn main() -> Result<(), Error> {
)
}
#[inline]
-fn static_init() {
- config::init_config();
+fn static_init(cur_dir: &mut PathBuf) {
+ config::init_config(cur_dir);
auth::init_values();
api::init_ws_users();
api::init_ws_anonymous_subscriptions();
@@ -127,23 +128,6 @@ fn check_data_folder() {
exit(1);
}
}
-fn check_rsa_keys() -> Result<(), crate::error::Error> {
- // If the RSA keys don't exist, try to create them
- let priv_path = Config::PRIVATE_RSA_KEY;
- let pub_path = Config::PUBLIC_RSA_KEY;
- if !util::file_exists(priv_path) {
- let rsa_key = openssl::rsa::Rsa::generate(2048)?;
- let priv_key = rsa_key.private_key_to_pem()?;
- crate::util::write_file(priv_path, priv_key.as_slice())?;
- }
- if !util::file_exists(pub_path) {
- let rsa_key = openssl::rsa::Rsa::private_key_from_pem(&std::fs::read(priv_path)?)?;
- let pub_key = rsa_key.public_key_to_pem()?;
- crate::util::write_file(pub_path, pub_key.as_slice())?;
- }
- auth::init_rsa_keys();
- Ok(())
-}
fn check_web_vault() {
if !config::get_config().web_vault_enabled {
diff --git a/src/util.rs b/src/util.rs
@@ -261,44 +261,7 @@ impl<'r> FromParam<'r> for SafeString {
}
}
}
-//
-// File handling
-//
-use std::{
- fs::{self, File},
- io::Result as IOResult,
- path::Path,
-};
-
-pub fn file_exists(path: &str) -> bool {
- Path::new(path).exists()
-}
-
-pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error> {
- use std::io::Write;
- let mut f = match File::create(path) {
- Ok(file) => file,
- Err(e) => {
- return Err(From::from(e));
- }
- };
-
- f.write_all(content)?;
- f.flush()?;
- Ok(())
-}
-
-pub fn delete_file(path: &str) -> IOResult<()> {
- let res = fs::remove_file(path);
-
- if let Some(parent) = Path::new(path).parent() {
- // If the directory isn't empty, this returns an error, which we ignore
- // We only want to delete the folder if it's empty
- fs::remove_dir(parent).ok();
- }
-
- res
-}
+use std::path::Path;
pub fn get_display_size(size: i32) -> String {
const UNITS: [&str; 6] = ["bytes", "KB", "MB", "GB", "TB", "PB"];