tokio_dual_stack

Dual-stack TcpListener.
git clone https://git.philomathiclife.com/repos/tokio_dual_stack
Log | Files | Refs | README

lib.rs (21638B)


      1 //! [![git]](https://git.philomathiclife.com/tokio_dual_stack/log.html) [![crates-io]](https://crates.io/crates/tokio_dual_stack) [![docs-rs]](crate)
      2 //!
      3 //! [git]: https://git.philomathiclife.com/git_badge.svg
      4 //! [crates-io]: https://img.shields.io/badge/crates.io-fc8d62?style=for-the-badge&labelColor=555555&logo=rust
      5 //! [docs-rs]: https://img.shields.io/badge/docs.rs-66c2a5?style=for-the-badge&labelColor=555555&logo=docs.rs
      6 //!
      7 //! `tokio_dual_stack` is a library that adds a "dual-stack" [`TcpListener`].
      8 //!
      9 //! ## Why is this useful?
     10 //!
     11 //! Only certain platforms offer the ability for one socket to handle both IPv6 and IPv4 requests
     12 //! (e.g., OpenBSD does not). For the platforms that do, it is often dependent on runtime configuration
     13 //! (e.g., [`IPV6_V6ONLY`](https://www.man7.org/linux/man-pages/man7/ipv6.7.html)). Additionally those platforms
     14 //! that support it often require the "wildcard" IPv6 address to be used (i.e., `::`) which has the unfortunate
     15 //! consequence of preventing other services from using the same protocol port.
     16 //!
     17 //! There are a few ways to work around this issue. One is to deploy the same service twice: one that uses
     18 //! an IPv6 socket and the other that uses an IPv4 socket. This can complicate deployments (e.g., the application
     19 //! may not have been written with the expectation that multiple deployments could be running at the same time) in
     20 //! addition to using more resources. Another is for the application to manually handle each socket (e.g.,
     21 //! [`select`](https://docs.rs/tokio/latest/tokio/macro.select.html)/[`join`](https://docs.rs/tokio/latest/tokio/macro.join.html)
     22 //! each [`TcpListener::accept`]).
     23 //!
     24 //! [`DualStackTcpListener`] chooses an implementation similar to what the equivalent `select` would do while
     25 //! also ensuring that one socket does not "starve" another by ensuring each socket is fairly given an opportunity
     26 //! to `TcpListener::accept` a connection. This has the nice benefit of having a similar API to what a single
     27 //! `TcpListener` would have as well as having similar performance to a socket that does handle both IPv6 and
     28 //! IPv4 requests.
     29 #![deny(
     30     unknown_lints,
     31     future_incompatible,
     32     let_underscore,
     33     missing_docs,
     34     nonstandard_style,
     35     refining_impl_trait,
     36     rust_2018_compatibility,
     37     rust_2018_idioms,
     38     rust_2021_compatibility,
     39     rust_2024_compatibility,
     40     unsafe_code,
     41     unused,
     42     warnings,
     43     clippy::all,
     44     clippy::cargo,
     45     clippy::complexity,
     46     clippy::correctness,
     47     clippy::nursery,
     48     clippy::pedantic,
     49     clippy::perf,
     50     clippy::restriction,
     51     clippy::style,
     52     clippy::suspicious
     53 )]
     54 #![expect(
     55     clippy::arbitrary_source_item_ordering,
     56     clippy::blanket_clippy_restriction_lints,
     57     clippy::implicit_return,
     58     reason = "noisy, opinionated, and likely doesn't prevent bugs or improve APIs"
     59 )]
     60 use core::{
     61     future::Future,
     62     net::{SocketAddr, SocketAddrV4, SocketAddrV6},
     63     pin::Pin,
     64     sync::atomic::{AtomicBool, Ordering},
     65     task::{Context, Poll},
     66 };
     67 use pin_project_lite::pin_project;
     68 use std::io::{Error, ErrorKind, Result};
     69 use tokio::net::{self, TcpListener, TcpSocket, TcpStream, ToSocketAddrs};
     70 /// Prevents [`Sealed`] from being publicly implementable.
     71 mod private {
     72     /// Marker trait to prevent [`super::Tcp`] from being publicly implementable.
     73     pub trait Sealed {}
     74 }
     75 use private::Sealed;
     76 /// TCP "listener".
     77 ///
     78 /// This `trait` is sealed and cannot be implemented for types outside of `tokio_dual_stack`.
     79 ///
     80 /// This exists primarily as a way to define type constructors or polymorphic functions
     81 /// that can user either a [`TcpListener`] or [`DualStackTcpListener`].
     82 ///
     83 /// # Examples
     84 ///
     85 /// ```no_run
     86 /// # use core::convert::Infallible;
     87 /// # use tokio_dual_stack::Tcp;
     88 /// async fn main_loop<T: Tcp>(listener: T) -> Infallible {
     89 ///     loop {
     90 ///         match listener.accept().await {
     91 ///             Ok((_, socket)) => println!("Client socket: {socket}"),
     92 ///             Err(e) => println!("TCP connection failure: {e}"),
     93 ///         }
     94 ///     }
     95 /// }
     96 /// ```
     97 pub trait Tcp: Sealed + Sized {
     98     /// Creates a new TCP listener, which will be bound to the specified address(es).
     99     ///
    100     /// The returned listener is ready for accepting connections.
    101     ///
    102     /// Binding with a port number of 0 will request that the OS assigns a port to this listener.
    103     /// The port allocated can be queried via the `local_addr` method.
    104     ///
    105     /// The address type can be any implementor of the [`ToSocketAddrs`] trait. If `addr` yields
    106     /// multiple addresses, bind will be attempted with each of the addresses until one succeeds
    107     /// and returns the listener. If none of the addresses succeed in creating a listener, the
    108     /// error returned from the last attempt (the last address) is returned.
    109     ///
    110     /// This function sets the `SO_REUSEADDR` option on the socket.
    111     ///
    112     /// # Examples
    113     ///
    114     /// ```no_run
    115     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    116     /// # use std::io::Result;
    117     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    118     /// #[tokio::main(flavor = "current_thread")]
    119     /// async fn main() -> Result<()> {
    120     ///     let listener = DualStackTcpListener::bind(
    121     ///         [
    122     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    123     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    124     ///         ]
    125     ///         .as_slice(),
    126     ///     )
    127     ///     .await?;
    128     ///     Ok(())
    129     /// }
    130     /// ```
    131     fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>>;
    132     /// Accepts a new incoming connection from this listener.
    133     ///
    134     /// This function will yield once a new TCP connection is established. When established,
    135     /// the corresponding `TcpStream` and the remote peer’s address will be returned.
    136     ///
    137     /// # Cancel safety
    138     ///
    139     /// This method is cancel safe. If the method is used as the event in a
    140     /// [`tokio::select!`](https://docs.rs/tokio/latest/tokio/macro.select.html)
    141     /// statement and some other branch completes first, then it is guaranteed that no new
    142     /// connections were accepted by this method.
    143     ///
    144     /// # Examples
    145     ///
    146     /// ```no_run
    147     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    148     /// # use std::io::Result;
    149     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    150     /// #[tokio::main(flavor = "current_thread")]
    151     /// async fn main() -> Result<()> {
    152     ///     match DualStackTcpListener::bind(
    153     ///         [
    154     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    155     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    156     ///         ]
    157     ///         .as_slice(),
    158     ///     )
    159     ///     .await?.accept().await {
    160     ///         Ok((_, addr)) => println!("new client: {addr}"),
    161     ///         Err(e) => println!("couldn't get client: {e}"),
    162     ///     }
    163     ///     Ok(())
    164     /// }
    165     /// ```
    166     fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync;
    167     /// Polls to accept a new incoming connection to this listener.
    168     ///
    169     /// If there is no connection to accept, `Poll::Pending` is returned and the current task will be notified by
    170     /// a waker. Note that on multiple calls to `poll_accept`, only the `Waker` from the `Context` passed to the
    171     /// most recent call is scheduled to receive a wakeup.
    172     fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>>;
    173 }
    174 impl Sealed for TcpListener {}
    175 impl Tcp for TcpListener {
    176     #[inline]
    177     fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>> {
    178         Self::bind(addr)
    179     }
    180     #[inline]
    181     fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
    182         self.accept()
    183     }
    184     #[inline]
    185     fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
    186         self.poll_accept(cx)
    187     }
    188 }
    189 /// "Dual-stack" TCP listener.
    190 ///
    191 /// IPv6 and IPv4 TCP listener.
    192 #[derive(Debug)]
    193 pub struct DualStackTcpListener {
    194     /// IPv6 TCP listener.
    195     ip6: TcpListener,
    196     /// IPv4 TCP listener.
    197     ip4: TcpListener,
    198     /// `true` iff [`Self::ip6::accept`] should be `poll`ed first; otherwise [`Self::ip4::accept`] is `poll`ed
    199     /// first.
    200     ///
    201     /// This exists to prevent one IP version from "starving" another. Each time [`Self::accept`] or
    202     /// [`Self::poll_accept`] is called, it's overwritten with the opposite `bool`.
    203     ///
    204     /// Note we could make this a `core::cell::Cell`; but for maximal flexibility and consistency with `TcpListener`,
    205     /// we use an `AtomicBool`. This among other things means `DualStackTcpListener` will implement `Sync`.
    206     ip6_first: AtomicBool,
    207 }
    208 impl DualStackTcpListener {
    209     /// Creates `Self` using the [`TcpListener`]s returned from [`TcpSocket::listen`].
    210     ///
    211     /// [`Self::bind`] is useful when the behavior of [`TcpListener::bind`] is sufficient; however if the underlying
    212     /// `TcpSocket`s need to be configured differently, then one must call this function instead.
    213     ///
    214     /// # Errors
    215     ///
    216     /// Errors iff [`TcpSocket::local_addr`] does for either socket, the underlying sockets use the same IP version,
    217     /// or [`TcpSocket::listen`] errors for either socket.
    218     ///
    219     /// Note on Windows-based platforms `TcpSocket::local_addr` will error if [`TcpSocket::bind`] was not called.
    220     ///
    221     /// # Examples
    222     ///
    223     /// ```no_run
    224     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    225     /// # use std::io::Result;
    226     /// # use tokio_dual_stack::DualStackTcpListener;
    227     /// # use tokio::net::TcpSocket;
    228     /// #[tokio::main(flavor = "current_thread")]
    229     /// async fn main() -> Result<()> {
    230     ///     let ip6 = TcpSocket::new_v6()?;
    231     ///     ip6.bind(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)))?;
    232     ///     let ip4 = TcpSocket::new_v4()?;
    233     ///     ip4.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)))?;
    234     ///     let listener = DualStackTcpListener::from_sockets((ip6, 1024), (ip4, 1024))?;
    235     ///     Ok(())
    236     /// }
    237     /// ```
    238     #[inline]
    239     pub fn from_sockets(
    240         (socket_1, backlog_1): (TcpSocket, u32),
    241         (socket_2, backlog_2): (TcpSocket, u32),
    242     ) -> Result<Self> {
    243         socket_1.local_addr().and_then(|sock| {
    244             socket_2.local_addr().and_then(|sock_2| {
    245                 if sock.is_ipv6() {
    246                     if sock_2.is_ipv4() {
    247                         socket_1.listen(backlog_1).and_then(|ip6| {
    248                             socket_2.listen(backlog_2).map(|ip4| Self {
    249                                 ip6,
    250                                 ip4,
    251                                 ip6_first: AtomicBool::new(true),
    252                             })
    253                         })
    254                     } else {
    255                         Err(Error::new(
    256                             ErrorKind::InvalidData,
    257                             "TcpSockets are the same IP version",
    258                         ))
    259                     }
    260                 } else if sock_2.is_ipv6() {
    261                     socket_1.listen(backlog_1).and_then(|ip4| {
    262                         socket_2.listen(backlog_2).map(|ip6| Self {
    263                             ip6,
    264                             ip4,
    265                             ip6_first: AtomicBool::new(true),
    266                         })
    267                     })
    268                 } else {
    269                     Err(Error::new(
    270                         ErrorKind::InvalidData,
    271                         "TcpSockets are the same IP version",
    272                     ))
    273                 }
    274             })
    275         })
    276     }
    277     /// Returns the local address of each socket that the listeners are bound to.
    278     ///
    279     /// This can be useful, for example, when binding to port 0 to figure out which port was actually bound.
    280     ///
    281     /// # Errors
    282     ///
    283     /// Errors iff [`TcpListener::local_addr`] does for either listener.
    284     ///
    285     /// # Examples
    286     ///
    287     /// ```no_run
    288     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    289     /// # use std::io::Result;
    290     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    291     /// #[tokio::main(flavor = "current_thread")]
    292     /// async fn main() -> Result<()> {
    293     ///     let ip6 = SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0);
    294     ///     let ip4 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080);
    295     ///     assert_eq!(
    296     ///         DualStackTcpListener::bind([SocketAddr::V6(ip6), SocketAddr::V4(ip4)].as_slice())
    297     ///             .await?
    298     ///             .local_addr()?,
    299     ///         (ip6, ip4)
    300     ///     );
    301     ///     Ok(())
    302     /// }
    303     /// ```
    304     #[expect(clippy::unreachable, reason = "we want to crash when there is a bug")]
    305     #[inline]
    306     pub fn local_addr(&self) -> Result<(SocketAddrV6, SocketAddrV4)> {
    307         self.ip6.local_addr().and_then(|ip6| {
    308             self.ip4.local_addr().map(|ip4| {
    309                 (
    310                     if let SocketAddr::V6(sock6) = ip6 {
    311                         sock6
    312                     } else {
    313                         unreachable!("there is a bug in DualStackTcpListener::bind")
    314                     },
    315                     if let SocketAddr::V4(sock4) = ip4 {
    316                         sock4
    317                     } else {
    318                         unreachable!("there is a bug in DualStackTcpListener::bind")
    319                     },
    320                 )
    321             })
    322         })
    323     }
    324     /// Sets the value for the `IP_TTL` option on both sockets.
    325     ///
    326     /// This value sets the time-to-live field that is used in every packet sent from each socket.
    327     /// `ttl_ip6` is the `IP_TTL` value for the IPv6 socket and `ttl_ip4` is the `IP_TTL` value for the
    328     /// IPv4 socket.
    329     ///
    330     /// # Errors
    331     ///
    332     /// Errors iff [`TcpListener::set_ttl`] does for either listener.
    333     ///
    334     /// # Examples
    335     ///
    336     /// ```no_run
    337     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    338     /// # use std::io::Result;
    339     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    340     /// #[tokio::main(flavor = "current_thread")]
    341     /// async fn main() -> Result<()> {
    342     ///     DualStackTcpListener::bind(
    343     ///         [
    344     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    345     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    346     ///         ]
    347     ///         .as_slice(),
    348     ///     )
    349     ///     .await?.set_ttl(100, 100).expect("could not set TTL");
    350     ///     Ok(())
    351     /// }
    352     /// ```
    353     #[inline]
    354     pub fn set_ttl(&self, ttl_ip6: u32, ttl_ip4: u32) -> Result<()> {
    355         self.ip6
    356             .set_ttl(ttl_ip6)
    357             .and_then(|()| self.ip4.set_ttl(ttl_ip4))
    358     }
    359     /// Gets the values of the `IP_TTL` option for both sockets.
    360     ///
    361     /// The first `u32` represents the `IP_TTL` value for the IPv6 socket and the second `u32` is the
    362     /// `IP_TTL` value for the IPv4 socket. For more information about this option, see [`Self::set_ttl`].
    363     ///
    364     /// # Errors
    365     ///
    366     /// Errors iff [`TcpListener::ttl`] does for either listener.
    367     ///
    368     /// # Examples
    369     ///
    370     /// ```no_run
    371     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    372     /// # use std::io::Result;
    373     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    374     /// #[tokio::main(flavor = "current_thread")]
    375     /// async fn main() -> Result<()> {
    376     ///     let listener = DualStackTcpListener::bind(
    377     ///         [
    378     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    379     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    380     ///         ]
    381     ///         .as_slice(),
    382     ///     )
    383     ///     .await?;
    384     ///     listener.set_ttl(100, 100).expect("could not set TTL");
    385     ///     assert_eq!(listener.ttl()?, (100, 100));
    386     ///     Ok(())
    387     /// }
    388     /// ```
    389     #[inline]
    390     pub fn ttl(&self) -> Result<(u32, u32)> {
    391         self.ip6
    392             .ttl()
    393             .and_then(|ip6| self.ip4.ttl().map(|ip4| (ip6, ip4)))
    394     }
    395 }
    396 pin_project! {
    397     /// `Future` returned by [`DualStackTcpListener::accept]`.
    398     struct AcceptFut<
    399         F: Future<Output = Result<(TcpStream, SocketAddr)>>,
    400         F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
    401     > {
    402         // Accept future for one `TcpListener`.
    403         #[pin]
    404         fut_1: F,
    405         // Accept future for the other `TcpListener`.
    406         #[pin]
    407         fut_2: F2,
    408     }
    409 }
    410 impl<
    411     F: Future<Output = Result<(TcpStream, SocketAddr)>>,
    412     F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
    413 > Future for AcceptFut<F, F2>
    414 {
    415     type Output = Result<(TcpStream, SocketAddr)>;
    416     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
    417         let this = self.project();
    418         match this.fut_1.poll(cx) {
    419             Poll::Ready(res) => Poll::Ready(res),
    420             Poll::Pending => this.fut_2.poll(cx),
    421         }
    422     }
    423 }
    424 impl Sealed for DualStackTcpListener {}
    425 impl Tcp for DualStackTcpListener {
    426     #[inline]
    427     async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
    428         match net::lookup_host(addr).await {
    429             Ok(socks) => {
    430                 let mut last_err = None;
    431                 let mut ip6_opt = None;
    432                 let mut ip4_opt = None;
    433                 for sock in socks {
    434                     match ip6_opt {
    435                         None => match ip4_opt {
    436                             None => {
    437                                 let is_ip6 = sock.is_ipv6();
    438                                 match TcpListener::bind(sock).await {
    439                                     Ok(ip) => {
    440                                         if is_ip6 {
    441                                             ip6_opt = Some(ip);
    442                                         } else {
    443                                             ip4_opt = Some(ip);
    444                                         }
    445                                     }
    446                                     Err(err) => last_err = Some(err),
    447                                 }
    448                             }
    449                             Some(ip4) => {
    450                                 if sock.is_ipv6() {
    451                                     match TcpListener::bind(sock).await {
    452                                         Ok(ip6) => {
    453                                             return Ok(Self {
    454                                                 ip6,
    455                                                 ip4,
    456                                                 ip6_first: AtomicBool::new(true),
    457                                             });
    458                                         }
    459                                         Err(err) => last_err = Some(err),
    460                                     }
    461                                 }
    462                                 ip4_opt = Some(ip4);
    463                             }
    464                         },
    465                         Some(ip6) => {
    466                             if sock.is_ipv4() {
    467                                 match TcpListener::bind(sock).await {
    468                                     Ok(ip4) => {
    469                                         return Ok(Self {
    470                                             ip6,
    471                                             ip4,
    472                                             ip6_first: AtomicBool::new(true),
    473                                         });
    474                                     }
    475                                     Err(err) => last_err = Some(err),
    476                                 }
    477                             }
    478                             ip6_opt = Some(ip6);
    479                         }
    480                     }
    481                 }
    482                 Err(last_err.unwrap_or_else(|| {
    483                     Error::new(
    484                         ErrorKind::InvalidInput,
    485                         "could not resolve to an IPv6 and IPv4 address",
    486                     )
    487                 }))
    488             }
    489             Err(err) => Err(err),
    490         }
    491     }
    492     #[inline]
    493     fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
    494         // The correctness of code does not depend on `self.ip6_first`; therefore
    495         // we elect for the most performant `Ordering`.
    496         if self.ip6_first.swap(false, Ordering::Relaxed) {
    497             AcceptFut {
    498                 fut_1: self.ip6.accept(),
    499                 fut_2: self.ip4.accept(),
    500             }
    501         } else {
    502             // The correctness of code does not depend on `self.ip6_first`; therefore
    503             // we elect for the most performant `Ordering`.
    504             self.ip6_first.store(true, Ordering::Relaxed);
    505             AcceptFut {
    506                 fut_1: self.ip4.accept(),
    507                 fut_2: self.ip6.accept(),
    508             }
    509         }
    510     }
    511     #[inline]
    512     fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
    513         // The correctness of code does not depend on `self.ip6_first`; therefore
    514         // we elect for the most performant `Ordering`.
    515         if self.ip6_first.swap(false, Ordering::Relaxed) {
    516             self.ip6.poll_accept(cx)
    517         } else {
    518             // The correctness of code does not depend on `self.ip6_first`; therefore
    519             // we elect for the most performant `Ordering`.
    520             self.ip6_first.store(true, Ordering::Relaxed);
    521             self.ip4.poll_accept(cx)
    522         }
    523     }
    524 }