hash_set.rs (15026B)
1 use super::{super::request::TimedCeremony, BuildIdentityHasher}; 2 #[cfg(doc)] 3 use core::hash::Hasher; 4 use core::hash::{BuildHasher, Hash}; 5 use hashbrown::{ 6 Equivalent, TryReserveError, 7 hash_set::{Drain, Entry, ExtractIf, HashSet}, 8 }; 9 #[cfg(not(feature = "serializable_server_state"))] 10 use std::time::Instant; 11 #[cfg(feature = "serializable_server_state")] 12 use std::time::SystemTime; 13 /// [`HashSet`] that has maximum [`HashSet::capacity`] and length and allocates exactly once. 14 /// 15 /// Note due to how `HashSet` removes values, it's possible to insert a value after removing a value and cause 16 /// a new allocation. To avoid this, we ensure that the allocated capacity is at least twice the size of 17 /// the requested maximum length. 18 /// 19 /// This is useful in situations when the underlying values are expected to be removed, and one wants to ensure the 20 /// set does not grow unbounded. When `T` is a [`TimedCeremony`], helper methods (e.g., 21 /// [`Self::insert_remove_all_expired`]) are provided that will automatically remove expired values. Note the 22 /// intended use case is for `T` to be based on a server-side randomly generated value; thus the default [`Hasher`] 23 /// is [`BuildIdentityHasher`]. In the event this is not true, one MUST use a more appropriate `Hasher`. 24 /// 25 /// Only the mutable methods of `HashSet` are re-defined in order to ensure [`Self::max_len`] is never exceeded. 26 /// For all other methods, first call [`Self::as_ref`] or [`Self::into`]. 27 /// 28 /// [`Self::into`]: struct.MaxLenHashSet.html#impl-Into<U>-for-T 29 #[derive(Debug)] 30 pub struct MaxLenHashSet<T, S = BuildIdentityHasher>(HashSet<T, S>, usize); 31 impl<T> MaxLenHashSet<T, BuildIdentityHasher> { 32 /// [`HashSet::with_capacity_and_hasher`] using `2 * max_len` and `BuildIdentityHasher`. 33 /// 34 /// Note since the actual capacity allocated may exceed the requested capacity, [`Self::max_len`] may exceed 35 /// `max_len`. 36 /// 37 /// # Panics 38 /// 39 /// `panic`s if `max_len > usize::MAX / 2`. Note since [`HashSet::with_capacity_and_hasher`] `panic`s 40 /// for much smaller values than `usize::MAX / 2`—even when `T` is a zero-sized type (ZST)—this is not an 41 /// additional `panic` than what would already occur. The only difference is the message reported. 42 #[inline] 43 #[must_use] 44 pub fn new(max_len: usize) -> Self { 45 Self::with_hasher(max_len, BuildIdentityHasher) 46 } 47 } 48 impl<T, S> MaxLenHashSet<T, S> { 49 /// Capacity we allocate. 50 /// 51 /// # Errors 52 /// 53 /// Errors iff `max_len > usize::MAX / 2`. 54 const fn requested_capacity(max_len: usize) -> Result<usize, TryReserveError> { 55 if max_len <= usize::MAX >> 1u8 { 56 Ok(max_len << 1u8) 57 } else { 58 Err(TryReserveError::CapacityOverflow) 59 } 60 } 61 /// Returns the immutable maximum length allowed by `self`. 62 #[inline] 63 pub const fn max_len(&self) -> usize { 64 self.1 65 } 66 /// [`HashSet::clear`]. 67 #[inline] 68 pub fn clear(&mut self) { 69 self.0.clear(); 70 } 71 /// [`HashSet::drain`]. 72 #[inline] 73 pub fn drain(&mut self) -> Drain<'_, T> { 74 self.0.drain() 75 } 76 /// [`HashSet::extract_if`]. 77 #[inline] 78 pub fn extract_if<F: FnMut(&T) -> bool>(&mut self, f: F) -> ExtractIf<'_, T, F> { 79 self.0.extract_if(f) 80 } 81 /// [`HashSet::with_capacity_and_hasher`] using `2 * max_len` and `hasher`. 82 /// 83 /// Note since the actual capacity allocated may exceed the requested capacity, [`Self::max_len`] may exceed 84 /// `max_len`. 85 /// 86 /// # Panics 87 /// 88 /// `panic`s if `max_len > usize::MAX / 2`. Note since [`HashSet::with_capacity_and_hasher`] `panic`s 89 /// for much smaller values than `usize::MAX / 2`—even when `T` is a zero-sized type (ZST)—this is not an 90 /// additional `panic` than what would already occur. The only difference is the message reported. 91 #[expect( 92 clippy::expect_used, 93 reason = "purpose of this function is to panic if the hash set cannot be allocated" 94 )] 95 #[inline] 96 #[must_use] 97 pub fn with_hasher(max_len: usize, hasher: S) -> Self { 98 let set = HashSet::with_capacity_and_hasher(Self::requested_capacity(max_len).expect("HashSet::with_hasher must be passed a maximum length that does not exceed usize::MAX / 2"), hasher); 99 let len = set.capacity() >> 1u8; 100 Self(set, len) 101 } 102 /// [`HashSet::retain`]. 103 #[inline] 104 pub fn retain<F: FnMut(&T) -> bool>(&mut self, f: F) { 105 self.0.retain(f); 106 } 107 } 108 impl<T: TimedCeremony, S> MaxLenHashSet<T, S> { 109 /// Removes all expired ceremonies. 110 #[inline] 111 pub fn remove_expired_ceremonies(&mut self) { 112 // Even though it's more accurate to check the current `Instant` for each ceremony, we elect to capture 113 // the `Instant` we begin iteration for performance reasons. It's unlikely an appreciable amount of 114 // additional ceremonies would be removed. 115 #[cfg(not(feature = "serializable_server_state"))] 116 let now = Instant::now(); 117 #[cfg(feature = "serializable_server_state")] 118 let now = SystemTime::now(); 119 self.retain(|v| v.expiration() >= now); 120 } 121 /// Removes the first encountered expired ceremony. 122 #[inline] 123 pub fn remove_first_expired_ceremony(&mut self) { 124 #[cfg(not(feature = "serializable_server_state"))] 125 let now = Instant::now(); 126 #[cfg(feature = "serializable_server_state")] 127 let now = SystemTime::now(); 128 drop(self.0.extract_if(|v| v.expiration() < now).next()); 129 } 130 } 131 impl<T: Eq + Hash, S: BuildHasher> MaxLenHashSet<T, S> { 132 /// [`HashSet::with_hasher`] using `hasher` followed by [`HashSet::try_reserve`] using `2 * max_len`. 133 /// 134 /// Note since the actual capacity allocated may exceed the requested capacity, [`Self::max_len`] may exceed 135 /// `max_len`. 136 /// 137 /// # Errors 138 /// 139 /// Errors iff `max_len > usize::MAX / 2` or [`HashSet::try_reserve`] does. 140 #[inline] 141 pub fn try_with_hasher(max_len: usize, hasher: S) -> Result<Self, TryReserveError> { 142 Self::requested_capacity(max_len).and_then(|additional| { 143 let mut set = HashSet::with_hasher(hasher); 144 set.try_reserve(additional).map(|()| { 145 let len = set.capacity() >> 1u8; 146 Self(set, len) 147 }) 148 }) 149 } 150 /// [`HashSet::get_or_insert`]. 151 /// 152 /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the 153 /// set. 154 #[inline] 155 pub fn get_or_insert(&mut self, value: T) -> Option<&T> { 156 if self.0.len() == self.1 { 157 self.0.get(&value) 158 } else { 159 Some(self.0.get_or_insert(value)) 160 } 161 } 162 /// [`HashSet::get_or_insert_with`]. 163 /// 164 /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the 165 /// set. 166 #[inline] 167 pub fn get_or_insert_with<Q: Equivalent<T> + Hash + ?Sized, F: FnOnce(&Q) -> T>( 168 &mut self, 169 value: &Q, 170 f: F, 171 ) -> Option<&T> { 172 if self.0.len() == self.1 { 173 self.0.get(value) 174 } else { 175 Some(self.0.get_or_insert_with(value, f)) 176 } 177 } 178 /// [`HashSet::remove`]. 179 #[inline] 180 pub fn remove<Q: Equivalent<T> + Hash + ?Sized>(&mut self, value: &Q) -> bool { 181 self.0.remove(value) 182 } 183 /// [`HashSet::take`]. 184 #[inline] 185 pub fn take<Q: Equivalent<T> + Hash + ?Sized>(&mut self, value: &Q) -> Option<T> { 186 self.0.take(value) 187 } 188 /// [`HashSet::insert`]. 189 /// 190 /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the 191 /// set. 192 #[inline] 193 pub fn insert(&mut self, value: T) -> Option<bool> { 194 let full = self.0.len() == self.1; 195 if let Entry::Vacant(ent) = self.0.entry(value) { 196 if full { 197 None 198 } else { 199 _ = ent.insert(); 200 Some(true) 201 } 202 } else { 203 Some(false) 204 } 205 } 206 /// [`HashSet::replace`]. 207 /// 208 /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the 209 /// set. 210 #[inline] 211 pub fn replace(&mut self, value: T) -> Option<Option<T>> { 212 // Ideally we would use the Entry API to avoid searching multiple times, but one can't while also using 213 // `replace` since there is no `OccupiedEntry::replace`. 214 if self.0.contains(&value) { 215 Some(self.0.replace(value)) 216 } else if self.0.len() == self.1 { 217 None 218 } else { 219 _ = self.0.insert(value); 220 Some(None) 221 } 222 } 223 /// [`HashSet::entry`]. 224 /// 225 /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the 226 /// set. 227 #[inline] 228 pub fn entry(&mut self, value: T) -> Option<Entry<'_, T, S>> { 229 let full = self.0.len() == self.1; 230 match self.0.entry(value) { 231 ent @ Entry::Occupied(_) => Some(ent), 232 ent @ Entry::Vacant(_) => { 233 if full { 234 None 235 } else { 236 Some(ent) 237 } 238 } 239 } 240 } 241 } 242 impl<T: Eq + Hash + TimedCeremony, S: BuildHasher> MaxLenHashSet<T, S> { 243 /// [`Self::insert`] except the first encountered expired ceremony is removed in the event [`Self::max_len`] 244 /// items have been added. 245 #[inline] 246 pub fn insert_remove_expired(&mut self, value: T) -> Option<bool> { 247 if self.0.len() == self.1 { 248 #[cfg(not(feature = "serializable_server_state"))] 249 let now = Instant::now(); 250 #[cfg(feature = "serializable_server_state")] 251 let now = SystemTime::now(); 252 if self.0.extract_if(|v| v.expiration() < now).next().is_some() { 253 Some(self.0.insert(value)) 254 } else { 255 self.0.contains(&value).then_some(false) 256 } 257 } else { 258 Some(self.0.insert(value)) 259 } 260 } 261 /// [`Self::insert`] except all expired ceremones are removed in the event [`Self::max_len`] items have 262 /// been added. 263 #[inline] 264 pub fn insert_remove_all_expired(&mut self, value: T) -> Option<bool> { 265 if self.0.len() == self.1 { 266 self.remove_expired_ceremonies(); 267 } 268 if self.0.len() == self.1 { 269 self.0.contains(&value).then_some(false) 270 } else { 271 Some(self.0.insert(value)) 272 } 273 } 274 /// [`Self::entry`] except the first encountered expired ceremony is removed in the event [`Self::max_len`] 275 /// items have been added. 276 #[inline] 277 pub fn entry_remove_expired(&mut self, value: T) -> Option<Entry<'_, T, S>> { 278 if self.0.len() == self.1 { 279 #[cfg(not(feature = "serializable_server_state"))] 280 let now = Instant::now(); 281 #[cfg(feature = "serializable_server_state")] 282 let now = SystemTime::now(); 283 if self.0.extract_if(|v| v.expiration() < now).next().is_some() { 284 Some(self.0.entry(value)) 285 } else if let ent @ Entry::Occupied(_) = self.0.entry(value) { 286 Some(ent) 287 } else { 288 None 289 } 290 } else { 291 Some(self.0.entry(value)) 292 } 293 } 294 /// [`Self::entry`] except all expired ceremones are removed in the event [`Self::max_len`] items have 295 /// been added. 296 #[inline] 297 pub fn entry_remove_all_expired(&mut self, value: T) -> Option<Entry<'_, T, S>> { 298 if self.0.len() == self.1 { 299 self.remove_expired_ceremonies(); 300 } 301 if self.0.len() == self.1 { 302 if let ent @ Entry::Occupied(_) = self.0.entry(value) { 303 Some(ent) 304 } else { 305 None 306 } 307 } else { 308 Some(self.0.entry(value)) 309 } 310 } 311 } 312 impl<T, S> AsRef<HashSet<T, S>> for MaxLenHashSet<T, S> { 313 #[inline] 314 fn as_ref(&self) -> &HashSet<T, S> { 315 &self.0 316 } 317 } 318 impl<T, S> From<MaxLenHashSet<T, S>> for HashSet<T, S> { 319 #[inline] 320 fn from(value: MaxLenHashSet<T, S>) -> Self { 321 value.0 322 } 323 } 324 #[cfg(test)] 325 mod tests { 326 use super::{Equivalent, MaxLenHashSet, TimedCeremony}; 327 use core::hash::{Hash, Hasher}; 328 #[cfg(not(feature = "serializable_server_state"))] 329 use std::time::Instant; 330 #[cfg(feature = "serializable_server_state")] 331 use std::time::SystemTime; 332 #[derive(Clone, Copy)] 333 struct Ceremony { 334 id: usize, 335 #[cfg(not(feature = "serializable_server_state"))] 336 exp: Instant, 337 #[cfg(feature = "serializable_server_state")] 338 exp: SystemTime, 339 } 340 impl Default for Ceremony { 341 fn default() -> Self { 342 Self { 343 id: 0, 344 #[cfg(not(feature = "serializable_server_state"))] 345 exp: Instant::now(), 346 #[cfg(feature = "serializable_server_state")] 347 exp: SystemTime::now(), 348 } 349 } 350 } 351 impl PartialEq for Ceremony { 352 fn eq(&self, other: &Self) -> bool { 353 self.id == other.id 354 } 355 } 356 impl Eq for Ceremony {} 357 impl Hash for Ceremony { 358 fn hash<H: Hasher>(&self, state: &mut H) { 359 self.id.hash(state); 360 } 361 } 362 impl TimedCeremony for Ceremony { 363 #[cfg(not(feature = "serializable_server_state"))] 364 fn expiration(&self) -> Instant { 365 self.exp 366 } 367 #[cfg(feature = "serializable_server_state")] 368 fn expiration(&self) -> SystemTime { 369 self.exp 370 } 371 } 372 impl Equivalent<Ceremony> for usize { 373 fn equivalent(&self, key: &Ceremony) -> bool { 374 *self == key.id 375 } 376 } 377 #[test] 378 fn hash_set_insert_removed() { 379 const REQ_MAX_LEN: usize = 8; 380 let mut set = MaxLenHashSet::new(REQ_MAX_LEN); 381 let cap = set.as_ref().capacity(); 382 let max_len = set.max_len(); 383 assert_eq!(cap >> 1u8, max_len); 384 assert!(max_len >= REQ_MAX_LEN); 385 let mut cer = Ceremony::default(); 386 for i in 0..max_len { 387 assert!(set.as_ref().capacity() <= cap); 388 cer.id = i; 389 assert_eq!(set.insert(cer), Some(true)); 390 } 391 assert!(set.as_ref().capacity() <= cap); 392 assert_eq!(set.as_ref().len(), max_len); 393 for i in 0..max_len { 394 assert!(set.as_ref().contains(&i)); 395 } 396 cer.id = cap; 397 assert_eq!(set.insert_remove_expired(cer), Some(true)); 398 assert!(set.as_ref().capacity() <= cap); 399 assert_eq!(set.as_ref().len(), max_len); 400 let mut counter = 0; 401 for i in 0..max_len { 402 counter += usize::from(set.as_ref().contains(&i)); 403 } 404 assert_eq!(counter, max_len - 1); 405 assert!(set.as_ref().contains(&(max_len - 1))); 406 } 407 }