darkfi/net/transport/
quic.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{
20    collections::HashMap,
21    io,
22    net::SocketAddr,
23    pin::Pin,
24    sync::{Arc, OnceLock},
25    task::{Context, Poll},
26    time::Duration,
27};
28
29use async_trait::async_trait;
30use futures::{
31    future::{select, Either},
32    pin_mut,
33};
34use futures_rustls::rustls::{self, version::TLS13};
35use quinn_smol::{
36    crypto::rustls::{QuicClientConfig, QuicServerConfig},
37    ClientConfig, Endpoint, RecvStream, SendStream, ServerConfig, TransportConfig, VarInt,
38};
39use smol::{
40    io::{AsyncRead, AsyncWrite},
41    lock::{Mutex, OnceCell},
42    Timer,
43};
44use tracing::debug;
45use url::Url;
46
47use super::{
48    tls::{
49        generate_certificate, ClientCertificateVerifier, ServerCertificateVerifier, TLS_DNS_NAME,
50    },
51    PtListener, PtStream,
52};
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55struct EndpointKey {
56    is_ipv6: bool,
57    port: u16,
58}
59
60impl EndpointKey {
61    fn from_addr(addr: SocketAddr) -> Self {
62        Self { is_ipv6: addr.is_ipv6(), port: addr.port() }
63    }
64}
65
66/// Global registry of QUIC endpoints, keyed by (addr_family, port).
67/// This enables transparent endpoint sharing between Dialer and Listener.
68static ENDPOINT_REGISTRY: OnceLock<Mutex<EndpointRegistry>> = OnceLock::new();
69
70struct EndpointRegistry {
71    endpoints: HashMap<EndpointKey, Endpoint>,
72}
73
74impl EndpointRegistry {
75    fn new() -> Self {
76        Self { endpoints: HashMap::new() }
77    }
78
79    /// Find an endpoint suitable for dialing the given target address.
80    fn find_for_target(&self, target: SocketAddr) -> Option<Endpoint> {
81        let is_ipv6 = target.is_ipv6();
82        self.endpoints.iter().find(|(k, _)| k.is_ipv6 == is_ipv6).map(|(_, ep)| ep.clone())
83    }
84}
85
86fn registry() -> &'static Mutex<EndpointRegistry> {
87    ENDPOINT_REGISTRY.get_or_init(|| Mutex::new(EndpointRegistry::new()))
88}
89
90/// Register an endpoint for the given bind address.
91/// Returns the endpoint (may be existing if already registered).
92async fn register_endpoint(bind_addr: SocketAddr) -> io::Result<Endpoint> {
93    let mut reg = registry().lock().await;
94
95    let key = EndpointKey::from_addr(bind_addr);
96
97    // Check if we already have an endpoint for this (family, port)
98    if bind_addr.port() != 0 {
99        if let Some(endpoint) = reg.endpoints.get(&key) {
100            debug!(
101                target: "net::quic::registry",
102                "[QUIC] Reusing existing {} endpoint on port {}",
103                if key.is_ipv6 { "IPv6" } else { "IPv4" },
104                key.port,
105            );
106            return Ok(endpoint.clone())
107        }
108    }
109
110    // Create new dual-mode endpoint
111    let endpoint = create_dual_endpoint(bind_addr).await?;
112    let actual_port = endpoint.local_addr()?.port();
113
114    let actual_key = EndpointKey { is_ipv6: key.is_ipv6, port: actual_port };
115
116    debug!(
117        target: "net::quic::registry",
118        "[QUIC] Created new {} QUIC endpoint on port {}",
119        if actual_key.is_ipv6 { "IPv6" } else { "IPv4" },
120        actual_port,
121    );
122
123    reg.endpoints.insert(actual_key, endpoint.clone());
124
125    Ok(endpoint)
126}
127
128/// Get an endpoint suitable for dialing the given target address.
129/// If no matching endpoint exist, creates a new one.
130async fn get_endpoint_for_target(target: SocketAddr) -> io::Result<Endpoint> {
131    let reg = registry().lock().await;
132    if let Some(endpoint) = reg.find_for_target(target) {
133        debug!(
134            target: "net::quic::registry",
135            "[QUIC] Dialer using existing {} endpoint on port {}",
136            if target.is_ipv6() { "IPv6" } else { "IPv4" },
137            endpoint.local_addr().map(|a| a.port()).unwrap_or(0),
138        );
139        return Ok(endpoint)
140    }
141    drop(reg);
142
143    // No suitable endpoint, create one.
144    let bind_addr: SocketAddr =
145        if target.is_ipv6() { "[::]:0".parse().unwrap() } else { "0.0.0.0:0".parse().unwrap() };
146
147    debug!(
148        target: "net::quic::registry",
149        "[QUIC] Creating new {} endpoint for dialing",
150        if target.is_ipv6() { "IPv6" } else { "IPv4" },
151    );
152
153    register_endpoint(bind_addr).await
154}
155
156/// Create an endpoint configured for both client and server roles
157async fn create_dual_endpoint(bind_addr: SocketAddr) -> io::Result<Endpoint> {
158    let server_config = create_server_config()?;
159    let client_config = create_client_config()?;
160
161    let endpoint = Endpoint::server(server_config, bind_addr)
162        .map_err(|e| io::Error::other(format!("Failed to create QUIC endpoint: {e}")))?;
163
164    endpoint.set_default_client_config(client_config);
165
166    Ok(endpoint)
167}
168
169/// Create QUIC client configuration with our TLS config
170fn create_client_config() -> io::Result<ClientConfig> {
171    let (certificate, secret_key) = generate_certificate()?;
172
173    let server_cert_verifier = Arc::new(ServerCertificateVerifier {});
174
175    let tls_config = rustls::ClientConfig::builder_with_protocol_versions(&[&TLS13])
176        .dangerous()
177        .with_custom_certificate_verifier(server_cert_verifier)
178        .with_client_auth_cert(vec![certificate], secret_key)
179        .map_err(|e| io::Error::other(format!("Failed to create QUIC client TLS config: {e}")))?;
180
181    let quic_config: QuicClientConfig = tls_config
182        .try_into()
183        .map_err(|e| io::Error::other(format!("Failed to create QUIC client config: {e}")))?;
184
185    let mut config = ClientConfig::new(Arc::new(quic_config));
186
187    // Configure transport parameters
188    let mut transport = TransportConfig::default();
189    transport.keep_alive_interval(Some(Duration::from_secs(15)));
190    transport.max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
191    config.transport_config(Arc::new(transport));
192
193    Ok(config)
194}
195
196/// Create QUIC server configuration with our TLS config
197fn create_server_config() -> io::Result<ServerConfig> {
198    let (certificate, secret_key) = generate_certificate()?;
199
200    let client_cert_verifier = Arc::new(ClientCertificateVerifier {});
201
202    let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13])
203        .with_client_cert_verifier(client_cert_verifier)
204        .with_single_cert(vec![certificate], secret_key)
205        .map_err(|e| io::Error::other(format!("Failed to create QUIC server TLS config: {e}")))?;
206
207    let quic_config: QuicServerConfig = tls_config
208        .try_into()
209        .map_err(|e| io::Error::other(format!("Failed to create QUIC server config: {e}")))?;
210
211    let mut config = ServerConfig::with_crypto(Arc::new(quic_config));
212
213    // Configure transport parameters
214    let mut transport = TransportConfig::default();
215    transport.keep_alive_interval(Some(Duration::from_secs(15)));
216    transport.max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
217    config.transport_config(Arc::new(transport));
218
219    Ok(config)
220}
221
222/// Wrapper around quinn's bidirectional stream to implement PtStream
223pub struct QuicStream {
224    send: SendStream,
225    recv: RecvStream,
226}
227
228impl QuicStream {
229    fn new(send: SendStream, recv: RecvStream) -> Self {
230        Self { send, recv }
231    }
232}
233
234impl AsyncRead for QuicStream {
235    fn poll_read(
236        mut self: Pin<&mut Self>,
237        cx: &mut Context<'_>,
238        buf: &mut [u8],
239    ) -> Poll<io::Result<usize>> {
240        Pin::new(&mut self.recv)
241            .poll_read(cx, buf)
242            .map_err(|e| io::Error::other(format!("QUIC read error: {e}")))
243    }
244}
245
246impl AsyncWrite for QuicStream {
247    fn poll_write(
248        mut self: Pin<&mut Self>,
249        cx: &mut Context<'_>,
250        buf: &[u8],
251    ) -> Poll<io::Result<usize>> {
252        Pin::new(&mut self.send)
253            .poll_write(cx, buf)
254            .map_err(|e| io::Error::other(format!("QUIC write error: {e}")))
255    }
256
257    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
258        Pin::new(&mut self.send)
259            .poll_flush(cx)
260            .map_err(|e| io::Error::other(format!("QUIC flush error: {e}")))
261    }
262
263    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
264        Pin::new(&mut self.send)
265            .poll_close(cx)
266            .map_err(|e| io::Error::other(format!("QUIC close error: {e}")))
267    }
268}
269
270/// QUIC Dialer implementation.
271///
272/// Automatically shares endpoint with QuicListener when one exists,
273/// enabling NAT hole punching without any special configuration.
274#[derive(Clone, Debug)]
275pub struct QuicDialer;
276
277impl QuicDialer {
278    /// Instantiate a new [`QuicDialer`] object
279    ///
280    /// The actual endpoint is selected at dial-time based on the target.
281    pub(crate) async fn new() -> io::Result<Self> {
282        Ok(Self {})
283    }
284
285    /// Internal dial function
286    pub(crate) async fn do_dial(
287        &self,
288        socket_addr: SocketAddr,
289        timeout: Option<Duration>,
290    ) -> io::Result<QuicStream> {
291        // Get appropriate endpoint for target address family
292        let endpoint = get_endpoint_for_target(socket_addr).await?;
293
294        debug!(
295            target: "net::quic::do_dial",
296            "[QUIC] Dialing {} {} from local {}",
297            if socket_addr.is_ipv6() { "IPv6" } else { "IPv4" },
298            socket_addr,
299            endpoint.local_addr().map(|a| a.to_string()).unwrap_or_default(),
300        );
301
302        let connect = async {
303            // Connect to the remote endpoint
304            let connection = endpoint
305                .connect(socket_addr, TLS_DNS_NAME)
306                .map_err(|e| io::Error::other(format!("QUIC connect error: {e}")))?
307                .await
308                .map_err(|e| io::Error::other(format!("QUIC connection error: {e}")))?;
309
310            // Open a bidirectional stream
311            let (send, recv) = connection
312                .open_bi()
313                .await
314                .map_err(|e| io::Error::other(format!("QUIC stream error: {e}")))?;
315
316            Ok(QuicStream::new(send, recv))
317        };
318
319        match timeout {
320            Some(t) => {
321                let timer = Timer::after(t);
322                pin_mut!(timer);
323                pin_mut!(connect);
324
325                match select(connect, timer).await {
326                    Either::Left((Ok(stream), _)) => Ok(stream),
327                    Either::Left((Err(e), _)) => Err(e),
328                    Either::Right((_, _)) => Err(io::ErrorKind::TimedOut.into()),
329                }
330            }
331            None => connect.await,
332        }
333    }
334}
335
336/// QUIC Listener implementation
337///
338/// When created, registers its endpoint so that QuicDialer can share it,
339/// enabling NAT hole punching automatically.
340#[derive(Debug, Clone)]
341pub struct QuicListener {
342    /// When the user puts a port of 0, the OS will assign a random port.
343    /// We get it from the listener so we know what the true endpoint is.
344    pub port: Arc<OnceCell<u16>>,
345}
346
347impl QuicListener {
348    /// Instantiate a new [`QuicListener`]
349    pub async fn new() -> io::Result<Self> {
350        Ok(Self { port: Arc::new(OnceCell::new()) })
351    }
352
353    /// Internal listen function
354    pub(crate) async fn do_listen(
355        &self,
356        socket_addr: SocketAddr,
357    ) -> io::Result<QuicListenerIntern> {
358        let endpoint = register_endpoint(socket_addr).await?;
359
360        let local_addr = endpoint.local_addr()?;
361
362        debug!(
363            target: "net::quic::do_listen",
364            "[QUIC] Listening on {} QUIC endpoint: {}",
365            if local_addr.is_ipv6() { "IPv6" } else { "IPv4" },
366            local_addr,
367        );
368
369        self.port.set(local_addr.port()).await.expect("fatal port already set for QuicListener");
370
371        Ok(QuicListenerIntern { endpoint })
372    }
373}
374
375/// Internal QUIC Listener implementation, used with `PtListener`
376pub struct QuicListenerIntern {
377    endpoint: Endpoint,
378}
379
380#[async_trait]
381impl PtListener for QuicListenerIntern {
382    async fn next(&self) -> io::Result<(Box<dyn PtStream>, Url)> {
383        // Wait for an incoming connection
384        let incoming =
385            self.endpoint.accept().await.ok_or_else(|| {
386                io::Error::new(io::ErrorKind::ConnectionAborted, "Endpoint closed")
387            })?;
388
389        let peer_addr = incoming.remote_address();
390
391        let connection =
392            incoming.await.map_err(|e| io::Error::other(format!("QUIC accept error: {e}")))?;
393
394        // Accept a bidirectional stream from the client
395        let (send, recv) = connection
396            .accept_bi()
397            .await
398            .map_err(|e| io::Error::other(format!("QUIC stream accept error: {e}")))?;
399
400        let url = Url::parse(&format!("quic://{peer_addr}")).map_err(|e| {
401            io::Error::new(io::ErrorKind::InvalidData, format!("Invalid peer address: {e}"))
402        })?;
403
404        Ok((Box::new(QuicStream::new(send, recv)), url))
405    }
406}