darkfi/net/transport/
tcp.rs1use std::{io, sync::Arc, time::Duration};
20
21use async_trait::async_trait;
22use futures::{
23 future::{select, Either},
24 pin_mut,
25};
26use futures_rustls::{TlsAcceptor, TlsStream};
27use smol::{
28 lock::OnceCell,
29 net::{SocketAddr, TcpListener as SmolTcpListener, TcpStream},
30 Async, Timer,
31};
32use socket2::{Domain, Socket, TcpKeepalive, Type};
33use tracing::debug;
34use url::Url;
35
36use super::{PtListener, PtStream};
37
38trait SocketExt {
39 fn enable_reuse_port(&self) -> io::Result<()>;
40}
41
42impl SocketExt for Socket {
43 fn enable_reuse_port(&self) -> io::Result<()> {
44 #[cfg(target_family = "unix")]
45 self.set_reuse_port(true)?;
46
47 #[cfg(target_family = "windows")]
49 self.set_reuse_address(true)?;
50
51 Ok(())
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct TcpDialer {
58 ttl: Option<u32>,
60}
61
62impl TcpDialer {
63 pub(crate) async fn new(ttl: Option<u32>) -> io::Result<Self> {
65 Ok(Self { ttl })
66 }
67
68 async fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<Socket> {
70 let domain = if socket_addr.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 };
71 let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
72
73 if socket_addr.is_ipv6() {
74 socket.set_only_v6(true)?;
75 }
76
77 if let Some(ttl) = self.ttl {
78 socket.set_ttl_v4(ttl)?;
79 }
80
81 socket.set_tcp_nodelay(true)?;
82 let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(20));
83 socket.set_tcp_keepalive(&keepalive)?;
84 socket.enable_reuse_port()?;
85
86 Ok(socket)
87 }
88
89 pub(crate) async fn do_dial(
91 &self,
92 socket_addr: SocketAddr,
93 timeout: Option<Duration>,
94 ) -> io::Result<TcpStream> {
95 debug!(target: "net::tcp::do_dial", "Dialing {socket_addr} with TCP...");
96 let socket = self.create_socket(socket_addr).await?;
97
98 socket.set_nonblocking(true)?;
99
100 match socket.connect(&socket_addr.into()) {
103 Ok(()) => {}
104 Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
105 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
106 Err(err) => return Err(err),
107 };
108
109 let stream = Async::new_nonblocking(std::net::TcpStream::from(socket))?;
110
111 let connect = async move {
113 stream.writable().await?;
114 match stream.get_ref().take_error()? {
115 Some(err) => Err(err),
116 None => Ok(stream),
117 }
118 };
119
120 match timeout {
124 Some(t) => {
125 let timeout = Timer::after(t);
126 pin_mut!(timeout);
127 pin_mut!(connect);
128
129 match select(connect, timeout).await {
130 Either::Left((Ok(stream), _)) => Ok(TcpStream::from(stream)),
131 Either::Left((Err(e), _)) => Err(e),
132 Either::Right((_, _)) => Err(io::ErrorKind::TimedOut.into()),
133 }
134 }
135 None => {
136 let stream = connect.await?;
137 Ok(TcpStream::from(stream))
138 }
139 }
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct TcpListener {
146 backlog: i32,
148 pub port: Arc<OnceCell<u16>>,
151}
152
153impl TcpListener {
154 pub async fn new(backlog: i32) -> io::Result<Self> {
156 Ok(Self { backlog, port: Arc::new(OnceCell::new()) })
157 }
158
159 async fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<Socket> {
161 let domain = if socket_addr.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 };
162 let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
163
164 if socket_addr.is_ipv6() {
165 socket.set_only_v6(true)?;
166 }
167
168 socket.set_tcp_nodelay(true)?;
169 let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(20));
170 socket.set_tcp_keepalive(&keepalive)?;
171 socket.enable_reuse_port()?;
172
173 Ok(socket)
174 }
175
176 pub(crate) async fn do_listen(&self, socket_addr: SocketAddr) -> io::Result<SmolTcpListener> {
178 let socket = self.create_socket(socket_addr).await?;
179 socket.bind(&socket_addr.into())?;
180 socket.listen(self.backlog)?;
181 socket.set_nonblocking(true)?;
182
183 let listener = std::net::TcpListener::from(socket);
184 let local_port = listener.local_addr()?.port();
185 let listener = smol::Async::<std::net::TcpListener>::try_from(listener)?;
186
187 self.port.set(local_port).await.expect("fatal port already set for TcpListener");
188
189 Ok(SmolTcpListener::from(listener))
190 }
191}
192
193#[async_trait]
194impl PtListener for SmolTcpListener {
195 async fn next(&self) -> io::Result<(Box<dyn PtStream>, Url)> {
196 let (stream, peer_addr) = match self.accept().await {
197 Ok((s, a)) => (s, a),
198 Err(e) => return Err(e),
199 };
200
201 let url = match Url::parse(&format!("tcp://{peer_addr}")) {
202 Ok(v) => v,
203 Err(e) => {
204 return Err(io::Error::new(
205 io::ErrorKind::InvalidData,
206 format!("Invalid peer address '{peer_addr}': {e}"),
207 ))
208 }
209 };
210 Ok((Box::new(stream), url))
211 }
212}
213
214#[async_trait]
215impl PtListener for (TlsAcceptor, SmolTcpListener) {
216 async fn next(&self) -> io::Result<(Box<dyn PtStream>, Url)> {
217 let (stream, peer_addr) = match self.1.accept().await {
218 Ok((s, a)) => (s, a),
219 Err(e) => return Err(e),
220 };
221
222 let stream = match self.0.accept(stream).await {
223 Ok(v) => v,
224 Err(e) => return Err(e),
225 };
226
227 let url = match Url::parse(&format!("tcp+tls://{peer_addr}")) {
228 Ok(v) => v,
229 Err(e) => {
230 return Err(io::Error::new(
231 io::ErrorKind::InvalidData,
232 format!("Invalid peer address '{peer_addr}': {e}"),
233 ))
234 }
235 };
236
237 Ok((Box::new(TlsStream::Server(stream)), url))
238 }
239}