1use std::{
20 collections::HashMap,
21 io,
22 net::SocketAddr,
23 pin::Pin,
24 sync::{Arc, OnceLock},
25 task::{Context, Poll},
26 time::Duration,
27};
28
29use async_trait::async_trait;
30use futures::{
31 future::{select, Either},
32 pin_mut,
33};
34use futures_rustls::rustls::{self, version::TLS13};
35use quinn_smol::{
36 crypto::rustls::{QuicClientConfig, QuicServerConfig},
37 ClientConfig, Endpoint, RecvStream, SendStream, ServerConfig, TransportConfig, VarInt,
38};
39use smol::{
40 io::{AsyncRead, AsyncWrite},
41 lock::{Mutex, OnceCell},
42 Timer,
43};
44use tracing::debug;
45use url::Url;
46
47use super::{
48 tls::{
49 generate_certificate, ClientCertificateVerifier, ServerCertificateVerifier, TLS_DNS_NAME,
50 },
51 PtListener, PtStream,
52};
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55struct EndpointKey {
56 is_ipv6: bool,
57 port: u16,
58}
59
60impl EndpointKey {
61 fn from_addr(addr: SocketAddr) -> Self {
62 Self { is_ipv6: addr.is_ipv6(), port: addr.port() }
63 }
64}
65
66static ENDPOINT_REGISTRY: OnceLock<Mutex<EndpointRegistry>> = OnceLock::new();
69
70struct EndpointRegistry {
71 endpoints: HashMap<EndpointKey, Endpoint>,
72}
73
74impl EndpointRegistry {
75 fn new() -> Self {
76 Self { endpoints: HashMap::new() }
77 }
78
79 fn find_for_target(&self, target: SocketAddr) -> Option<Endpoint> {
81 let is_ipv6 = target.is_ipv6();
82 self.endpoints.iter().find(|(k, _)| k.is_ipv6 == is_ipv6).map(|(_, ep)| ep.clone())
83 }
84}
85
86fn registry() -> &'static Mutex<EndpointRegistry> {
87 ENDPOINT_REGISTRY.get_or_init(|| Mutex::new(EndpointRegistry::new()))
88}
89
90async fn register_endpoint(bind_addr: SocketAddr) -> io::Result<Endpoint> {
93 let mut reg = registry().lock().await;
94
95 let key = EndpointKey::from_addr(bind_addr);
96
97 if bind_addr.port() != 0 {
99 if let Some(endpoint) = reg.endpoints.get(&key) {
100 debug!(
101 target: "net::quic::registry",
102 "[QUIC] Reusing existing {} endpoint on port {}",
103 if key.is_ipv6 { "IPv6" } else { "IPv4" },
104 key.port,
105 );
106 return Ok(endpoint.clone())
107 }
108 }
109
110 let endpoint = create_dual_endpoint(bind_addr).await?;
112 let actual_port = endpoint.local_addr()?.port();
113
114 let actual_key = EndpointKey { is_ipv6: key.is_ipv6, port: actual_port };
115
116 debug!(
117 target: "net::quic::registry",
118 "[QUIC] Created new {} QUIC endpoint on port {}",
119 if actual_key.is_ipv6 { "IPv6" } else { "IPv4" },
120 actual_port,
121 );
122
123 reg.endpoints.insert(actual_key, endpoint.clone());
124
125 Ok(endpoint)
126}
127
128async fn get_endpoint_for_target(target: SocketAddr) -> io::Result<Endpoint> {
131 let reg = registry().lock().await;
132 if let Some(endpoint) = reg.find_for_target(target) {
133 debug!(
134 target: "net::quic::registry",
135 "[QUIC] Dialer using existing {} endpoint on port {}",
136 if target.is_ipv6() { "IPv6" } else { "IPv4" },
137 endpoint.local_addr().map(|a| a.port()).unwrap_or(0),
138 );
139 return Ok(endpoint)
140 }
141 drop(reg);
142
143 let bind_addr: SocketAddr =
145 if target.is_ipv6() { "[::]:0".parse().unwrap() } else { "0.0.0.0:0".parse().unwrap() };
146
147 debug!(
148 target: "net::quic::registry",
149 "[QUIC] Creating new {} endpoint for dialing",
150 if target.is_ipv6() { "IPv6" } else { "IPv4" },
151 );
152
153 register_endpoint(bind_addr).await
154}
155
156async fn create_dual_endpoint(bind_addr: SocketAddr) -> io::Result<Endpoint> {
158 let server_config = create_server_config()?;
159 let client_config = create_client_config()?;
160
161 let endpoint = Endpoint::server(server_config, bind_addr)
162 .map_err(|e| io::Error::other(format!("Failed to create QUIC endpoint: {e}")))?;
163
164 endpoint.set_default_client_config(client_config);
165
166 Ok(endpoint)
167}
168
169fn create_client_config() -> io::Result<ClientConfig> {
171 let (certificate, secret_key) = generate_certificate()?;
172
173 let server_cert_verifier = Arc::new(ServerCertificateVerifier {});
174
175 let tls_config = rustls::ClientConfig::builder_with_protocol_versions(&[&TLS13])
176 .dangerous()
177 .with_custom_certificate_verifier(server_cert_verifier)
178 .with_client_auth_cert(vec![certificate], secret_key)
179 .map_err(|e| io::Error::other(format!("Failed to create QUIC client TLS config: {e}")))?;
180
181 let quic_config: QuicClientConfig = tls_config
182 .try_into()
183 .map_err(|e| io::Error::other(format!("Failed to create QUIC client config: {e}")))?;
184
185 let mut config = ClientConfig::new(Arc::new(quic_config));
186
187 let mut transport = TransportConfig::default();
189 transport.keep_alive_interval(Some(Duration::from_secs(15)));
190 transport.max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
191 config.transport_config(Arc::new(transport));
192
193 Ok(config)
194}
195
196fn create_server_config() -> io::Result<ServerConfig> {
198 let (certificate, secret_key) = generate_certificate()?;
199
200 let client_cert_verifier = Arc::new(ClientCertificateVerifier {});
201
202 let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13])
203 .with_client_cert_verifier(client_cert_verifier)
204 .with_single_cert(vec![certificate], secret_key)
205 .map_err(|e| io::Error::other(format!("Failed to create QUIC server TLS config: {e}")))?;
206
207 let quic_config: QuicServerConfig = tls_config
208 .try_into()
209 .map_err(|e| io::Error::other(format!("Failed to create QUIC server config: {e}")))?;
210
211 let mut config = ServerConfig::with_crypto(Arc::new(quic_config));
212
213 let mut transport = TransportConfig::default();
215 transport.keep_alive_interval(Some(Duration::from_secs(15)));
216 transport.max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
217 config.transport_config(Arc::new(transport));
218
219 Ok(config)
220}
221
222pub struct QuicStream {
224 send: SendStream,
225 recv: RecvStream,
226}
227
228impl QuicStream {
229 fn new(send: SendStream, recv: RecvStream) -> Self {
230 Self { send, recv }
231 }
232}
233
234impl AsyncRead for QuicStream {
235 fn poll_read(
236 mut self: Pin<&mut Self>,
237 cx: &mut Context<'_>,
238 buf: &mut [u8],
239 ) -> Poll<io::Result<usize>> {
240 Pin::new(&mut self.recv)
241 .poll_read(cx, buf)
242 .map_err(|e| io::Error::other(format!("QUIC read error: {e}")))
243 }
244}
245
246impl AsyncWrite for QuicStream {
247 fn poll_write(
248 mut self: Pin<&mut Self>,
249 cx: &mut Context<'_>,
250 buf: &[u8],
251 ) -> Poll<io::Result<usize>> {
252 Pin::new(&mut self.send)
253 .poll_write(cx, buf)
254 .map_err(|e| io::Error::other(format!("QUIC write error: {e}")))
255 }
256
257 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
258 Pin::new(&mut self.send)
259 .poll_flush(cx)
260 .map_err(|e| io::Error::other(format!("QUIC flush error: {e}")))
261 }
262
263 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
264 Pin::new(&mut self.send)
265 .poll_close(cx)
266 .map_err(|e| io::Error::other(format!("QUIC close error: {e}")))
267 }
268}
269
270#[derive(Clone, Debug)]
275pub struct QuicDialer;
276
277impl QuicDialer {
278 pub(crate) async fn new() -> io::Result<Self> {
282 Ok(Self {})
283 }
284
285 pub(crate) async fn do_dial(
287 &self,
288 socket_addr: SocketAddr,
289 timeout: Option<Duration>,
290 ) -> io::Result<QuicStream> {
291 let endpoint = get_endpoint_for_target(socket_addr).await?;
293
294 debug!(
295 target: "net::quic::do_dial",
296 "[QUIC] Dialing {} {} from local {}",
297 if socket_addr.is_ipv6() { "IPv6" } else { "IPv4" },
298 socket_addr,
299 endpoint.local_addr().map(|a| a.to_string()).unwrap_or_default(),
300 );
301
302 let connect = async {
303 let connection = endpoint
305 .connect(socket_addr, TLS_DNS_NAME)
306 .map_err(|e| io::Error::other(format!("QUIC connect error: {e}")))?
307 .await
308 .map_err(|e| io::Error::other(format!("QUIC connection error: {e}")))?;
309
310 let (send, recv) = connection
312 .open_bi()
313 .await
314 .map_err(|e| io::Error::other(format!("QUIC stream error: {e}")))?;
315
316 Ok(QuicStream::new(send, recv))
317 };
318
319 match timeout {
320 Some(t) => {
321 let timer = Timer::after(t);
322 pin_mut!(timer);
323 pin_mut!(connect);
324
325 match select(connect, timer).await {
326 Either::Left((Ok(stream), _)) => Ok(stream),
327 Either::Left((Err(e), _)) => Err(e),
328 Either::Right((_, _)) => Err(io::ErrorKind::TimedOut.into()),
329 }
330 }
331 None => connect.await,
332 }
333 }
334}
335
336#[derive(Debug, Clone)]
341pub struct QuicListener {
342 pub port: Arc<OnceCell<u16>>,
345}
346
347impl QuicListener {
348 pub async fn new() -> io::Result<Self> {
350 Ok(Self { port: Arc::new(OnceCell::new()) })
351 }
352
353 pub(crate) async fn do_listen(
355 &self,
356 socket_addr: SocketAddr,
357 ) -> io::Result<QuicListenerIntern> {
358 let endpoint = register_endpoint(socket_addr).await?;
359
360 let local_addr = endpoint.local_addr()?;
361
362 debug!(
363 target: "net::quic::do_listen",
364 "[QUIC] Listening on {} QUIC endpoint: {}",
365 if local_addr.is_ipv6() { "IPv6" } else { "IPv4" },
366 local_addr,
367 );
368
369 self.port.set(local_addr.port()).await.expect("fatal port already set for QuicListener");
370
371 Ok(QuicListenerIntern { endpoint })
372 }
373}
374
375pub struct QuicListenerIntern {
377 endpoint: Endpoint,
378}
379
380#[async_trait]
381impl PtListener for QuicListenerIntern {
382 async fn next(&self) -> io::Result<(Box<dyn PtStream>, Url)> {
383 let incoming =
385 self.endpoint.accept().await.ok_or_else(|| {
386 io::Error::new(io::ErrorKind::ConnectionAborted, "Endpoint closed")
387 })?;
388
389 let peer_addr = incoming.remote_address();
390
391 let connection =
392 incoming.await.map_err(|e| io::Error::other(format!("QUIC accept error: {e}")))?;
393
394 let (send, recv) = connection
396 .accept_bi()
397 .await
398 .map_err(|e| io::Error::other(format!("QUIC stream accept error: {e}")))?;
399
400 let url = Url::parse(&format!("quic://{peer_addr}")).map_err(|e| {
401 io::Error::new(io::ErrorKind::InvalidData, format!("Invalid peer address: {e}"))
402 })?;
403
404 Ok((Box::new(QuicStream::new(send, recv)), url))
405 }
406}