webauthn_rp

WebAuthn Level 3 RP library.
git clone https://git.philomathiclife.com/repos/webauthn_rp
Log | Files | Refs | README

hash_set.rs (15026B)


      1 use super::{super::request::TimedCeremony, BuildIdentityHasher};
      2 #[cfg(doc)]
      3 use core::hash::Hasher;
      4 use core::hash::{BuildHasher, Hash};
      5 use hashbrown::{
      6     Equivalent, TryReserveError,
      7     hash_set::{Drain, Entry, ExtractIf, HashSet},
      8 };
      9 #[cfg(not(feature = "serializable_server_state"))]
     10 use std::time::Instant;
     11 #[cfg(feature = "serializable_server_state")]
     12 use std::time::SystemTime;
     13 /// [`HashSet`] that has maximum [`HashSet::capacity`] and length and allocates exactly once.
     14 ///
     15 /// Note due to how `HashSet` removes values, it's possible to insert a value after removing a value and cause
     16 /// a new allocation. To avoid this, we ensure that the allocated capacity is at least twice the size of
     17 /// the requested maximum length.
     18 ///
     19 /// This is useful in situations when the underlying values are expected to be removed, and one wants to ensure the
     20 /// set does not grow unbounded. When `T` is a [`TimedCeremony`], helper methods (e.g.,
     21 /// [`Self::insert_remove_all_expired`]) are provided that will automatically remove expired values. Note the
     22 /// intended use case is for `T` to be based on a server-side randomly generated value; thus the default [`Hasher`]
     23 /// is [`BuildIdentityHasher`]. In the event this is not true, one MUST use a more appropriate `Hasher`.
     24 ///
     25 /// Only the mutable methods of `HashSet` are re-defined in order to ensure [`Self::max_len`] is never exceeded.
     26 /// For all other methods, first call [`Self::as_ref`] or [`Self::into`].
     27 ///
     28 /// [`Self::into`]: struct.MaxLenHashSet.html#impl-Into<U>-for-T
     29 #[derive(Debug)]
     30 pub struct MaxLenHashSet<T, S = BuildIdentityHasher>(HashSet<T, S>, usize);
     31 impl<T> MaxLenHashSet<T, BuildIdentityHasher> {
     32     /// [`HashSet::with_capacity_and_hasher`] using `2 * max_len` and `BuildIdentityHasher`.
     33     ///
     34     /// Note since the actual capacity allocated may exceed the requested capacity, [`Self::max_len`] may exceed
     35     /// `max_len`.
     36     ///
     37     /// # Panics
     38     ///
     39     /// `panic`s if `max_len > usize::MAX / 2`. Note since [`HashSet::with_capacity_and_hasher`] `panic`s
     40     /// for much smaller values than `usize::MAX / 2`—even when `T` is a zero-sized type (ZST)—this is not an
     41     /// additional `panic` than what would already occur. The only difference is the message reported.
     42     #[inline]
     43     #[must_use]
     44     pub fn new(max_len: usize) -> Self {
     45         Self::with_hasher(max_len, BuildIdentityHasher)
     46     }
     47 }
     48 impl<T, S> MaxLenHashSet<T, S> {
     49     /// Capacity we allocate.
     50     ///
     51     /// # Errors
     52     ///
     53     /// Errors iff `max_len > usize::MAX / 2`.
     54     const fn requested_capacity(max_len: usize) -> Result<usize, TryReserveError> {
     55         if max_len <= usize::MAX >> 1u8 {
     56             Ok(max_len << 1u8)
     57         } else {
     58             Err(TryReserveError::CapacityOverflow)
     59         }
     60     }
     61     /// Returns the immutable maximum length allowed by `self`.
     62     #[inline]
     63     pub const fn max_len(&self) -> usize {
     64         self.1
     65     }
     66     /// [`HashSet::clear`].
     67     #[inline]
     68     pub fn clear(&mut self) {
     69         self.0.clear();
     70     }
     71     /// [`HashSet::drain`].
     72     #[inline]
     73     pub fn drain(&mut self) -> Drain<'_, T> {
     74         self.0.drain()
     75     }
     76     /// [`HashSet::extract_if`].
     77     #[inline]
     78     pub fn extract_if<F: FnMut(&T) -> bool>(&mut self, f: F) -> ExtractIf<'_, T, F> {
     79         self.0.extract_if(f)
     80     }
     81     /// [`HashSet::with_capacity_and_hasher`] using `2 * max_len` and `hasher`.
     82     ///
     83     /// Note since the actual capacity allocated may exceed the requested capacity, [`Self::max_len`] may exceed
     84     /// `max_len`.
     85     ///
     86     /// # Panics
     87     ///
     88     /// `panic`s if `max_len > usize::MAX / 2`. Note since [`HashSet::with_capacity_and_hasher`] `panic`s
     89     /// for much smaller values than `usize::MAX / 2`—even when `T` is a zero-sized type (ZST)—this is not an
     90     /// additional `panic` than what would already occur. The only difference is the message reported.
     91     #[expect(
     92         clippy::expect_used,
     93         reason = "purpose of this function is to panic if the hash set cannot be allocated"
     94     )]
     95     #[inline]
     96     #[must_use]
     97     pub fn with_hasher(max_len: usize, hasher: S) -> Self {
     98         let set = HashSet::with_capacity_and_hasher(Self::requested_capacity(max_len).expect("HashSet::with_hasher must be passed a maximum length that does not exceed usize::MAX / 2"), hasher);
     99         let len = set.capacity() >> 1u8;
    100         Self(set, len)
    101     }
    102     /// [`HashSet::retain`].
    103     #[inline]
    104     pub fn retain<F: FnMut(&T) -> bool>(&mut self, f: F) {
    105         self.0.retain(f);
    106     }
    107 }
    108 impl<T: TimedCeremony, S> MaxLenHashSet<T, S> {
    109     /// Removes all expired ceremonies.
    110     #[inline]
    111     pub fn remove_expired_ceremonies(&mut self) {
    112         // Even though it's more accurate to check the current `Instant` for each ceremony, we elect to capture
    113         // the `Instant` we begin iteration for performance reasons. It's unlikely an appreciable amount of
    114         // additional ceremonies would be removed.
    115         #[cfg(not(feature = "serializable_server_state"))]
    116         let now = Instant::now();
    117         #[cfg(feature = "serializable_server_state")]
    118         let now = SystemTime::now();
    119         self.retain(|v| v.expiration() >= now);
    120     }
    121     /// Removes the first encountered expired ceremony.
    122     #[inline]
    123     pub fn remove_first_expired_ceremony(&mut self) {
    124         #[cfg(not(feature = "serializable_server_state"))]
    125         let now = Instant::now();
    126         #[cfg(feature = "serializable_server_state")]
    127         let now = SystemTime::now();
    128         drop(self.0.extract_if(|v| v.expiration() < now).next());
    129     }
    130 }
    131 impl<T: Eq + Hash, S: BuildHasher> MaxLenHashSet<T, S> {
    132     /// [`HashSet::with_hasher`] using `hasher` followed by [`HashSet::try_reserve`] using `2 * max_len`.
    133     ///
    134     /// Note since the actual capacity allocated may exceed the requested capacity, [`Self::max_len`] may exceed
    135     /// `max_len`.
    136     ///
    137     /// # Errors
    138     ///
    139     /// Errors iff `max_len > usize::MAX / 2` or [`HashSet::try_reserve`] does.
    140     #[inline]
    141     pub fn try_with_hasher(max_len: usize, hasher: S) -> Result<Self, TryReserveError> {
    142         Self::requested_capacity(max_len).and_then(|additional| {
    143             let mut set = HashSet::with_hasher(hasher);
    144             set.try_reserve(additional).map(|()| {
    145                 let len = set.capacity() >> 1u8;
    146                 Self(set, len)
    147             })
    148         })
    149     }
    150     /// [`HashSet::get_or_insert`].
    151     ///
    152     /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the
    153     /// set.
    154     #[inline]
    155     pub fn get_or_insert(&mut self, value: T) -> Option<&T> {
    156         if self.0.len() == self.1 {
    157             self.0.get(&value)
    158         } else {
    159             Some(self.0.get_or_insert(value))
    160         }
    161     }
    162     /// [`HashSet::get_or_insert_with`].
    163     ///
    164     /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the
    165     /// set.
    166     #[inline]
    167     pub fn get_or_insert_with<Q: Equivalent<T> + Hash + ?Sized, F: FnOnce(&Q) -> T>(
    168         &mut self,
    169         value: &Q,
    170         f: F,
    171     ) -> Option<&T> {
    172         if self.0.len() == self.1 {
    173             self.0.get(value)
    174         } else {
    175             Some(self.0.get_or_insert_with(value, f))
    176         }
    177     }
    178     /// [`HashSet::remove`].
    179     #[inline]
    180     pub fn remove<Q: Equivalent<T> + Hash + ?Sized>(&mut self, value: &Q) -> bool {
    181         self.0.remove(value)
    182     }
    183     /// [`HashSet::take`].
    184     #[inline]
    185     pub fn take<Q: Equivalent<T> + Hash + ?Sized>(&mut self, value: &Q) -> Option<T> {
    186         self.0.take(value)
    187     }
    188     /// [`HashSet::insert`].
    189     ///
    190     /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the
    191     /// set.
    192     #[inline]
    193     pub fn insert(&mut self, value: T) -> Option<bool> {
    194         let full = self.0.len() == self.1;
    195         if let Entry::Vacant(ent) = self.0.entry(value) {
    196             if full {
    197                 None
    198             } else {
    199                 _ = ent.insert();
    200                 Some(true)
    201             }
    202         } else {
    203             Some(false)
    204         }
    205     }
    206     /// [`HashSet::replace`].
    207     ///
    208     /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the
    209     /// set.
    210     #[inline]
    211     pub fn replace(&mut self, value: T) -> Option<Option<T>> {
    212         // Ideally we would use the Entry API to avoid searching multiple times, but one can't while also using
    213         // `replace` since there is no `OccupiedEntry::replace`.
    214         if self.0.contains(&value) {
    215             Some(self.0.replace(value))
    216         } else if self.0.len() == self.1 {
    217             None
    218         } else {
    219             _ = self.0.insert(value);
    220             Some(None)
    221         }
    222     }
    223     /// [`HashSet::entry`].
    224     ///
    225     /// `None` is returned iff [`HashSet::len`] `==` [`Self::max_len`] and `value` does not already exist in the
    226     /// set.
    227     #[inline]
    228     pub fn entry(&mut self, value: T) -> Option<Entry<'_, T, S>> {
    229         let full = self.0.len() == self.1;
    230         match self.0.entry(value) {
    231             ent @ Entry::Occupied(_) => Some(ent),
    232             ent @ Entry::Vacant(_) => {
    233                 if full {
    234                     None
    235                 } else {
    236                     Some(ent)
    237                 }
    238             }
    239         }
    240     }
    241 }
    242 impl<T: Eq + Hash + TimedCeremony, S: BuildHasher> MaxLenHashSet<T, S> {
    243     /// [`Self::insert`] except the first encountered expired ceremony is removed in the event [`Self::max_len`]
    244     /// items have been added.
    245     #[inline]
    246     pub fn insert_remove_expired(&mut self, value: T) -> Option<bool> {
    247         if self.0.len() == self.1 {
    248             #[cfg(not(feature = "serializable_server_state"))]
    249             let now = Instant::now();
    250             #[cfg(feature = "serializable_server_state")]
    251             let now = SystemTime::now();
    252             if self.0.extract_if(|v| v.expiration() < now).next().is_some() {
    253                 Some(self.0.insert(value))
    254             } else {
    255                 self.0.contains(&value).then_some(false)
    256             }
    257         } else {
    258             Some(self.0.insert(value))
    259         }
    260     }
    261     /// [`Self::insert`] except all expired ceremones are removed in the event [`Self::max_len`] items have
    262     /// been added.
    263     #[inline]
    264     pub fn insert_remove_all_expired(&mut self, value: T) -> Option<bool> {
    265         if self.0.len() == self.1 {
    266             self.remove_expired_ceremonies();
    267         }
    268         if self.0.len() == self.1 {
    269             self.0.contains(&value).then_some(false)
    270         } else {
    271             Some(self.0.insert(value))
    272         }
    273     }
    274     /// [`Self::entry`] except the first encountered expired ceremony is removed in the event [`Self::max_len`]
    275     /// items have been added.
    276     #[inline]
    277     pub fn entry_remove_expired(&mut self, value: T) -> Option<Entry<'_, T, S>> {
    278         if self.0.len() == self.1 {
    279             #[cfg(not(feature = "serializable_server_state"))]
    280             let now = Instant::now();
    281             #[cfg(feature = "serializable_server_state")]
    282             let now = SystemTime::now();
    283             if self.0.extract_if(|v| v.expiration() < now).next().is_some() {
    284                 Some(self.0.entry(value))
    285             } else if let ent @ Entry::Occupied(_) = self.0.entry(value) {
    286                 Some(ent)
    287             } else {
    288                 None
    289             }
    290         } else {
    291             Some(self.0.entry(value))
    292         }
    293     }
    294     /// [`Self::entry`] except all expired ceremones are removed in the event [`Self::max_len`] items have
    295     /// been added.
    296     #[inline]
    297     pub fn entry_remove_all_expired(&mut self, value: T) -> Option<Entry<'_, T, S>> {
    298         if self.0.len() == self.1 {
    299             self.remove_expired_ceremonies();
    300         }
    301         if self.0.len() == self.1 {
    302             if let ent @ Entry::Occupied(_) = self.0.entry(value) {
    303                 Some(ent)
    304             } else {
    305                 None
    306             }
    307         } else {
    308             Some(self.0.entry(value))
    309         }
    310     }
    311 }
    312 impl<T, S> AsRef<HashSet<T, S>> for MaxLenHashSet<T, S> {
    313     #[inline]
    314     fn as_ref(&self) -> &HashSet<T, S> {
    315         &self.0
    316     }
    317 }
    318 impl<T, S> From<MaxLenHashSet<T, S>> for HashSet<T, S> {
    319     #[inline]
    320     fn from(value: MaxLenHashSet<T, S>) -> Self {
    321         value.0
    322     }
    323 }
    324 #[cfg(test)]
    325 mod tests {
    326     use super::{Equivalent, MaxLenHashSet, TimedCeremony};
    327     use core::hash::{Hash, Hasher};
    328     #[cfg(not(feature = "serializable_server_state"))]
    329     use std::time::Instant;
    330     #[cfg(feature = "serializable_server_state")]
    331     use std::time::SystemTime;
    332     #[derive(Clone, Copy)]
    333     struct Ceremony {
    334         id: usize,
    335         #[cfg(not(feature = "serializable_server_state"))]
    336         exp: Instant,
    337         #[cfg(feature = "serializable_server_state")]
    338         exp: SystemTime,
    339     }
    340     impl Default for Ceremony {
    341         fn default() -> Self {
    342             Self {
    343                 id: 0,
    344                 #[cfg(not(feature = "serializable_server_state"))]
    345                 exp: Instant::now(),
    346                 #[cfg(feature = "serializable_server_state")]
    347                 exp: SystemTime::now(),
    348             }
    349         }
    350     }
    351     impl PartialEq for Ceremony {
    352         fn eq(&self, other: &Self) -> bool {
    353             self.id == other.id
    354         }
    355     }
    356     impl Eq for Ceremony {}
    357     impl Hash for Ceremony {
    358         fn hash<H: Hasher>(&self, state: &mut H) {
    359             self.id.hash(state);
    360         }
    361     }
    362     impl TimedCeremony for Ceremony {
    363         #[cfg(not(feature = "serializable_server_state"))]
    364         fn expiration(&self) -> Instant {
    365             self.exp
    366         }
    367         #[cfg(feature = "serializable_server_state")]
    368         fn expiration(&self) -> SystemTime {
    369             self.exp
    370         }
    371     }
    372     impl Equivalent<Ceremony> for usize {
    373         fn equivalent(&self, key: &Ceremony) -> bool {
    374             *self == key.id
    375         }
    376     }
    377     #[test]
    378     fn hash_set_insert_removed() {
    379         const REQ_MAX_LEN: usize = 8;
    380         let mut set = MaxLenHashSet::new(REQ_MAX_LEN);
    381         let cap = set.as_ref().capacity();
    382         let max_len = set.max_len();
    383         assert_eq!(cap >> 1u8, max_len);
    384         assert!(max_len >= REQ_MAX_LEN);
    385         let mut cer = Ceremony::default();
    386         for i in 0..max_len {
    387             assert!(set.as_ref().capacity() <= cap);
    388             cer.id = i;
    389             assert_eq!(set.insert(cer), Some(true));
    390         }
    391         assert!(set.as_ref().capacity() <= cap);
    392         assert_eq!(set.as_ref().len(), max_len);
    393         for i in 0..max_len {
    394             assert!(set.as_ref().contains(&i));
    395         }
    396         cer.id = cap;
    397         assert_eq!(set.insert_remove_expired(cer), Some(true));
    398         assert!(set.as_ref().capacity() <= cap);
    399         assert_eq!(set.as_ref().len(), max_len);
    400         let mut counter = 0;
    401         for i in 0..max_len {
    402             counter += usize::from(set.as_ref().contains(&i));
    403         }
    404         assert_eq!(counter, max_len - 1);
    405         assert!(set.as_ref().contains(&(max_len - 1)));
    406     }
    407 }