darkfi/net/transport/
tls.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{io, sync::Arc};
20
21use futures_rustls::{
22    rustls::{
23        self,
24        client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
25        pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime},
26        server::danger::{ClientCertVerified, ClientCertVerifier},
27        version::TLS13,
28        ClientConfig, DigitallySignedStruct, DistinguishedName, ServerConfig, SignatureScheme,
29    },
30    TlsAcceptor, TlsConnector, TlsStream,
31};
32use rcgen::string::Ia5String;
33use tracing::error;
34use x509_parser::{
35    parse_x509_certificate,
36    prelude::{GeneralName, ParsedExtension, X509Certificate},
37};
38
39/// Validate certificate DNSName.
40fn validate_dnsname(cert: &X509Certificate) -> std::result::Result<(), rustls::Error> {
41    #[rustfmt::skip]
42        let oid = x509_parser::oid_registry::asn1_rs::oid!(2.5.29.17);
43    let Ok(Some(extension)) = cert.get_extension_unique(&oid) else {
44        return Err(rustls::CertificateError::BadEncoding.into())
45    };
46
47    let dns_name = match extension.parsed_extension() {
48        ParsedExtension::SubjectAlternativeName(altname) => {
49            if altname.general_names.len() != 1 {
50                return Err(rustls::CertificateError::BadEncoding.into())
51            }
52
53            match altname.general_names[0] {
54                GeneralName::DNSName(dns_name) => dns_name,
55                _ => return Err(rustls::CertificateError::BadEncoding.into()),
56            }
57        }
58
59        _ => return Err(rustls::CertificateError::BadEncoding.into()),
60    };
61
62    if dns_name != "dark.fi" {
63        return Err(rustls::CertificateError::BadEncoding.into())
64    }
65
66    Ok(())
67}
68
69#[derive(Debug)]
70struct ServerCertificateVerifier;
71impl ServerCertVerifier for ServerCertificateVerifier {
72    fn verify_server_cert(
73        &self,
74        end_entity: &CertificateDer,
75        _intermediates: &[CertificateDer],
76        _server_name: &ServerName,
77        _ocsp_response: &[u8],
78        _now: UnixTime,
79    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
80        // Read the DER-encoded certificate into a buffer
81        let mut buf = Vec::with_capacity(end_entity.len());
82        for byte in end_entity.iter() {
83            buf.push(*byte);
84        }
85
86        // Parse the certificate
87        let Ok((_, cert)) = parse_x509_certificate(&buf) else {
88            error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
89            return Err(rustls::CertificateError::BadEncoding.into())
90        };
91
92        // Validate DNSName
93        validate_dnsname(&cert)?;
94
95        Ok(ServerCertVerified::assertion())
96    }
97
98    fn verify_tls12_signature(
99        &self,
100        _message: &[u8],
101        _cert: &CertificateDer,
102        _dss: &DigitallySignedStruct,
103    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
104        unreachable!()
105    }
106
107    fn verify_tls13_signature(
108        &self,
109        message: &[u8],
110        cert: &CertificateDer,
111        dss: &DigitallySignedStruct,
112    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
113        // Verify we're using the correct signature scheme
114        if dss.scheme != SignatureScheme::ED25519 {
115            return Err(rustls::CertificateError::BadSignature.into())
116        }
117
118        // Read the DER-encoded certificate into a buffer
119        let mut buf = Vec::with_capacity(cert.len());
120        for byte in cert.iter() {
121            buf.push(*byte);
122        }
123
124        // Parse the certificate and extract the public key
125        let Ok((_, cert)) = parse_x509_certificate(&buf) else {
126            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server TLS certificate");
127            return Err(rustls::CertificateError::BadEncoding.into())
128        };
129
130        let Ok(public_key) = ed25519_compact::PublicKey::from_der(cert.public_key().raw) else {
131            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server public key");
132            return Err(rustls::CertificateError::BadEncoding.into())
133        };
134
135        // Verify the signature
136        let Ok(signature) = ed25519_compact::Signature::from_slice(dss.signature()) else {
137            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature");
138            return Err(rustls::CertificateError::BadSignature.into())
139        };
140
141        if let Err(e) = public_key.verify(message, &signature) {
142            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature: {e}");
143            return Err(rustls::CertificateError::BadSignature.into())
144        }
145
146        Ok(HandshakeSignatureValid::assertion())
147    }
148
149    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
150        vec![SignatureScheme::ED25519]
151    }
152}
153
154#[derive(Debug)]
155struct ClientCertificateVerifier;
156impl ClientCertVerifier for ClientCertificateVerifier {
157    fn offer_client_auth(&self) -> bool {
158        true
159    }
160
161    fn client_auth_mandatory(&self) -> bool {
162        true
163    }
164
165    fn root_hint_subjects(&self) -> &[DistinguishedName] {
166        &[]
167    }
168
169    fn verify_client_cert(
170        &self,
171        end_entity: &CertificateDer,
172        _intermediates: &[CertificateDer],
173        _now: UnixTime,
174    ) -> std::result::Result<ClientCertVerified, rustls::Error> {
175        // Read the DER-encoded certificate into a buffer
176        let mut cert = Vec::with_capacity(end_entity.len());
177        for byte in end_entity.iter() {
178            cert.push(*byte);
179        }
180
181        // Parse the certificate
182        let Ok((_, cert)) = parse_x509_certificate(&cert) else {
183            error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
184            return Err(rustls::CertificateError::BadEncoding.into())
185        };
186
187        // Validate DNSName
188        validate_dnsname(&cert)?;
189
190        Ok(ClientCertVerified::assertion())
191    }
192
193    fn verify_tls12_signature(
194        &self,
195        _message: &[u8],
196        _cert: &CertificateDer,
197        _dss: &DigitallySignedStruct,
198    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
199        unreachable!()
200    }
201
202    fn verify_tls13_signature(
203        &self,
204        message: &[u8],
205        cert: &CertificateDer,
206        dss: &DigitallySignedStruct,
207    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
208        // Verify we're using the correct signature scheme
209        if dss.scheme != SignatureScheme::ED25519 {
210            return Err(rustls::CertificateError::BadSignature.into())
211        }
212
213        // Read the DER-encoded certificate into a buffer
214        let mut buf = Vec::with_capacity(cert.len());
215        for byte in cert.iter() {
216            buf.push(*byte);
217        }
218
219        // Parse the certificate and extract the public key
220        let Ok((_, cert)) = parse_x509_certificate(&buf) else {
221            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server TLS certificate");
222            return Err(rustls::CertificateError::BadEncoding.into())
223        };
224
225        let Ok(public_key) = ed25519_compact::PublicKey::from_der(cert.public_key().raw) else {
226            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed parsing server public key");
227            return Err(rustls::CertificateError::BadEncoding.into())
228        };
229
230        // Verify the signature
231        let Ok(signature) = ed25519_compact::Signature::from_slice(dss.signature()) else {
232            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature");
233            return Err(rustls::CertificateError::BadSignature.into())
234        };
235
236        if let Err(e) = public_key.verify(message, &signature) {
237            error!(target: "net::tls::verify_tls13_signature", "[net::tls] Failed verifying server signature: {e}");
238            return Err(rustls::CertificateError::BadSignature.into())
239        }
240
241        Ok(HandshakeSignatureValid::assertion())
242    }
243
244    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
245        vec![SignatureScheme::ED25519]
246    }
247}
248
249pub struct TlsUpgrade {
250    /// TLS server configuration
251    server_config: Arc<ServerConfig>,
252    /// TLS client configuration
253    client_config: Arc<ClientConfig>,
254}
255
256impl TlsUpgrade {
257    pub async fn new() -> io::Result<Self> {
258        // On each instantiation, generate a new keypair and certificate
259        let Ok(keypair) = rcgen::KeyPair::generate_for(&rcgen::PKCS_ED25519) else {
260            return Err(io::Error::other("Failed to generate TLS keypair"))
261        };
262
263        let Ok(mut cert_params) = rcgen::CertificateParams::new(&[]) else {
264            return Err(io::Error::other("Failed to generate TLS params"))
265        };
266
267        cert_params.subject_alt_names =
268            vec![rcgen::SanType::DnsName(Ia5String::try_from("dark.fi").unwrap())];
269        cert_params.extended_key_usages = vec![
270            rcgen::ExtendedKeyUsagePurpose::ClientAuth,
271            rcgen::ExtendedKeyUsagePurpose::ServerAuth,
272        ];
273
274        let Ok(certificate) = cert_params.self_signed(&keypair) else {
275            return Err(io::Error::other("Failed to sign TLS certificate"))
276        };
277        let certificate = certificate.der();
278
279        let keypair_der = keypair.serialize_der();
280        let Ok(secret_key_der) = PrivateKeyDer::try_from(keypair_der) else {
281            return Err(io::Error::other("Failed to deserialize DER TLS secret"))
282        };
283
284        // Server-side config
285        let client_cert_verifier = Arc::new(ClientCertificateVerifier {});
286        let server_config = Arc::new(
287            ServerConfig::builder_with_protocol_versions(&[&TLS13])
288                .with_client_cert_verifier(client_cert_verifier)
289                .with_single_cert(vec![certificate.clone()], secret_key_der.clone_key())
290                .unwrap(),
291        );
292
293        // Client-side config
294        let server_cert_verifier = Arc::new(ServerCertificateVerifier {});
295        let client_config = Arc::new(
296            ClientConfig::builder_with_protocol_versions(&[&TLS13])
297                .dangerous()
298                .with_custom_certificate_verifier(server_cert_verifier)
299                .with_client_auth_cert(vec![certificate.clone()], secret_key_der)
300                .unwrap(),
301        );
302
303        Ok(Self { server_config, client_config })
304    }
305
306    pub async fn upgrade_dialer_tls<IO>(self, stream: IO) -> io::Result<TlsStream<IO>>
307    where
308        IO: super::PtStream,
309    {
310        let server_name = ServerName::try_from("dark.fi").unwrap();
311        let connector = TlsConnector::from(self.client_config);
312        let stream = connector.connect(server_name, stream).await?;
313        Ok(TlsStream::Client(stream))
314    }
315
316    // TODO: Try to find a transparent way for this instead of implementing
317    // the function separately for every transport type.
318    pub async fn upgrade_listener_tcp_tls(
319        self,
320        listener: smol::net::TcpListener,
321    ) -> io::Result<(TlsAcceptor, smol::net::TcpListener)> {
322        Ok((TlsAcceptor::from(self.server_config), listener))
323    }
324}