1use std::{io, sync::Arc};
20
21use futures_rustls::{
22 rustls::{
23 self,
24 client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
25 pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime},
26 server::danger::{ClientCertVerified, ClientCertVerifier},
27 version::TLS13,
28 ClientConfig, DigitallySignedStruct, DistinguishedName, ServerConfig, SignatureScheme,
29 },
30 TlsAcceptor, TlsConnector, TlsStream,
31};
32use rcgen::string::Ia5String;
33use tracing::error;
34use x509_parser::{
35 parse_x509_certificate,
36 prelude::{GeneralName, ParsedExtension, X509Certificate},
37};
38
39pub(crate) const TLS_DNS_NAME: &str = "dark.fi";
41
42fn validate_dnsname(cert: &X509Certificate) -> std::result::Result<(), rustls::Error> {
44 #[rustfmt::skip]
45 let oid = x509_parser::oid_registry::asn1_rs::oid!(2.5.29.17);
46 let Ok(Some(extension)) = cert.get_extension_unique(&oid) else {
47 return Err(rustls::CertificateError::BadEncoding.into())
48 };
49
50 let dns_name = match extension.parsed_extension() {
51 ParsedExtension::SubjectAlternativeName(altname) => {
52 if altname.general_names.len() != 1 {
53 return Err(rustls::CertificateError::BadEncoding.into())
54 }
55
56 match altname.general_names[0] {
57 GeneralName::DNSName(dns_name) => dns_name,
58 _ => return Err(rustls::CertificateError::BadEncoding.into()),
59 }
60 }
61
62 _ => return Err(rustls::CertificateError::BadEncoding.into()),
63 };
64
65 if dns_name != TLS_DNS_NAME {
66 return Err(rustls::CertificateError::BadEncoding.into())
67 }
68
69 Ok(())
70}
71
72fn verify_ed25519_signature(
73 message: &[u8],
74 cert: &CertificateDer,
75 dss: &DigitallySignedStruct,
76) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
77 if dss.scheme != SignatureScheme::ED25519 {
78 return Err(rustls::CertificateError::BadSignature.into())
79 }
80
81 let buf: Vec<u8> = cert.iter().copied().collect();
83
84 let Ok((_, cert)) = parse_x509_certificate(&buf) else {
86 error!(target: "net::tls::verify_ed25519_signature", "[net::tls] Failed parsing TLS certificate");
87 return Err(rustls::CertificateError::BadEncoding.into())
88 };
89
90 let Ok(public_key) = ed25519_compact::PublicKey::from_der(cert.public_key().raw) else {
91 error!(target: "net::tls::verify_ed25519_signature", "[net::tls] Failed parsing public key");
92 return Err(rustls::CertificateError::BadEncoding.into())
93 };
94
95 let Ok(signature) = ed25519_compact::Signature::from_slice(dss.signature()) else {
96 error!(target: "net::tls::verify_ed25519_signature", "[net::tls] Failed verifying signature");
97 return Err(rustls::CertificateError::BadSignature.into())
98 };
99
100 if let Err(e) = public_key.verify(message, &signature) {
101 error!(target: "net::tls::verify_ed25519_signature", "[net::tls] Failed verifying signature: {e}");
102 return Err(rustls::CertificateError::BadSignature.into())
103 }
104
105 Ok(HandshakeSignatureValid::assertion())
106}
107
108#[derive(Debug)]
109pub(crate) struct ServerCertificateVerifier;
110
111impl ServerCertVerifier for ServerCertificateVerifier {
112 fn verify_server_cert(
113 &self,
114 end_entity: &CertificateDer,
115 _intermediates: &[CertificateDer],
116 _server_name: &ServerName,
117 _ocsp_response: &[u8],
118 _now: UnixTime,
119 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
120 let buf: Vec<u8> = end_entity.iter().copied().collect();
122
123 let Ok((_, cert)) = parse_x509_certificate(&buf) else {
125 error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
126 return Err(rustls::CertificateError::BadEncoding.into())
127 };
128
129 validate_dnsname(&cert)?;
131
132 Ok(ServerCertVerified::assertion())
133 }
134
135 fn verify_tls12_signature(
136 &self,
137 _message: &[u8],
138 _cert: &CertificateDer,
139 _dss: &DigitallySignedStruct,
140 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
141 unreachable!()
142 }
143
144 fn verify_tls13_signature(
145 &self,
146 message: &[u8],
147 cert: &CertificateDer,
148 dss: &DigitallySignedStruct,
149 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
150 verify_ed25519_signature(message, cert, dss)
151 }
152
153 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
154 vec![SignatureScheme::ED25519]
155 }
156}
157
158#[derive(Debug)]
159pub(crate) struct ClientCertificateVerifier;
160
161impl ClientCertVerifier for ClientCertificateVerifier {
162 fn offer_client_auth(&self) -> bool {
163 true
164 }
165
166 fn client_auth_mandatory(&self) -> bool {
167 true
168 }
169
170 fn root_hint_subjects(&self) -> &[DistinguishedName] {
171 &[]
172 }
173
174 fn verify_client_cert(
175 &self,
176 end_entity: &CertificateDer,
177 _intermediates: &[CertificateDer],
178 _now: UnixTime,
179 ) -> std::result::Result<ClientCertVerified, rustls::Error> {
180 let buf: Vec<u8> = end_entity.iter().copied().collect();
182
183 let Ok((_, cert)) = parse_x509_certificate(&buf) else {
185 error!(target: "net::tls::verify_server_cert", "[net::tls] Failed parsing server TLS certificate");
186 return Err(rustls::CertificateError::BadEncoding.into())
187 };
188
189 validate_dnsname(&cert)?;
191
192 Ok(ClientCertVerified::assertion())
193 }
194
195 fn verify_tls12_signature(
196 &self,
197 _message: &[u8],
198 _cert: &CertificateDer,
199 _dss: &DigitallySignedStruct,
200 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
201 unreachable!()
202 }
203
204 fn verify_tls13_signature(
205 &self,
206 message: &[u8],
207 cert: &CertificateDer,
208 dss: &DigitallySignedStruct,
209 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
210 verify_ed25519_signature(message, cert, dss)
211 }
212
213 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
214 vec![SignatureScheme::ED25519]
215 }
216}
217
218pub(crate) fn generate_certificate() -> io::Result<(CertificateDer<'static>, PrivateKeyDer<'static>)>
221{
222 let Ok(keypair) = rcgen::KeyPair::generate_for(&rcgen::PKCS_ED25519) else {
223 return Err(io::Error::other("Failed to generate TLS keypair"))
224 };
225
226 let Ok(mut cert_params) = rcgen::CertificateParams::new(&[]) else {
227 return Err(io::Error::other("Failed to generate TLS params"))
228 };
229
230 cert_params.subject_alt_names =
231 vec![rcgen::SanType::DnsName(Ia5String::try_from(TLS_DNS_NAME).unwrap())];
232 cert_params.extended_key_usages = vec![
233 rcgen::ExtendedKeyUsagePurpose::ClientAuth,
234 rcgen::ExtendedKeyUsagePurpose::ServerAuth,
235 ];
236
237 let Ok(certificate) = cert_params.self_signed(&keypair) else {
238 return Err(io::Error::other("Failed to sign TLS certificate"))
239 };
240
241 let certificate = certificate.der().clone();
242 let keypair_der = keypair.serialize_der();
243
244 let Ok(secret_key_der) = PrivateKeyDer::try_from(keypair_der) else {
245 return Err(io::Error::other("Failed to deserialize DER TLS secret"))
246 };
247
248 Ok((certificate, secret_key_der))
249}
250
251pub struct TlsUpgrade {
252 server_config: Arc<ServerConfig>,
254 client_config: Arc<ClientConfig>,
256}
257
258impl TlsUpgrade {
259 pub async fn new() -> io::Result<Self> {
260 let (certificate, secret_key_der) = generate_certificate()?;
262
263 let client_cert_verifier = Arc::new(ClientCertificateVerifier {});
265 let server_config = Arc::new(
266 ServerConfig::builder_with_protocol_versions(&[&TLS13])
267 .with_client_cert_verifier(client_cert_verifier)
268 .with_single_cert(vec![certificate.clone()], secret_key_der.clone_key())
269 .unwrap(),
270 );
271
272 let server_cert_verifier = Arc::new(ServerCertificateVerifier {});
274 let client_config = Arc::new(
275 ClientConfig::builder_with_protocol_versions(&[&TLS13])
276 .dangerous()
277 .with_custom_certificate_verifier(server_cert_verifier)
278 .with_client_auth_cert(vec![certificate.clone()], secret_key_der)
279 .unwrap(),
280 );
281
282 Ok(Self { server_config, client_config })
283 }
284
285 pub async fn upgrade_dialer_tls<IO>(self, stream: IO) -> io::Result<TlsStream<IO>>
286 where
287 IO: super::PtStream,
288 {
289 let server_name = ServerName::try_from(TLS_DNS_NAME).unwrap();
290 let connector = TlsConnector::from(self.client_config);
291 let stream = connector.connect(server_name, stream).await?;
292 Ok(TlsStream::Client(stream))
293 }
294
295 pub async fn upgrade_listener_tcp_tls(
298 self,
299 listener: smol::net::TcpListener,
300 ) -> io::Result<(TlsAcceptor, smol::net::TcpListener)> {
301 Ok((TlsAcceptor::from(self.server_config), listener))
302 }
303}