tokio_dual_stack

Dual-stack TcpListener.
git clone https://git.philomathiclife.com/repos/tokio_dual_stack
Log | Files | Refs | README

lib.rs (21080B)


      1 //! [![git]](https://git.philomathiclife.com/tokio_dual_stack/log.html) [![crates-io]](https://crates.io/crates/tokio_dual_stack) [![docs-rs]](crate)
      2 //!
      3 //! [git]: https://git.philomathiclife.com/git_badge.svg
      4 //! [crates-io]: https://img.shields.io/badge/crates.io-fc8d62?style=for-the-badge&labelColor=555555&logo=rust
      5 //! [docs-rs]: https://img.shields.io/badge/docs.rs-66c2a5?style=for-the-badge&labelColor=555555&logo=docs.rs
      6 //!
      7 //! `tokio_dual_stack` is a library that adds a "dual-stack" [`TcpListener`].
      8 //!
      9 //! ## Why is this useful?
     10 //!
     11 //! Only certain platforms offer the ability for one socket to handle both IPv6 and IPv4 requests
     12 //! (e.g., OpenBSD does not). For the platforms that do, it is often dependent on runtime configuration
     13 //! (e.g., [`IPV6_V6ONLY`](https://www.man7.org/linux/man-pages/man7/ipv6.7.html)). Additionally those platforms
     14 //! that support it often require the "wildcard" IPv6 address to be used (i.e., `::`) which has the unfortunate
     15 //! consequence of preventing other services from using the same protocol port.
     16 //!
     17 //! There are a few ways to work around this issue. One is to deploy the same service twice: one that uses
     18 //! an IPv6 socket and the other that uses an IPv4 socket. This can complicate deployments (e.g., the application
     19 //! may not have been written with the expectation that multiple deployments could be running at the same time) in
     20 //! addition to using more resources. Another is for the application to manually handle each socket (e.g.,
     21 //! [`select`](https://docs.rs/tokio/latest/tokio/macro.select.html)/[`join`](https://docs.rs/tokio/latest/tokio/macro.join.html)
     22 //! each [`TcpListener::accept`]).
     23 //!
     24 //! [`DualStackTcpListener`] chooses an implementation similar to what the equivalent `select` would do while
     25 //! also ensuring that one socket does not "starve" another by ensuring each socket is fairly given an opportunity
     26 //! to `TcpListener::accept` a connection. This has the nice benefit of having a similar API to what a single
     27 //! `TcpListener` would have as well as having similar performance to a socket that does handle both IPv6 and
     28 //! IPv4 requests.
     29 #![expect(
     30     clippy::multiple_crate_versions,
     31     reason = "dependencies haven't updated to most recent crates"
     32 )]
     33 use core::{
     34     net::{SocketAddr, SocketAddrV4, SocketAddrV6},
     35     pin::Pin,
     36     sync::atomic::{AtomicBool, Ordering},
     37     task::{Context, Poll},
     38 };
     39 use pin_project_lite::pin_project;
     40 use std::io::{Error, ErrorKind, Result};
     41 use tokio::net::{self, TcpListener, TcpSocket, TcpStream, ToSocketAddrs};
     42 /// Prevents [`Sealed`] from being publicly implementable.
     43 mod private {
     44     /// Marker trait to prevent [`super::Tcp`] from being publicly implementable.
     45     #[expect(unnameable_types, reason = "want Tcp to be 'sealed'")]
     46     pub trait Sealed {}
     47 }
     48 use private::Sealed;
     49 /// TCP "listener".
     50 ///
     51 /// This `trait` is sealed and cannot be implemented for types outside of `tokio_dual_stack`.
     52 ///
     53 /// This exists primarily as a way to define type constructors or polymorphic functions
     54 /// that can user either a [`TcpListener`] or [`DualStackTcpListener`].
     55 ///
     56 /// # Examples
     57 ///
     58 /// ```no_run
     59 /// # use core::convert::Infallible;
     60 /// # use tokio_dual_stack::Tcp;
     61 /// async fn main_loop<T: Tcp>(listener: T) -> Infallible {
     62 ///     loop {
     63 ///         match listener.accept().await {
     64 ///             Ok((_, socket)) => println!("Client socket: {socket}"),
     65 ///             Err(e) => println!("TCP connection failure: {e}"),
     66 ///         }
     67 ///     }
     68 /// }
     69 /// ```
     70 pub trait Tcp: Sealed + Sized {
     71     /// Creates a new TCP listener, which will be bound to the specified address(es).
     72     ///
     73     /// The returned listener is ready for accepting connections.
     74     ///
     75     /// Binding with a port number of 0 will request that the OS assigns a port to this listener.
     76     /// The port allocated can be queried via the `local_addr` method.
     77     ///
     78     /// The address type can be any implementor of the [`ToSocketAddrs`] trait. If `addr` yields
     79     /// multiple addresses, bind will be attempted with each of the addresses until one succeeds
     80     /// and returns the listener. If none of the addresses succeed in creating a listener, the
     81     /// error returned from the last attempt (the last address) is returned.
     82     ///
     83     /// This function sets the `SO_REUSEADDR` option on the socket.
     84     ///
     85     /// # Examples
     86     ///
     87     /// ```no_run
     88     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
     89     /// # use std::io::Result;
     90     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
     91     /// #[tokio::main(flavor = "current_thread")]
     92     /// async fn main() -> Result<()> {
     93     ///     let listener = DualStackTcpListener::bind(
     94     ///         [
     95     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
     96     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
     97     ///         ]
     98     ///         .as_slice(),
     99     ///     )
    100     ///     .await?;
    101     ///     Ok(())
    102     /// }
    103     /// ```
    104     fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>>;
    105     /// Accepts a new incoming connection from this listener.
    106     ///
    107     /// This function will yield once a new TCP connection is established. When established,
    108     /// the corresponding `TcpStream` and the remote peer’s address will be returned.
    109     ///
    110     /// # Cancel safety
    111     ///
    112     /// This method is cancel safe. If the method is used as the event in a
    113     /// [`tokio::select!`](https://docs.rs/tokio/latest/tokio/macro.select.html)
    114     /// statement and some other branch completes first, then it is guaranteed that no new
    115     /// connections were accepted by this method.
    116     ///
    117     /// # Examples
    118     ///
    119     /// ```no_run
    120     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    121     /// # use std::io::Result;
    122     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    123     /// #[tokio::main(flavor = "current_thread")]
    124     /// async fn main() -> Result<()> {
    125     ///     match DualStackTcpListener::bind(
    126     ///         [
    127     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    128     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    129     ///         ]
    130     ///         .as_slice(),
    131     ///     )
    132     ///     .await?.accept().await {
    133     ///         Ok((_, addr)) => println!("new client: {addr}"),
    134     ///         Err(e) => println!("couldn't get client: {e}"),
    135     ///     }
    136     ///     Ok(())
    137     /// }
    138     /// ```
    139     fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync;
    140     /// Polls to accept a new incoming connection to this listener.
    141     ///
    142     /// If there is no connection to accept, `Poll::Pending` is returned and the current task will be notified by
    143     /// a waker. Note that on multiple calls to `poll_accept`, only the `Waker` from the `Context` passed to the
    144     /// most recent call is scheduled to receive a wakeup.
    145     fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>>;
    146 }
    147 impl Sealed for TcpListener {}
    148 impl Tcp for TcpListener {
    149     #[inline]
    150     fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>> {
    151         Self::bind(addr)
    152     }
    153     #[inline]
    154     fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
    155         self.accept()
    156     }
    157     #[inline]
    158     fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
    159         self.poll_accept(cx)
    160     }
    161 }
    162 /// "Dual-stack" TCP listener.
    163 ///
    164 /// IPv6 and IPv4 TCP listener.
    165 #[derive(Debug)]
    166 pub struct DualStackTcpListener {
    167     /// IPv6 TCP listener.
    168     ip6: TcpListener,
    169     /// IPv4 TCP listener.
    170     ip4: TcpListener,
    171     /// `true` iff [`Self::ip6::accept`] should be `poll`ed first; otherwise [`Self::ip4::accept`] is `poll`ed
    172     /// first.
    173     ///
    174     /// This exists to prevent one IP version from "starving" another. Each time [`Self::accept`] or
    175     /// [`Self::poll_accept`] is called, it's overwritten with the opposite `bool`.
    176     ///
    177     /// Note we could make this a `core::cell::Cell`; but for maximal flexibility and consistency with `TcpListener`,
    178     /// we use an `AtomicBool`. This among other things means `DualStackTcpListener` will implement `Sync`.
    179     ip6_first: AtomicBool,
    180 }
    181 impl DualStackTcpListener {
    182     /// Creates `Self` using the [`TcpListener`]s returned from [`TcpSocket::listen`].
    183     ///
    184     /// [`Self::bind`] is useful when the behavior of [`TcpListener::bind`] is sufficient; however if the underlying
    185     /// `TcpSocket`s need to be configured differently, then one must call this function instead.
    186     ///
    187     /// # Errors
    188     ///
    189     /// Errors iff [`TcpSocket::local_addr`] does for either socket, the underlying sockets use the same IP version,
    190     /// or [`TcpSocket::listen`] errors for either socket.
    191     ///
    192     /// Note on Windows-based platforms `TcpSocket::local_addr` will error if [`TcpSocket::bind`] was not called.
    193     ///
    194     /// # Examples
    195     ///
    196     /// ```no_run
    197     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    198     /// # use std::io::Result;
    199     /// # use tokio_dual_stack::DualStackTcpListener;
    200     /// # use tokio::net::TcpSocket;
    201     /// #[tokio::main(flavor = "current_thread")]
    202     /// async fn main() -> Result<()> {
    203     ///     let ip6 = TcpSocket::new_v6()?;
    204     ///     ip6.bind(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)))?;
    205     ///     let ip4 = TcpSocket::new_v4()?;
    206     ///     ip4.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)))?;
    207     ///     let listener = DualStackTcpListener::from_sockets((ip6, 1024), (ip4, 1024))?;
    208     ///     Ok(())
    209     /// }
    210     /// ```
    211     #[inline]
    212     pub fn from_sockets(
    213         (socket_1, backlog_1): (TcpSocket, u32),
    214         (socket_2, backlog_2): (TcpSocket, u32),
    215     ) -> Result<Self> {
    216         socket_1.local_addr().and_then(|sock| {
    217             socket_2.local_addr().and_then(|sock_2| {
    218                 if sock.is_ipv6() {
    219                     if sock_2.is_ipv4() {
    220                         socket_1.listen(backlog_1).and_then(|ip6| {
    221                             socket_2.listen(backlog_2).map(|ip4| Self {
    222                                 ip6,
    223                                 ip4,
    224                                 ip6_first: AtomicBool::new(true),
    225                             })
    226                         })
    227                     } else {
    228                         Err(Error::new(
    229                             ErrorKind::InvalidData,
    230                             "TcpSockets are the same IP version",
    231                         ))
    232                     }
    233                 } else if sock_2.is_ipv6() {
    234                     socket_1.listen(backlog_1).and_then(|ip4| {
    235                         socket_2.listen(backlog_2).map(|ip6| Self {
    236                             ip6,
    237                             ip4,
    238                             ip6_first: AtomicBool::new(true),
    239                         })
    240                     })
    241                 } else {
    242                     Err(Error::new(
    243                         ErrorKind::InvalidData,
    244                         "TcpSockets are the same IP version",
    245                     ))
    246                 }
    247             })
    248         })
    249     }
    250     /// Returns the local address of each socket that the listeners are bound to.
    251     ///
    252     /// This can be useful, for example, when binding to port 0 to figure out which port was actually bound.
    253     ///
    254     /// # Errors
    255     ///
    256     /// Errors iff [`TcpListener::local_addr`] does for either listener.
    257     ///
    258     /// # Examples
    259     ///
    260     /// ```no_run
    261     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    262     /// # use std::io::Result;
    263     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    264     /// #[tokio::main(flavor = "current_thread")]
    265     /// async fn main() -> Result<()> {
    266     ///     let ip6 = SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0);
    267     ///     let ip4 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080);
    268     ///     assert_eq!(
    269     ///         DualStackTcpListener::bind([SocketAddr::V6(ip6), SocketAddr::V4(ip4)].as_slice())
    270     ///             .await?
    271     ///             .local_addr()?,
    272     ///         (ip6, ip4)
    273     ///     );
    274     ///     Ok(())
    275     /// }
    276     /// ```
    277     #[expect(clippy::unreachable, reason = "we want to crash when there is a bug")]
    278     #[inline]
    279     pub fn local_addr(&self) -> Result<(SocketAddrV6, SocketAddrV4)> {
    280         self.ip6.local_addr().and_then(|ip6| {
    281             self.ip4.local_addr().map(|ip4| {
    282                 (
    283                     if let SocketAddr::V6(sock6) = ip6 {
    284                         sock6
    285                     } else {
    286                         unreachable!("there is a bug in DualStackTcpListener::bind")
    287                     },
    288                     if let SocketAddr::V4(sock4) = ip4 {
    289                         sock4
    290                     } else {
    291                         unreachable!("there is a bug in DualStackTcpListener::bind")
    292                     },
    293                 )
    294             })
    295         })
    296     }
    297     /// Sets the value for the `IP_TTL` option on both sockets.
    298     ///
    299     /// This value sets the time-to-live field that is used in every packet sent from each socket.
    300     /// `ttl_ip6` is the `IP_TTL` value for the IPv6 socket and `ttl_ip4` is the `IP_TTL` value for the
    301     /// IPv4 socket.
    302     ///
    303     /// # Errors
    304     ///
    305     /// Errors iff [`TcpListener::set_ttl`] does for either listener.
    306     ///
    307     /// # Examples
    308     ///
    309     /// ```no_run
    310     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    311     /// # use std::io::Result;
    312     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    313     /// #[tokio::main(flavor = "current_thread")]
    314     /// async fn main() -> Result<()> {
    315     ///     DualStackTcpListener::bind(
    316     ///         [
    317     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    318     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    319     ///         ]
    320     ///         .as_slice(),
    321     ///     )
    322     ///     .await?.set_ttl(100, 100).expect("could not set TTL");
    323     ///     Ok(())
    324     /// }
    325     /// ```
    326     #[inline]
    327     pub fn set_ttl(&self, ttl_ip6: u32, ttl_ip4: u32) -> Result<()> {
    328         self.ip6
    329             .set_ttl(ttl_ip6)
    330             .and_then(|()| self.ip4.set_ttl(ttl_ip4))
    331     }
    332     /// Gets the values of the `IP_TTL` option for both sockets.
    333     ///
    334     /// The first `u32` represents the `IP_TTL` value for the IPv6 socket and the second `u32` is the
    335     /// `IP_TTL` value for the IPv4 socket. For more information about this option, see [`Self::set_ttl`].
    336     ///
    337     /// # Errors
    338     ///
    339     /// Errors iff [`TcpListener::ttl`] does for either listener.
    340     ///
    341     /// # Examples
    342     ///
    343     /// ```no_run
    344     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    345     /// # use std::io::Result;
    346     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    347     /// #[tokio::main(flavor = "current_thread")]
    348     /// async fn main() -> Result<()> {
    349     ///     let listener = DualStackTcpListener::bind(
    350     ///         [
    351     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    352     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    353     ///         ]
    354     ///         .as_slice(),
    355     ///     )
    356     ///     .await?;
    357     ///     listener.set_ttl(100, 100).expect("could not set TTL");
    358     ///     assert_eq!(listener.ttl()?, (100, 100));
    359     ///     Ok(())
    360     /// }
    361     /// ```
    362     #[inline]
    363     pub fn ttl(&self) -> Result<(u32, u32)> {
    364         self.ip6
    365             .ttl()
    366             .and_then(|ip6| self.ip4.ttl().map(|ip4| (ip6, ip4)))
    367     }
    368 }
    369 pin_project! {
    370     /// `Future` returned by [`DualStackTcpListener::accept]`.
    371     struct AcceptFut<
    372         F: Future<Output = Result<(TcpStream, SocketAddr)>>,
    373         F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
    374     > {
    375         // Accept future for one `TcpListener`.
    376         #[pin]
    377         fut_1: F,
    378         // Accept future for the other `TcpListener`.
    379         #[pin]
    380         fut_2: F2,
    381     }
    382 }
    383 impl<
    384     F: Future<Output = Result<(TcpStream, SocketAddr)>>,
    385     F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
    386 > Future for AcceptFut<F, F2>
    387 {
    388     type Output = Result<(TcpStream, SocketAddr)>;
    389     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
    390         let this = self.project();
    391         match this.fut_1.poll(cx) {
    392             Poll::Ready(res) => Poll::Ready(res),
    393             Poll::Pending => this.fut_2.poll(cx),
    394         }
    395     }
    396 }
    397 impl Sealed for DualStackTcpListener {}
    398 impl Tcp for DualStackTcpListener {
    399     #[inline]
    400     async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
    401         match net::lookup_host(addr).await {
    402             Ok(socks) => {
    403                 let mut last_err = None;
    404                 let mut ip6_opt = None;
    405                 let mut ip4_opt = None;
    406                 for sock in socks {
    407                     match ip6_opt {
    408                         None => match ip4_opt {
    409                             None => {
    410                                 let is_ip6 = sock.is_ipv6();
    411                                 match TcpListener::bind(sock).await {
    412                                     Ok(ip) => {
    413                                         if is_ip6 {
    414                                             ip6_opt = Some(ip);
    415                                         } else {
    416                                             ip4_opt = Some(ip);
    417                                         }
    418                                     }
    419                                     Err(err) => last_err = Some(err),
    420                                 }
    421                             }
    422                             Some(ip4) => {
    423                                 if sock.is_ipv6() {
    424                                     match TcpListener::bind(sock).await {
    425                                         Ok(ip6) => {
    426                                             return Ok(Self {
    427                                                 ip6,
    428                                                 ip4,
    429                                                 ip6_first: AtomicBool::new(true),
    430                                             });
    431                                         }
    432                                         Err(err) => last_err = Some(err),
    433                                     }
    434                                 }
    435                                 ip4_opt = Some(ip4);
    436                             }
    437                         },
    438                         Some(ip6) => {
    439                             if sock.is_ipv4() {
    440                                 match TcpListener::bind(sock).await {
    441                                     Ok(ip4) => {
    442                                         return Ok(Self {
    443                                             ip6,
    444                                             ip4,
    445                                             ip6_first: AtomicBool::new(true),
    446                                         });
    447                                     }
    448                                     Err(err) => last_err = Some(err),
    449                                 }
    450                             }
    451                             ip6_opt = Some(ip6);
    452                         }
    453                     }
    454                 }
    455                 Err(last_err.unwrap_or_else(|| {
    456                     Error::new(
    457                         ErrorKind::InvalidInput,
    458                         "could not resolve to an IPv6 and IPv4 address",
    459                     )
    460                 }))
    461             }
    462             Err(err) => Err(err),
    463         }
    464     }
    465     #[inline]
    466     fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
    467         // The correctness of code does not depend on `self.ip6_first`; therefore
    468         // we elect for the most performant `Ordering`.
    469         if self.ip6_first.swap(false, Ordering::Relaxed) {
    470             AcceptFut {
    471                 fut_1: self.ip6.accept(),
    472                 fut_2: self.ip4.accept(),
    473             }
    474         } else {
    475             // The correctness of code does not depend on `self.ip6_first`; therefore
    476             // we elect for the most performant `Ordering`.
    477             self.ip6_first.store(true, Ordering::Relaxed);
    478             AcceptFut {
    479                 fut_1: self.ip4.accept(),
    480                 fut_2: self.ip6.accept(),
    481             }
    482         }
    483     }
    484     #[inline]
    485     fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
    486         // The correctness of code does not depend on `self.ip6_first`; therefore
    487         // we elect for the most performant `Ordering`.
    488         if self.ip6_first.swap(false, Ordering::Relaxed) {
    489             self.ip6.poll_accept(cx)
    490         } else {
    491             // The correctness of code does not depend on `self.ip6_first`; therefore
    492             // we elect for the most performant `Ordering`.
    493             self.ip6_first.store(true, Ordering::Relaxed);
    494             self.ip4.poll_accept(cx)
    495         }
    496     }
    497 }