darkfi/net/transport/
tls.rs

1/* This file is part of DarkFi (https://dark.fi)
2 *
3 * Copyright (C) 2020-2026 Dyne.org foundation
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU Affero General Public License as
7 * published by the Free Software Foundation, either version 3 of the
8 * License, or (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU Affero General Public License for more details.
14 *
15 * You should have received a copy of the GNU Affero General Public License
16 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 */
18
19use std::{io, sync::Arc};
20
21use futures_rustls::{
22    rustls::{
23        self,
24        client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
25        pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime},
26        server::danger::{ClientCertVerified, ClientCertVerifier},
27        version::TLS13,
28        ClientConfig, DigitallySignedStruct, DistinguishedName, ServerConfig, SignatureScheme,
29    },
30    TlsAcceptor, TlsConnector, TlsStream,
31};
32use rcgen::string::Ia5String;
33use tracing::error;
34use x509_parser::{
35    parse_x509_certificate,
36    prelude::{GeneralName, ParsedExtension, X509Certificate},
37};
38
39/// The DNS name used for certificate validation across all transports
40pub(crate) const TLS_DNS_NAME: &str = "dark.fi";
41
42/// Validate certificate DNSName.
43fn validate_dnsname(cert: &X509Certificate) -> std::result::Result<(), rustls::Error> {
44    #[rustfmt::skip]
45    let oid = x509_parser::oid_registry::asn1_rs::oid!(2.5.29.17);
46    let Ok(Some(extension)) = cert.get_extension_unique(&oid) else {
47        return Err(rustls::CertificateError::BadEncoding.into())
48    };
49
50    let dns_name = match extension.parsed_extension() {
51        ParsedExtension::SubjectAlternativeName(altname) => {
52            if altname.general_names.len() != 1 {
53                return Err(rustls::CertificateError::BadEncoding.into())
54            }
55
56            match altname.general_names[0] {
57                GeneralName::DNSName(dns_name) => dns_name,
58                _ => return Err(rustls::CertificateError::BadEncoding.into()),
59            }
60        }
61
62        _ => return Err(rustls::CertificateError::BadEncoding.into()),
63    };
64
65    if dns_name != TLS_DNS_NAME {
66        return Err(rustls::CertificateError::BadEncoding.into())
67    }
68
69    Ok(())
70}
71
72fn verify_ed25519_signature(
73    message: &[u8],
74    cert: &CertificateDer,
75    dss: &DigitallySignedStruct,
76) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
77    if dss.scheme != SignatureScheme::ED25519 {
78        return Err(rustls::CertificateError::BadSignature.into())
79    }
80
81    // Read the DER-encoded certificate into a buffer
82    let buf: Vec<u8> = cert.iter().copied().collect();
83
84    // Parse the cert and extract the public key
85    let Ok((_, cert)) = parse_x509_certificate(&buf) else {
86        error!(target: "net::tls::verify_ed25519_signature", "[net::tls] Failed parsing TLS certificate");
87        return Err(rustls::CertificateError::BadEncoding.into())
88    };
89
90    let Ok(public_key) = ed25519_compact::PublicKey::from_der(cert.public_key().raw) else {
91        error!(target: "net::tls::verify_ed25519_signature", "[net::tls] Failed parsing public key");
92        return Err(rustls::CertificateError::BadEncoding.into())
93    };
94
95    let Ok(signature) = ed25519_compact::Signature::from_slice(dss.signature()) else {
96        error!(target: "net::tls::verify_ed25519_signature", "[net::tls] Failed verifying signature");
97        return Err(rustls::CertificateError::BadSignature.into())
98    };
99
100    if let Err(e) = public_key.verify(message, &signature) {
101        error!(target: "net::tls::verify_ed25519_signature", "[net::tls] Failed verifying signature: {e}");
102        return Err(rustls::CertificateError::BadSignature.into())
103    }
104
105    Ok(HandshakeSignatureValid::assertion())
106}
107
108#[derive(Debug)]
109pub(crate) struct ServerCertificateVerifier;
110
111impl ServerCertVerifier for ServerCertificateVerifier {
112    fn verify_server_cert(
113        &self,
114        end_entity: &CertificateDer,
115        _intermediates: &[CertificateDer],
116        _server_name: &ServerName,
117        _ocsp_response: &[u8],
118        _now: UnixTime,
119    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
120        // Read the DER-encoded certificate into a buffer
121        let buf: Vec<u8> = end_entity.iter().copied().collect();
122
123        // Parse the certificate
124        let Ok((_, cert)) = parse_x509_certificate(&buf) else {
125            error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
126            return Err(rustls::CertificateError::BadEncoding.into())
127        };
128
129        // Validate DNSName
130        validate_dnsname(&cert)?;
131
132        Ok(ServerCertVerified::assertion())
133    }
134
135    fn verify_tls12_signature(
136        &self,
137        _message: &[u8],
138        _cert: &CertificateDer,
139        _dss: &DigitallySignedStruct,
140    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
141        unreachable!()
142    }
143
144    fn verify_tls13_signature(
145        &self,
146        message: &[u8],
147        cert: &CertificateDer,
148        dss: &DigitallySignedStruct,
149    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
150        verify_ed25519_signature(message, cert, dss)
151    }
152
153    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
154        vec![SignatureScheme::ED25519]
155    }
156}
157
158#[derive(Debug)]
159pub(crate) struct ClientCertificateVerifier;
160
161impl ClientCertVerifier for ClientCertificateVerifier {
162    fn offer_client_auth(&self) -> bool {
163        true
164    }
165
166    fn client_auth_mandatory(&self) -> bool {
167        true
168    }
169
170    fn root_hint_subjects(&self) -> &[DistinguishedName] {
171        &[]
172    }
173
174    fn verify_client_cert(
175        &self,
176        end_entity: &CertificateDer,
177        _intermediates: &[CertificateDer],
178        _now: UnixTime,
179    ) -> std::result::Result<ClientCertVerified, rustls::Error> {
180        // Read the DER-encoded certificate into a buffer
181        let buf: Vec<u8> = end_entity.iter().copied().collect();
182
183        // Parse the certificate
184        let Ok((_, cert)) = parse_x509_certificate(&buf) else {
185            error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
186            return Err(rustls::CertificateError::BadEncoding.into())
187        };
188
189        // Validate DNSName
190        validate_dnsname(&cert)?;
191
192        Ok(ClientCertVerified::assertion())
193    }
194
195    fn verify_tls12_signature(
196        &self,
197        _message: &[u8],
198        _cert: &CertificateDer,
199        _dss: &DigitallySignedStruct,
200    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
201        unreachable!()
202    }
203
204    fn verify_tls13_signature(
205        &self,
206        message: &[u8],
207        cert: &CertificateDer,
208        dss: &DigitallySignedStruct,
209    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
210        verify_ed25519_signature(message, cert, dss)
211    }
212
213    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
214        vec![SignatureScheme::ED25519]
215    }
216}
217
218/// Generate a self-signed Ed25519 certificate for TLS.
219/// Returns the certificate and private key in DER format.
220pub(crate) fn generate_certificate() -> io::Result<(CertificateDer<'static>, PrivateKeyDer<'static>)>
221{
222    let Ok(keypair) = rcgen::KeyPair::generate_for(&rcgen::PKCS_ED25519) else {
223        return Err(io::Error::other("Failed to generate TLS keypair"))
224    };
225
226    let Ok(mut cert_params) = rcgen::CertificateParams::new(&[]) else {
227        return Err(io::Error::other("Failed to generate TLS params"))
228    };
229
230    cert_params.subject_alt_names =
231        vec![rcgen::SanType::DnsName(Ia5String::try_from(TLS_DNS_NAME).unwrap())];
232    cert_params.extended_key_usages = vec![
233        rcgen::ExtendedKeyUsagePurpose::ClientAuth,
234        rcgen::ExtendedKeyUsagePurpose::ServerAuth,
235    ];
236
237    let Ok(certificate) = cert_params.self_signed(&keypair) else {
238        return Err(io::Error::other("Failed to sign TLS certificate"))
239    };
240
241    let certificate = certificate.der().clone();
242    let keypair_der = keypair.serialize_der();
243
244    let Ok(secret_key_der) = PrivateKeyDer::try_from(keypair_der) else {
245        return Err(io::Error::other("Failed to deserialize DER TLS secret"))
246    };
247
248    Ok((certificate, secret_key_der))
249}
250
251pub struct TlsUpgrade {
252    /// TLS server configuration
253    server_config: Arc<ServerConfig>,
254    /// TLS client configuration
255    client_config: Arc<ClientConfig>,
256}
257
258impl TlsUpgrade {
259    pub async fn new() -> io::Result<Self> {
260        // On each instantiation, generate a new keypair and certificate
261        let (certificate, secret_key_der) = generate_certificate()?;
262
263        // Server-side config
264        let client_cert_verifier = Arc::new(ClientCertificateVerifier {});
265        let server_config = Arc::new(
266            ServerConfig::builder_with_protocol_versions(&[&TLS13])
267                .with_client_cert_verifier(client_cert_verifier)
268                .with_single_cert(vec![certificate.clone()], secret_key_der.clone_key())
269                .unwrap(),
270        );
271
272        // Client-side config
273        let server_cert_verifier = Arc::new(ServerCertificateVerifier {});
274        let client_config = Arc::new(
275            ClientConfig::builder_with_protocol_versions(&[&TLS13])
276                .dangerous()
277                .with_custom_certificate_verifier(server_cert_verifier)
278                .with_client_auth_cert(vec![certificate.clone()], secret_key_der)
279                .unwrap(),
280        );
281
282        Ok(Self { server_config, client_config })
283    }
284
285    pub async fn upgrade_dialer_tls<IO>(self, stream: IO) -> io::Result<TlsStream<IO>>
286    where
287        IO: super::PtStream,
288    {
289        let server_name = ServerName::try_from(TLS_DNS_NAME).unwrap();
290        let connector = TlsConnector::from(self.client_config);
291        let stream = connector.connect(server_name, stream).await?;
292        Ok(TlsStream::Client(stream))
293    }
294
295    // TODO: Try to find a transparent way for this instead of implementing
296    // the function separately for every transport type.
297    pub async fn upgrade_listener_tcp_tls(
298        self,
299        listener: smol::net::TcpListener,
300    ) -> io::Result<(TlsAcceptor, smol::net::TcpListener)> {
301        Ok((TlsAcceptor::from(self.server_config), listener))
302    }
303}