rpz

Response policy zone (RPZ) file generator.
git clone https://git.philomathiclife.com/repos/rpz
Log | Files | Refs | README

config.rs (16820B)


      1 extern crate alloc;
      2 use alloc::borrow::Cow;
      3 use core::{
      4     fmt::{self, Display, Formatter},
      5     time::Duration,
      6 };
      7 use rpz::file::{AbsFilePath, HttpUrl};
      8 use serde::de::{Deserialize, Deserializer, Error, MapAccess, SeqAccess, Unexpected, Visitor};
      9 use std::collections::HashSet;
     10 /// The TOML config file.
     11 #[derive(Debug)]
     12 pub struct Config {
     13     /// The maximum amount of time allowed for an HTTP(S) file to be downloaded.
     14     pub timeout: Option<Duration>,
     15     /// The absolute file path for the [response policy zone (RPZ)](https://en.wikipedia.org/wiki/Response_policy_zone) file.
     16     pub rpz: Option<AbsFilePath<false>>,
     17     /// The absolute file path to the directory that contains local (un)block files.
     18     pub local_dir: Option<AbsFilePath<true>>,
     19     /// The unique absolute HTTP(S) URLs to [Adblock-style](https://adguard-dns.io/kb/general/dns-filtering-syntax/#adblock-style-syntax)
     20     /// block lists.
     21     pub adblock: HashSet<HttpUrl>,
     22     /// The unique absolute HTTP(S) URLs to [domains-only](https://adguard-dns.io/kb/general/dns-filtering-syntax/#domains-only-syntax)
     23     /// block lists.
     24     pub domain: HashSet<HttpUrl>,
     25     /// The unique absolute HTTP(S) URLs to [`hosts(5)`-style](https://adguard-dns.io/kb/general/dns-filtering-syntax/#etc-hosts-syntax)
     26     /// block lists.
     27     pub hosts: HashSet<HttpUrl>,
     28     /// The unique absolute HTTP(S) URLs to [wildcard domain](https://pgl.yoyo.org/adservers/serverlist.php?hostformat=adblock&showintro=0&mimetype=plaintext)
     29     /// block lists.
     30     pub wildcard: HashSet<HttpUrl>,
     31 }
     32 impl Display for Config {
     33     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
     34         /// Helper function that writes the `Url`s in a `HashSet<HttpUrl>`.
     35         fn keys(set: &HashSet<HttpUrl>, f: &mut Formatter<'_>, name: &str) -> fmt::Result {
     36             write!(f, "{name}: [").and_then(|()| {
     37                 set.iter()
     38                     .try_fold((), |(), url| write!(f, "{url}, "))
     39                     .and_then(|()| f.write_str("], "))
     40             })
     41         }
     42         write!(
     43             f,
     44             "Config {{ timeout: {} rpz: {}, local_dir: {}, ",
     45             self.timeout
     46                 .map_or_else(String::new, |dur| dur.as_secs().to_string()),
     47             &self
     48                 .rpz
     49                 .as_ref()
     50                 .map_or_else(|| Cow::Owned(String::new()), |file| file.to_string_lossy()),
     51             &self
     52                 .local_dir
     53                 .as_ref()
     54                 .map_or_else(|| Cow::Owned(String::new()), |dir| dir.to_string_lossy())
     55         )
     56         .and_then(|()| {
     57             keys(&self.adblock, f, "adblock")
     58                 .and_then(|()| keys(&self.domain, f, "domain"))
     59                 .and_then(|()| keys(&self.hosts, f, "hosts"))
     60                 .and_then(|()| keys(&self.wildcard, f, "wildcard"))
     61                 .and_then(|()| f.write_str("}"))
     62         })
     63     }
     64 }
     65 impl<'de> Deserialize<'de> for Config {
     66     #[expect(clippy::too_many_lines, reason = "this is fine")]
     67     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     68     where
     69         D: Deserializer<'de>,
     70     {
     71         /// Config file fields.
     72         enum Field {
     73             /// Timeout field.
     74             Timeout,
     75             /// RPZ field.
     76             Rpz,
     77             /// Local directory field.
     78             LocalDir,
     79             /// Adblock URLs field.
     80             Adblock,
     81             /// Domain-only URLs field.
     82             Domain,
     83             /// hosts(5) URLs field.
     84             Hosts,
     85             /// Wildcard URLs field.
     86             Wildcard,
     87         }
     88         impl<'d> Deserialize<'d> for Field {
     89             fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     90             where
     91                 D: Deserializer<'d>,
     92             {
     93                 /// `Visitor` for `Field`.
     94                 struct FieldVisitor;
     95                 impl<'de> Visitor<'de> for FieldVisitor {
     96                     type Value = Field;
     97                     fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
     98                         formatter.write_str(
     99                             "'timeout', 'rpz', 'local_dir', 'adblock', 'domain', 'hosts', or 'wildcard'",
    100                         )
    101                     }
    102                     fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    103                     where
    104                         E: Error,
    105                     {
    106                         match v {
    107                             "timeout" => Ok(Field::Timeout),
    108                             "rpz" => Ok(Field::Rpz),
    109                             "local_dir" => Ok(Field::LocalDir),
    110                             "adblock" => Ok(Field::Adblock),
    111                             "domain" => Ok(Field::Domain),
    112                             "hosts" => Ok(Field::Hosts),
    113                             "wildcard" => Ok(Field::Wildcard),
    114                             _ => Err(E::unknown_field(v, &VARIANTS)),
    115                         }
    116                     }
    117                 }
    118                 deserializer.deserialize_identifier(FieldVisitor)
    119             }
    120         }
    121         /// `Visitor` for `Config`.
    122         struct ConfigVisitor;
    123         impl<'d> Visitor<'d> for ConfigVisitor {
    124             type Value = Config;
    125             fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
    126                 formatter.write_str("struct Config")
    127             }
    128             #[expect(
    129                 clippy::as_conversions,
    130                 clippy::cast_lossless,
    131                 clippy::too_many_lines,
    132                 reason = "carefull verify use is correct"
    133             )]
    134             fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
    135             where
    136                 A: MapAccess<'d>,
    137             {
    138                 /// Verifies that the `HashSet`s are pairwise disjoint.
    139                 #[expect(
    140                     clippy::arithmetic_side_effects,
    141                     clippy::indexing_slicing,
    142                     reason = "carefully verify use is correct"
    143                 )]
    144                 fn hash_overlap<E: Error>(maps: &[&HashSet<HttpUrl>]) -> Result<(), E> {
    145                     /// Verifies the intersection of `left` and `right` is empty.
    146                     fn url_overlap<E: Error>(
    147                         left: &HashSet<HttpUrl>,
    148                         right: &HashSet<HttpUrl>,
    149                     ) -> Result<(), E> {
    150                         let (mut iter, urls) = if left.len() <= right.len() {
    151                             (left.iter(), right)
    152                         } else {
    153                             (right.iter(), left)
    154                         };
    155                         iter.try_fold((), |(), url| {
    156                             if urls.contains(url) {
    157                                 Err(E::invalid_type(
    158                                     Unexpected::Other(url.to_string().as_str()),
    159                                     &"unique URLs across the block list types",
    160                                 ))
    161                             } else {
    162                                 Ok(())
    163                             }
    164                         })
    165                     }
    166                     maps.iter().enumerate().try_fold((), |(), (idx, map)| {
    167                         maps[idx + 1..]
    168                             .iter()
    169                             .try_fold((), |(), map2| url_overlap(map, map2))
    170                     })
    171                 }
    172                 /// Wrapper around a `HashSet` that is deserializable.
    173                 struct Urls(HashSet<HttpUrl>);
    174                 impl<'de> Deserialize<'de> for Urls {
    175                     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    176                     where
    177                         D: Deserializer<'de>,
    178                     {
    179                         /// `Visitor` for `Urls`.
    180                         struct HashVisitor;
    181                         impl<'d> Visitor<'d> for HashVisitor {
    182                             type Value = Urls;
    183                             fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
    184                                 formatter.write_str("struct Urls")
    185                             }
    186                             fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
    187                             where
    188                                 A: SeqAccess<'d>,
    189                             {
    190                                 let mut urls = HashSet::with_capacity(seq.size_hint().unwrap_or(0));
    191                                 while let Some(url) = seq.next_element::<HttpUrl>()? {
    192                                     urls.replace(url).map_or(Ok(()), |dup| {
    193                                         Err(Error::invalid_value(
    194                                             Unexpected::Other(dup.to_string().as_str()),
    195                                             &"a set of unique URLs",
    196                                         ))
    197                                     })?;
    198                                 }
    199                                 Ok(Urls(urls))
    200                             }
    201                         }
    202                         deserializer.deserialize_seq(HashVisitor)
    203                     }
    204                 }
    205                 let mut timeout = None;
    206                 let mut rpz = None;
    207                 let mut local_dir = None;
    208                 let mut ad = None;
    209                 let mut dom = None;
    210                 let mut hst = None;
    211                 let mut wc = None;
    212                 while let Some(key) = map.next_key()? {
    213                     match key {
    214                         Field::Timeout => {
    215                             if timeout.is_some() {
    216                                 return Err(Error::duplicate_field("timeout"));
    217                             }
    218                             timeout = Some(Duration::from_secs(map.next_value::<u32>()? as u64));
    219                         }
    220                         Field::Rpz => {
    221                             if rpz.is_some() {
    222                                 return Err(Error::duplicate_field("rpz"));
    223                             }
    224                             rpz = Some(map.next_value::<AbsFilePath<false>>()?);
    225                         }
    226                         Field::LocalDir => {
    227                             if local_dir.is_some() {
    228                                 return Err(Error::duplicate_field("local_dir"));
    229                             }
    230                             local_dir = Some(map.next_value::<AbsFilePath<true>>()?);
    231                         }
    232                         Field::Adblock => {
    233                             if ad.is_some() {
    234                                 return Err(Error::duplicate_field("adblock"));
    235                             }
    236                             ad = Some(map.next_value::<Urls>()?);
    237                         }
    238                         Field::Domain => {
    239                             if dom.is_some() {
    240                                 return Err(Error::duplicate_field("domain"));
    241                             }
    242                             dom = Some(map.next_value::<Urls>()?);
    243                         }
    244                         Field::Hosts => {
    245                             if hst.is_some() {
    246                                 return Err(Error::duplicate_field("hosts"));
    247                             }
    248                             hst = Some(map.next_value::<Urls>()?);
    249                         }
    250                         Field::Wildcard => {
    251                             if wc.is_some() {
    252                                 return Err(Error::duplicate_field("wildcard"));
    253                             }
    254                             wc = Some(map.next_value::<Urls>()?);
    255                         }
    256                     }
    257                 }
    258                 if local_dir.is_none()
    259                     && ad.as_ref().map_or(true, |urls| urls.0.is_empty())
    260                     && dom.as_ref().map_or(true, |urls| urls.0.is_empty())
    261                     && hst.as_ref().map_or(true, |urls| urls.0.is_empty())
    262                     && wc.as_ref().map_or(true, |urls| urls.0.is_empty())
    263                 {
    264                     Err(Error::invalid_type(
    265                         Unexpected::Other("no block list URLs or directory"),
    266                         &"at least one block list entry (i.e., 'local_dir', 'adblock', 'domain', 'hosts', or 'wildcard' must exist and not be empty)",
    267                     ))
    268                 } else {
    269                     let adblock = ad.map_or_else(HashSet::new, |urls| urls.0);
    270                     let domain = dom.map_or_else(HashSet::new, |urls| urls.0);
    271                     let hosts = hst.map_or_else(HashSet::new, |urls| urls.0);
    272                     let wildcard = wc.map_or_else(HashSet::new, |urls| urls.0);
    273                     hash_overlap([&adblock, &domain, &hosts, &wildcard].as_slice()).map(|()| {
    274                         Config {
    275                             timeout,
    276                             rpz,
    277                             local_dir,
    278                             adblock,
    279                             domain,
    280                             hosts,
    281                             wildcard,
    282                         }
    283                     })
    284                 }
    285             }
    286         }
    287         /// `Config` fields.
    288         const VARIANTS: [&str; 7] = [
    289             "timeout",
    290             "rpz",
    291             "local_dir",
    292             "adblock",
    293             "domain",
    294             "hosts",
    295             "wildcard",
    296         ];
    297         deserializer.deserialize_struct("Config", &VARIANTS, ConfigVisitor)
    298     }
    299 }
    300 #[cfg(test)]
    301 mod tests {
    302     use crate::Config;
    303     #[test]
    304     fn test_missing_fields() {
    305         assert!(toml::from_str::<Config>("").is_err());
    306         assert!(toml::from_str::<Config>("timeout=15").is_err());
    307         assert!(toml::from_str::<Config>(r#"rpz="/foo""#).is_err());
    308         assert!(toml::from_str::<Config>(r#"local_dir="/foo/""#).is_ok());
    309         assert!(toml::from_str::<Config>(r#"adblock=["https://foo.com/foo"]"#).is_ok());
    310     }
    311     #[test]
    312     fn test_invalid_fields() {
    313         assert!(toml::from_str::<Config>("bob=15").is_err());
    314         assert!(toml::from_str::<Config>(r#"foo=["https://foo.com/foo"]"#).is_err());
    315     }
    316     #[test]
    317     fn test_timeout() {
    318         assert!(toml::from_str::<Config>(r#"local_dir="/foo/""#).is_ok());
    319         assert!(toml::from_str::<Config>(
    320             r#"timeout=15
    321 local_dir="/foo/""#
    322         )
    323         .is_ok());
    324         assert!(toml::from_str::<Config>(
    325             r#"timeout=4294967295
    326 local_dir="/foo/""#
    327         )
    328         .is_ok());
    329         assert!(toml::from_str::<Config>(
    330             r#"timeout=-1
    331 local_dir="/foo/""#
    332         )
    333         .is_err());
    334         assert!(toml::from_str::<Config>(
    335             r#"timeout=0
    336 local_dir="/foo/""#
    337         )
    338         .is_ok());
    339         assert!(toml::from_str::<Config>(
    340             r#"timeout=4294967296
    341 local_dir="/foo/""#
    342         )
    343         .is_err());
    344     }
    345     #[test]
    346     fn test_arrays() {
    347         assert!(toml::from_str::<Config>(
    348             r#"adblock=["https://foo.com/foo","https://foo.com/foo"]"#
    349         )
    350         .is_err());
    351         assert!(toml::from_str::<Config>(r#"adblock=["https://foo.com/foo"]"#).is_ok());
    352         assert!(toml::from_str::<Config>(
    353             r#"adblock=["https://foo.com/foo"]
    354 domain=["https://foo.com/foo"]"#
    355         )
    356         .is_err());
    357     }
    358     #[test]
    359     fn test_urls() {
    360         assert!(toml::from_str::<Config>(r#"adblock=["file://foo.com/foo"]"#).is_err());
    361         assert!(toml::from_str::<Config>(r#"adblock=["foo.com/foo"]"#).is_err());
    362         assert!(toml::from_str::<Config>(r#"adblock=["http://foo.com/foo"]"#).is_ok());
    363         assert!(toml::from_str::<Config>(
    364             r#"adblock=[]
    365 domain=["https:///foo"]"#
    366         )
    367         .is_ok());
    368         assert!(toml::from_str::<Config>(r#"adblock=[""]"#).is_err());
    369         assert!(toml::from_str::<Config>(r#"adblock=["https://"]"#).is_err());
    370         assert!(toml::from_str::<Config>(r#"adblock=["ftp://foo.com/foo"]"#).is_err());
    371     }
    372     #[test]
    373     fn test_paths() {
    374         assert!(toml::from_str::<Config>(
    375             r#"rpz="/foo/"
    376 wildcard=["https://foo.com/foo"]"#
    377         )
    378         .is_err());
    379         assert!(toml::from_str::<Config>(
    380             r#"rpz="foo"
    381 wildcard=["https://foo.com/foo"]"#
    382         )
    383         .is_err());
    384         assert!(toml::from_str::<Config>(
    385             r#"rpz="/foo"
    386 wildcard=["https://foo.com/foo"]"#
    387         )
    388         .is_ok());
    389         assert!(toml::from_str::<Config>(r#"local_dir="foo/""#).is_err());
    390         assert!(toml::from_str::<Config>(r#"local_dir="/foo/""#).is_ok());
    391         // Directories are allowed to not have a trailing `/`, but they will get it
    392         // added.
    393         assert!(toml::from_str::<Config>(
    394             r#"local_dir="/foo"
    395 wildcard=["https://foo.com/foo"]"#
    396         )
    397         .map_or(false, |config| config.local_dir.map_or(false, |dir| dir
    398             .to_str()
    399             .map_or(false, |val| val == String::from("/foo/")))));
    400     }
    401 }