tokio_dual_stack

Dual-stack TcpListener.
git clone https://git.philomathiclife.com/repos/tokio_dual_stack
Log | Files | Refs | README

lib.rs (20963B)


      1 //! [![git]](https://git.philomathiclife.com/tokio_dual_stack/log.html) [![crates-io]](https://crates.io/crates/tokio_dual_stack) [![docs-rs]](crate)
      2 //!
      3 //! [git]: https://git.philomathiclife.com/git_badge.svg
      4 //! [crates-io]: https://img.shields.io/badge/crates.io-fc8d62?style=for-the-badge&labelColor=555555&logo=rust
      5 //! [docs-rs]: https://img.shields.io/badge/docs.rs-66c2a5?style=for-the-badge&labelColor=555555&logo=docs.rs
      6 //!
      7 //! `tokio_dual_stack` is a library that adds a "dual-stack" [`TcpListener`].
      8 //!
      9 //! ## Why is this useful?
     10 //!
     11 //! Only certain platforms offer the ability for one socket to handle both IPv6 and IPv4 requests
     12 //! (e.g., OpenBSD does not). For the platforms that do, it is often dependent on runtime configuration
     13 //! (e.g., [`IPV6_V6ONLY`](https://www.man7.org/linux/man-pages/man7/ipv6.7.html)). Additionally those platforms
     14 //! that support it often require the "wildcard" IPv6 address to be used (i.e., `::`) which has the unfortunate
     15 //! consequence of preventing other services from using the same protocol port.
     16 //!
     17 //! There are a few ways to work around this issue. One is to deploy the same service twice: one that uses
     18 //! an IPv6 socket and the other that uses an IPv4 socket. This can complicate deployments (e.g., the application
     19 //! may not have been written with the expectation that multiple deployments could be running at the same time) in
     20 //! addition to using more resources. Another is for the application to manually handle each socket (e.g.,
     21 //! [`select`](https://docs.rs/tokio/latest/tokio/macro.select.html)/[`join`](https://docs.rs/tokio/latest/tokio/macro.join.html)
     22 //! each [`TcpListener::accept`]).
     23 //!
     24 //! [`DualStackTcpListener`] chooses an implementation similar to what the equivalent `select` would do while
     25 //! also ensuring that one socket does not "starve" another by ensuring each socket is fairly given an opportunity
     26 //! to `TcpListener::accept` a connection. This has the nice benefit of having a similar API to what a single
     27 //! `TcpListener` would have as well as having similar performance to a socket that does handle both IPv6 and
     28 //! IPv4 requests.
     29 use core::{
     30     net::{SocketAddr, SocketAddrV4, SocketAddrV6},
     31     pin::Pin,
     32     sync::atomic::{AtomicBool, Ordering},
     33     task::{Context, Poll},
     34 };
     35 use pin_project_lite::pin_project;
     36 use std::io::{Error, ErrorKind, Result};
     37 use tokio::net::{self, TcpListener, TcpSocket, TcpStream, ToSocketAddrs};
     38 /// Prevents [`Sealed`] from being publicly implementable.
     39 mod private {
     40     /// Marker trait to prevent [`super::Tcp`] from being publicly implementable.
     41     #[expect(unnameable_types, reason = "want Tcp to be 'sealed'")]
     42     pub trait Sealed {}
     43 }
     44 use private::Sealed;
     45 /// TCP "listener".
     46 ///
     47 /// This `trait` is sealed and cannot be implemented for types outside of `tokio_dual_stack`.
     48 ///
     49 /// This exists primarily as a way to define type constructors or polymorphic functions
     50 /// that can user either a [`TcpListener`] or [`DualStackTcpListener`].
     51 ///
     52 /// # Examples
     53 ///
     54 /// ```no_run
     55 /// # use core::convert::Infallible;
     56 /// # use tokio_dual_stack::Tcp;
     57 /// async fn main_loop<T: Tcp>(listener: T) -> Infallible {
     58 ///     loop {
     59 ///         match listener.accept().await {
     60 ///             Ok((_, socket)) => println!("Client socket: {socket}"),
     61 ///             Err(e) => println!("TCP connection failure: {e}"),
     62 ///         }
     63 ///     }
     64 /// }
     65 /// ```
     66 pub trait Tcp: Sealed + Sized {
     67     /// Creates a new TCP listener, which will be bound to the specified address(es).
     68     ///
     69     /// The returned listener is ready for accepting connections.
     70     ///
     71     /// Binding with a port number of 0 will request that the OS assigns a port to this listener.
     72     /// The port allocated can be queried via the `local_addr` method.
     73     ///
     74     /// The address type can be any implementor of the [`ToSocketAddrs`] trait. If `addr` yields
     75     /// multiple addresses, bind will be attempted with each of the addresses until one succeeds
     76     /// and returns the listener. If none of the addresses succeed in creating a listener, the
     77     /// error returned from the last attempt (the last address) is returned.
     78     ///
     79     /// This function sets the `SO_REUSEADDR` option on the socket.
     80     ///
     81     /// # Examples
     82     ///
     83     /// ```no_run
     84     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
     85     /// # use std::io::Result;
     86     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
     87     /// #[tokio::main(flavor = "current_thread")]
     88     /// async fn main() -> Result<()> {
     89     ///     let listener = DualStackTcpListener::bind(
     90     ///         [
     91     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
     92     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
     93     ///         ]
     94     ///         .as_slice(),
     95     ///     )
     96     ///     .await?;
     97     ///     Ok(())
     98     /// }
     99     /// ```
    100     fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>>;
    101     /// Accepts a new incoming connection from this listener.
    102     ///
    103     /// This function will yield once a new TCP connection is established. When established,
    104     /// the corresponding `TcpStream` and the remote peer’s address will be returned.
    105     ///
    106     /// # Cancel safety
    107     ///
    108     /// This method is cancel safe. If the method is used as the event in a
    109     /// [`tokio::select!`](https://docs.rs/tokio/latest/tokio/macro.select.html)
    110     /// statement and some other branch completes first, then it is guaranteed that no new
    111     /// connections were accepted by this method.
    112     ///
    113     /// # Examples
    114     ///
    115     /// ```no_run
    116     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    117     /// # use std::io::Result;
    118     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    119     /// #[tokio::main(flavor = "current_thread")]
    120     /// async fn main() -> Result<()> {
    121     ///     match DualStackTcpListener::bind(
    122     ///         [
    123     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    124     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    125     ///         ]
    126     ///         .as_slice(),
    127     ///     )
    128     ///     .await?.accept().await {
    129     ///         Ok((_, addr)) => println!("new client: {addr}"),
    130     ///         Err(e) => println!("couldn't get client: {e}"),
    131     ///     }
    132     ///     Ok(())
    133     /// }
    134     /// ```
    135     fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync;
    136     /// Polls to accept a new incoming connection to this listener.
    137     ///
    138     /// If there is no connection to accept, `Poll::Pending` is returned and the current task will be notified by
    139     /// a waker. Note that on multiple calls to `poll_accept`, only the `Waker` from the `Context` passed to the
    140     /// most recent call is scheduled to receive a wakeup.
    141     fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>>;
    142 }
    143 impl Sealed for TcpListener {}
    144 impl Tcp for TcpListener {
    145     #[inline]
    146     fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>> {
    147         Self::bind(addr)
    148     }
    149     #[inline]
    150     fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
    151         self.accept()
    152     }
    153     #[inline]
    154     fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
    155         self.poll_accept(cx)
    156     }
    157 }
    158 /// "Dual-stack" TCP listener.
    159 ///
    160 /// IPv6 and IPv4 TCP listener.
    161 #[derive(Debug)]
    162 pub struct DualStackTcpListener {
    163     /// IPv6 TCP listener.
    164     ip6: TcpListener,
    165     /// IPv4 TCP listener.
    166     ip4: TcpListener,
    167     /// `true` iff [`Self::ip6::accept`] should be `poll`ed first; otherwise [`Self::ip4::accept`] is `poll`ed
    168     /// first.
    169     ///
    170     /// This exists to prevent one IP version from "starving" another. Each time [`Self::accept`] or
    171     /// [`Self::poll_accept`] is called, it's overwritten with the opposite `bool`.
    172     ///
    173     /// Note we could make this a `core::cell::Cell`; but for maximal flexibility and consistency with `TcpListener`,
    174     /// we use an `AtomicBool`. This among other things means `DualStackTcpListener` will implement `Sync`.
    175     ip6_first: AtomicBool,
    176 }
    177 impl DualStackTcpListener {
    178     /// Creates `Self` using the [`TcpListener`]s returned from [`TcpSocket::listen`].
    179     ///
    180     /// [`Self::bind`] is useful when the behavior of [`TcpListener::bind`] is sufficient; however if the underlying
    181     /// `TcpSocket`s need to be configured differently, then one must call this function instead.
    182     ///
    183     /// # Errors
    184     ///
    185     /// Errors iff [`TcpSocket::local_addr`] does for either socket, the underlying sockets use the same IP version,
    186     /// or [`TcpSocket::listen`] errors for either socket.
    187     ///
    188     /// Note on Windows-based platforms `TcpSocket::local_addr` will error if [`TcpSocket::bind`] was not called.
    189     ///
    190     /// # Examples
    191     ///
    192     /// ```no_run
    193     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    194     /// # use std::io::Result;
    195     /// # use tokio_dual_stack::DualStackTcpListener;
    196     /// # use tokio::net::TcpSocket;
    197     /// #[tokio::main(flavor = "current_thread")]
    198     /// async fn main() -> Result<()> {
    199     ///     let ip6 = TcpSocket::new_v6()?;
    200     ///     ip6.bind(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)))?;
    201     ///     let ip4 = TcpSocket::new_v4()?;
    202     ///     ip4.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)))?;
    203     ///     let listener = DualStackTcpListener::from_sockets((ip6, 1024), (ip4, 1024))?;
    204     ///     Ok(())
    205     /// }
    206     /// ```
    207     #[inline]
    208     pub fn from_sockets(
    209         (socket_1, backlog_1): (TcpSocket, u32),
    210         (socket_2, backlog_2): (TcpSocket, u32),
    211     ) -> Result<Self> {
    212         socket_1.local_addr().and_then(|sock| {
    213             socket_2.local_addr().and_then(|sock_2| {
    214                 if sock.is_ipv6() {
    215                     if sock_2.is_ipv4() {
    216                         socket_1.listen(backlog_1).and_then(|ip6| {
    217                             socket_2.listen(backlog_2).map(|ip4| Self {
    218                                 ip6,
    219                                 ip4,
    220                                 ip6_first: AtomicBool::new(true),
    221                             })
    222                         })
    223                     } else {
    224                         Err(Error::new(
    225                             ErrorKind::InvalidData,
    226                             "TcpSockets are the same IP version",
    227                         ))
    228                     }
    229                 } else if sock_2.is_ipv6() {
    230                     socket_1.listen(backlog_1).and_then(|ip4| {
    231                         socket_2.listen(backlog_2).map(|ip6| Self {
    232                             ip6,
    233                             ip4,
    234                             ip6_first: AtomicBool::new(true),
    235                         })
    236                     })
    237                 } else {
    238                     Err(Error::new(
    239                         ErrorKind::InvalidData,
    240                         "TcpSockets are the same IP version",
    241                     ))
    242                 }
    243             })
    244         })
    245     }
    246     /// Returns the local address of each socket that the listeners are bound to.
    247     ///
    248     /// This can be useful, for example, when binding to port 0 to figure out which port was actually bound.
    249     ///
    250     /// # Errors
    251     ///
    252     /// Errors iff [`TcpListener::local_addr`] does for either listener.
    253     ///
    254     /// # Examples
    255     ///
    256     /// ```no_run
    257     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    258     /// # use std::io::Result;
    259     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    260     /// #[tokio::main(flavor = "current_thread")]
    261     /// async fn main() -> Result<()> {
    262     ///     let ip6 = SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0);
    263     ///     let ip4 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080);
    264     ///     assert_eq!(
    265     ///         DualStackTcpListener::bind([SocketAddr::V6(ip6), SocketAddr::V4(ip4)].as_slice())
    266     ///             .await?
    267     ///             .local_addr()?,
    268     ///         (ip6, ip4)
    269     ///     );
    270     ///     Ok(())
    271     /// }
    272     /// ```
    273     #[expect(clippy::unreachable, reason = "we want to crash when there is a bug")]
    274     #[inline]
    275     pub fn local_addr(&self) -> Result<(SocketAddrV6, SocketAddrV4)> {
    276         self.ip6.local_addr().and_then(|ip6| {
    277             self.ip4.local_addr().map(|ip4| {
    278                 (
    279                     if let SocketAddr::V6(sock6) = ip6 {
    280                         sock6
    281                     } else {
    282                         unreachable!("there is a bug in DualStackTcpListener::bind")
    283                     },
    284                     if let SocketAddr::V4(sock4) = ip4 {
    285                         sock4
    286                     } else {
    287                         unreachable!("there is a bug in DualStackTcpListener::bind")
    288                     },
    289                 )
    290             })
    291         })
    292     }
    293     /// Sets the value for the `IP_TTL` option on both sockets.
    294     ///
    295     /// This value sets the time-to-live field that is used in every packet sent from each socket.
    296     /// `ttl_ip6` is the `IP_TTL` value for the IPv6 socket and `ttl_ip4` is the `IP_TTL` value for the
    297     /// IPv4 socket.
    298     ///
    299     /// # Errors
    300     ///
    301     /// Errors iff [`TcpListener::set_ttl`] does for either listener.
    302     ///
    303     /// # Examples
    304     ///
    305     /// ```no_run
    306     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    307     /// # use std::io::Result;
    308     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    309     /// #[tokio::main(flavor = "current_thread")]
    310     /// async fn main() -> Result<()> {
    311     ///     DualStackTcpListener::bind(
    312     ///         [
    313     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    314     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    315     ///         ]
    316     ///         .as_slice(),
    317     ///     )
    318     ///     .await?.set_ttl(100, 100).expect("could not set TTL");
    319     ///     Ok(())
    320     /// }
    321     /// ```
    322     #[inline]
    323     pub fn set_ttl(&self, ttl_ip6: u32, ttl_ip4: u32) -> Result<()> {
    324         self.ip6
    325             .set_ttl(ttl_ip6)
    326             .and_then(|()| self.ip4.set_ttl(ttl_ip4))
    327     }
    328     /// Gets the values of the `IP_TTL` option for both sockets.
    329     ///
    330     /// The first `u32` represents the `IP_TTL` value for the IPv6 socket and the second `u32` is the
    331     /// `IP_TTL` value for the IPv4 socket. For more information about this option, see [`Self::set_ttl`].
    332     ///
    333     /// # Errors
    334     ///
    335     /// Errors iff [`TcpListener::ttl`] does for either listener.
    336     ///
    337     /// # Examples
    338     ///
    339     /// ```no_run
    340     /// # use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
    341     /// # use std::io::Result;
    342     /// # use tokio_dual_stack::{DualStackTcpListener, Tcp as _};
    343     /// #[tokio::main(flavor = "current_thread")]
    344     /// async fn main() -> Result<()> {
    345     ///     let listener = DualStackTcpListener::bind(
    346     ///         [
    347     ///             SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0)),
    348     ///             SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
    349     ///         ]
    350     ///         .as_slice(),
    351     ///     )
    352     ///     .await?;
    353     ///     listener.set_ttl(100, 100).expect("could not set TTL");
    354     ///     assert_eq!(listener.ttl()?, (100, 100));
    355     ///     Ok(())
    356     /// }
    357     /// ```
    358     #[inline]
    359     pub fn ttl(&self) -> Result<(u32, u32)> {
    360         self.ip6
    361             .ttl()
    362             .and_then(|ip6| self.ip4.ttl().map(|ip4| (ip6, ip4)))
    363     }
    364 }
    365 pin_project! {
    366     /// `Future` returned by [`DualStackTcpListener::accept]`.
    367     struct AcceptFut<
    368         F: Future<Output = Result<(TcpStream, SocketAddr)>>,
    369         F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
    370     > {
    371         // Accept future for one `TcpListener`.
    372         #[pin]
    373         fut_1: F,
    374         // Accept future for the other `TcpListener`.
    375         #[pin]
    376         fut_2: F2,
    377     }
    378 }
    379 impl<
    380     F: Future<Output = Result<(TcpStream, SocketAddr)>>,
    381     F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
    382 > Future for AcceptFut<F, F2>
    383 {
    384     type Output = Result<(TcpStream, SocketAddr)>;
    385     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
    386         let this = self.project();
    387         match this.fut_1.poll(cx) {
    388             Poll::Ready(res) => Poll::Ready(res),
    389             Poll::Pending => this.fut_2.poll(cx),
    390         }
    391     }
    392 }
    393 impl Sealed for DualStackTcpListener {}
    394 impl Tcp for DualStackTcpListener {
    395     #[inline]
    396     async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
    397         match net::lookup_host(addr).await {
    398             Ok(socks) => {
    399                 let mut last_err = None;
    400                 let mut ip6_opt = None;
    401                 let mut ip4_opt = None;
    402                 for sock in socks {
    403                     match ip6_opt {
    404                         None => match ip4_opt {
    405                             None => {
    406                                 let is_ip6 = sock.is_ipv6();
    407                                 match TcpListener::bind(sock).await {
    408                                     Ok(ip) => {
    409                                         if is_ip6 {
    410                                             ip6_opt = Some(ip);
    411                                         } else {
    412                                             ip4_opt = Some(ip);
    413                                         }
    414                                     }
    415                                     Err(err) => last_err = Some(err),
    416                                 }
    417                             }
    418                             Some(ip4) => {
    419                                 if sock.is_ipv6() {
    420                                     match TcpListener::bind(sock).await {
    421                                         Ok(ip6) => {
    422                                             return Ok(Self {
    423                                                 ip6,
    424                                                 ip4,
    425                                                 ip6_first: AtomicBool::new(true),
    426                                             });
    427                                         }
    428                                         Err(err) => last_err = Some(err),
    429                                     }
    430                                 }
    431                                 ip4_opt = Some(ip4);
    432                             }
    433                         },
    434                         Some(ip6) => {
    435                             if sock.is_ipv4() {
    436                                 match TcpListener::bind(sock).await {
    437                                     Ok(ip4) => {
    438                                         return Ok(Self {
    439                                             ip6,
    440                                             ip4,
    441                                             ip6_first: AtomicBool::new(true),
    442                                         });
    443                                     }
    444                                     Err(err) => last_err = Some(err),
    445                                 }
    446                             }
    447                             ip6_opt = Some(ip6);
    448                         }
    449                     }
    450                 }
    451                 Err(last_err.unwrap_or_else(|| {
    452                     Error::new(
    453                         ErrorKind::InvalidInput,
    454                         "could not resolve to an IPv6 and IPv4 address",
    455                     )
    456                 }))
    457             }
    458             Err(err) => Err(err),
    459         }
    460     }
    461     #[inline]
    462     fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
    463         // The correctness of code does not depend on `self.ip6_first`; therefore
    464         // we elect for the most performant `Ordering`.
    465         if self.ip6_first.swap(false, Ordering::Relaxed) {
    466             AcceptFut {
    467                 fut_1: self.ip6.accept(),
    468                 fut_2: self.ip4.accept(),
    469             }
    470         } else {
    471             // The correctness of code does not depend on `self.ip6_first`; therefore
    472             // we elect for the most performant `Ordering`.
    473             self.ip6_first.store(true, Ordering::Relaxed);
    474             AcceptFut {
    475                 fut_1: self.ip4.accept(),
    476                 fut_2: self.ip6.accept(),
    477             }
    478         }
    479     }
    480     #[inline]
    481     fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
    482         // The correctness of code does not depend on `self.ip6_first`; therefore
    483         // we elect for the most performant `Ordering`.
    484         if self.ip6_first.swap(false, Ordering::Relaxed) {
    485             self.ip6.poll_accept(cx)
    486         } else {
    487             // The correctness of code does not depend on `self.ip6_first`; therefore
    488             // we elect for the most performant `Ordering`.
    489             self.ip6_first.store(true, Ordering::Relaxed);
    490             self.ip4.poll_accept(cx)
    491         }
    492     }
    493 }