darkfi/net/transport/
tcp.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{io, sync::Arc, time::Duration};
20
21use async_trait::async_trait;
22use futures::{
23    future::{select, Either},
24    pin_mut,
25};
26use futures_rustls::{TlsAcceptor, TlsStream};
27use smol::{
28    lock::OnceCell,
29    net::{SocketAddr, TcpListener as SmolTcpListener, TcpStream},
30    Async, Timer,
31};
32use socket2::{Domain, Socket, TcpKeepalive, Type};
33use tracing::debug;
34use url::Url;
35
36use super::{PtListener, PtStream};
37
38trait SocketExt {
39    fn enable_reuse_port(&self) -> io::Result<()>;
40}
41
42impl SocketExt for Socket {
43    fn enable_reuse_port(&self) -> io::Result<()> {
44        #[cfg(target_family = "unix")]
45        self.set_reuse_port(true)?;
46
47        // On Windows SO_REUSEPORT means the same thing as SO_REUSEADDR
48        #[cfg(target_family = "windows")]
49        self.set_reuse_address(true)?;
50
51        Ok(())
52    }
53}
54
55/// TCP Dialer implementation
56#[derive(Debug, Clone)]
57pub struct TcpDialer {
58    /// TTL to set for opened sockets, or `None` for default.
59    ttl: Option<u32>,
60}
61
62impl TcpDialer {
63    /// Instantiate a new [`TcpDialer`] with optional TTL.
64    pub(crate) async fn new(ttl: Option<u32>) -> io::Result<Self> {
65        Ok(Self { ttl })
66    }
67
68    /// Internal helper function to create a TCP socket.
69    async fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<Socket> {
70        let domain = if socket_addr.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 };
71        let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
72
73        if socket_addr.is_ipv6() {
74            socket.set_only_v6(true)?;
75        }
76
77        if let Some(ttl) = self.ttl {
78            socket.set_ttl_v4(ttl)?;
79        }
80
81        socket.set_tcp_nodelay(true)?;
82        let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(20));
83        socket.set_tcp_keepalive(&keepalive)?;
84        socket.enable_reuse_port()?;
85
86        Ok(socket)
87    }
88
89    /// Internal dial function
90    pub(crate) async fn do_dial(
91        &self,
92        socket_addr: SocketAddr,
93        timeout: Option<Duration>,
94    ) -> io::Result<TcpStream> {
95        debug!(target: "net::tcp::do_dial", "Dialing {socket_addr} with TCP...");
96        let socket = self.create_socket(socket_addr).await?;
97
98        socket.set_nonblocking(true)?;
99
100        // Sync start socket connect. A WouldBlock error means this
101        // connection is in progress.
102        match socket.connect(&socket_addr.into()) {
103            Ok(()) => {}
104            Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
105            Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
106            Err(err) => return Err(err),
107        };
108
109        let stream = Async::new_nonblocking(std::net::TcpStream::from(socket))?;
110
111        // Wait until the async object becomes writable.
112        let connect = async move {
113            stream.writable().await?;
114            match stream.get_ref().take_error()? {
115                Some(err) => Err(err),
116                None => Ok(stream),
117            }
118        };
119
120        // If a timeout is configured, run both the connect and timeout
121        // futures and return whatever finishes first. Otherwise wait on
122        // the connect future.
123        match timeout {
124            Some(t) => {
125                let timeout = Timer::after(t);
126                pin_mut!(timeout);
127                pin_mut!(connect);
128
129                match select(connect, timeout).await {
130                    Either::Left((Ok(stream), _)) => Ok(TcpStream::from(stream)),
131                    Either::Left((Err(e), _)) => Err(e),
132                    Either::Right((_, _)) => Err(io::ErrorKind::TimedOut.into()),
133                }
134            }
135            None => {
136                let stream = connect.await?;
137                Ok(TcpStream::from(stream))
138            }
139        }
140    }
141}
142
143/// TCP Listener implementation
144#[derive(Debug, Clone)]
145pub struct TcpListener {
146    /// Size of the listen backlog for listen sockets
147    backlog: i32,
148    /// When the user puts a port of 0, the OS will assign a random port.
149    /// We get it from the listener so we know what the true endpoint is.
150    pub port: Arc<OnceCell<u16>>,
151}
152
153impl TcpListener {
154    /// Instantiate a new [`TcpListener`] with given backlog size.
155    pub async fn new(backlog: i32) -> io::Result<Self> {
156        Ok(Self { backlog, port: Arc::new(OnceCell::new()) })
157    }
158
159    /// Internal helper function to create a TCP socket.
160    async fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<Socket> {
161        let domain = if socket_addr.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 };
162        let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
163
164        if socket_addr.is_ipv6() {
165            socket.set_only_v6(true)?;
166        }
167
168        socket.set_tcp_nodelay(true)?;
169        let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(20));
170        socket.set_tcp_keepalive(&keepalive)?;
171        socket.enable_reuse_port()?;
172
173        Ok(socket)
174    }
175
176    /// Internal listen function
177    pub(crate) async fn do_listen(&self, socket_addr: SocketAddr) -> io::Result<SmolTcpListener> {
178        let socket = self.create_socket(socket_addr).await?;
179        socket.bind(&socket_addr.into())?;
180        socket.listen(self.backlog)?;
181        socket.set_nonblocking(true)?;
182
183        let listener = std::net::TcpListener::from(socket);
184        let local_port = listener.local_addr()?.port();
185        let listener = smol::Async::<std::net::TcpListener>::try_from(listener)?;
186
187        self.port.set(local_port).await.expect("fatal port already set for TcpListener");
188
189        Ok(SmolTcpListener::from(listener))
190    }
191}
192
193#[async_trait]
194impl PtListener for SmolTcpListener {
195    async fn next(&self) -> io::Result<(Box<dyn PtStream>, Url)> {
196        let (stream, peer_addr) = match self.accept().await {
197            Ok((s, a)) => (s, a),
198            Err(e) => return Err(e),
199        };
200
201        let url = match Url::parse(&format!("tcp://{peer_addr}")) {
202            Ok(v) => v,
203            Err(e) => {
204                return Err(io::Error::new(
205                    io::ErrorKind::InvalidData,
206                    format!("Invalid peer address '{peer_addr}': {e}"),
207                ))
208            }
209        };
210        Ok((Box::new(stream), url))
211    }
212}
213
214#[async_trait]
215impl PtListener for (TlsAcceptor, SmolTcpListener) {
216    async fn next(&self) -> io::Result<(Box<dyn PtStream>, Url)> {
217        let (stream, peer_addr) = match self.1.accept().await {
218            Ok((s, a)) => (s, a),
219            Err(e) => return Err(e),
220        };
221
222        let stream = match self.0.accept(stream).await {
223            Ok(v) => v,
224            Err(e) => return Err(e),
225        };
226
227        let url = match Url::parse(&format!("tcp+tls://{peer_addr}")) {
228            Ok(v) => v,
229            Err(e) => {
230                return Err(io::Error::new(
231                    io::ErrorKind::InvalidData,
232                    format!("Invalid peer address '{peer_addr}': {e}"),
233                ))
234            }
235        };
236
237        Ok((Box::new(TlsStream::Server(stream)), url))
238    }
239}