Coverage Report

Created: 2025-11-16 06:37

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/ztunnel/src/tls/certificate.rs
Line
Count
Source
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use crate::identity::Identity;
16
use crate::tls::{Error, IdentityVerifier, OutboundConnector};
17
use base64::engine::general_purpose::STANDARD;
18
use bytes::Bytes;
19
use itertools::Itertools;
20
use std::{cmp, iter};
21
22
use rustls::client::Resumption;
23
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
24
25
use rustls::server::WebPkiClientVerifier;
26
use rustls::{ClientConfig, RootCertStore, ServerConfig, server};
27
use rustls_pemfile::Item;
28
use std::io::Cursor;
29
use std::str::FromStr;
30
use std::sync::Arc;
31
use std::time::{Duration, SystemTime, UNIX_EPOCH};
32
use tracing::warn;
33
34
use crate::tls;
35
use x509_parser::certificate::X509Certificate;
36
37
#[derive(Clone, Debug)]
38
pub struct Certificate {
39
    pub expiry: Expiration,
40
    pub der: CertificateDer<'static>,
41
}
42
43
#[derive(Clone, Debug)]
44
pub struct Expiration {
45
    pub not_before: SystemTime,
46
    pub not_after: SystemTime,
47
}
48
49
#[derive(Debug)]
50
pub struct WorkloadCertificate {
51
    /// cert is the leaf certificate
52
    pub cert: Certificate,
53
    /// chain is the entire trust chain, excluding the leaf and root
54
    pub chain: Vec<Certificate>,
55
    pub private_key: PrivateKeyDer<'static>,
56
57
    /// precomputed roots. This is used for verification
58
    root_store: Arc<RootCertStore>,
59
    /// original roots, used for debugging
60
    pub roots: Vec<Certificate>,
61
}
62
63
0
pub fn identity_from_connection(conn: &server::ServerConnection) -> Option<Identity> {
64
    use x509_parser::prelude::*;
65
0
    conn.peer_certificates()
66
0
        .and_then(|certs| certs.first())
67
0
        .and_then(|cert| match X509Certificate::from_der(cert) {
68
0
            Ok((_, a)) => Some(a),
69
0
            Err(e) => {
70
0
                warn!("invalid certificate: {e}");
71
0
                None
72
            }
73
0
        })
74
0
        .and_then(|cert| match identities(cert) {
75
0
            Ok(ids) => ids.into_iter().next(),
76
0
            Err(e) => {
77
0
                warn!("failed to extract identity: {}", e);
78
0
                None
79
            }
80
0
        })
81
0
}
82
83
0
pub fn identities(cert: X509Certificate) -> Result<Vec<Identity>, Error> {
84
    use x509_parser::prelude::*;
85
0
    let names = cert
86
0
        .subject_alternative_name()?
87
0
        .map(|x| &x.value.general_names);
88
89
0
    if let Some(names) = names {
90
0
        return Ok(names
91
0
            .iter()
92
0
            .filter_map(|n| {
93
0
                let id = match n {
94
0
                    GeneralName::URI(uri) => Identity::from_str(uri),
95
0
                    _ => return None,
96
                };
97
98
0
                match id {
99
0
                    Ok(id) => Some(id),
100
0
                    Err(err) => {
101
0
                        warn!("SAN {n} could not be parsed: {err}");
102
0
                        None
103
                    }
104
                }
105
0
            })
106
0
            .collect());
107
0
    }
108
0
    Ok(Vec::default())
109
0
}
110
111
impl Certificate {
112
    // TODO: I would love to parse this once, but ran into lifetime issues.
113
0
    fn parsed(&self) -> X509Certificate<'_> {
114
0
        x509_parser::parse_x509_certificate(&self.der)
115
0
            .expect("certificate was already parsed successfully before")
116
0
            .1
117
0
    }
118
119
0
    pub fn as_pem(&self) -> String {
120
0
        der_to_pem(&self.der, CERTIFICATE)
121
0
    }
122
123
0
    pub fn identity(&self) -> Option<Identity> {
124
0
        self.parsed()
125
0
            .subject_alternative_name()
126
0
            .ok()
127
0
            .flatten()
128
0
            .and_then(|ext| {
129
0
                ext.value
130
0
                    .general_names
131
0
                    .iter()
132
0
                    .filter_map(|n| match n {
133
0
                        x509_parser::extensions::GeneralName::URI(uri) => Some(uri),
134
0
                        _ => None,
135
0
                    })
136
0
                    .next()
137
0
            })
138
0
            .and_then(|san| Identity::from_str(san).ok())
139
0
    }
140
141
    #[cfg(test)]
142
    pub fn names(&self) -> Vec<String> {
143
        let reg = oid_registry::OidRegistry::default().with_x509();
144
145
        self.parsed()
146
            .subject
147
            .iter()
148
            .flat_map(|dn| {
149
                dn.iter().map(|x| {
150
                    reg.get(x.attr_type()).unwrap().sn().to_string() + "/" + x.as_str().unwrap()
151
                })
152
            })
153
            .chain(
154
                self.parsed()
155
                    .subject_alternative_name()
156
                    .ok()
157
                    .flatten()
158
                    .iter()
159
                    .flat_map(|ext| ext.value.general_names.iter().map(|n| n.to_string())),
160
            )
161
            .collect()
162
    }
163
164
0
    pub fn serial(&self) -> String {
165
0
        self.parsed().serial.to_string()
166
0
    }
167
168
0
    pub fn expiration(&self) -> Expiration {
169
0
        self.expiry.clone()
170
0
    }
171
}
172
173
0
fn expiration(cert: X509Certificate) -> Expiration {
174
0
    Expiration {
175
0
        not_before: UNIX_EPOCH
176
0
            + Duration::from_secs(
177
0
                cert.validity
178
0
                    .not_before
179
0
                    .timestamp()
180
0
                    .try_into()
181
0
                    .unwrap_or_default(),
182
0
            ),
183
0
        not_after: UNIX_EPOCH
184
0
            + Duration::from_secs(
185
0
                cert.validity
186
0
                    .not_after
187
0
                    .timestamp()
188
0
                    .try_into()
189
0
                    .unwrap_or_default(),
190
0
            ),
191
0
    }
192
0
}
193
194
0
fn parse_cert(mut cert: Vec<u8>) -> Result<Certificate, Error> {
195
0
    let mut reader = std::io::BufReader::new(Cursor::new(&mut cert));
196
0
    let parsed = rustls_pemfile::read_one(&mut reader)
197
0
        .map_err(|e| Error::CertificateParseError(e.to_string()))?
198
0
        .ok_or_else(|| Error::CertificateParseError("no certificate".to_string()))?;
199
0
    let Item::X509Certificate(der) = parsed else {
200
0
        return Err(Error::CertificateParseError("no certificate".to_string()));
201
    };
202
203
0
    let (_, cert) = x509_parser::parse_x509_certificate(&der)?;
204
0
    Ok(Certificate {
205
0
        der: der.clone(),
206
0
        expiry: expiration(cert),
207
0
    })
208
0
}
209
210
0
fn parse_cert_multi(mut cert: &[u8]) -> Result<Vec<Certificate>, Error> {
211
0
    let mut reader = std::io::BufReader::new(Cursor::new(&mut cert));
212
0
    let parsed: Result<Vec<_>, _> = rustls_pemfile::read_all(&mut reader).collect();
213
0
    parsed
214
0
        .map_err(|e| Error::CertificateParseError(e.to_string()))?
215
0
        .into_iter()
216
0
        .map(|p| {
217
0
            let Item::X509Certificate(der) = p else {
218
0
                return Err(Error::CertificateParseError("no certificate".to_string()));
219
            };
220
0
            let (_, cert) = x509_parser::parse_x509_certificate(&der)?;
221
0
            Ok(Certificate {
222
0
                der: der.clone(),
223
0
                expiry: expiration(cert),
224
0
            })
225
0
        })
226
0
        .collect()
227
0
}
228
229
0
fn parse_key(mut key: &[u8]) -> Result<PrivateKeyDer<'static>, Error> {
230
0
    let mut reader = std::io::BufReader::new(Cursor::new(&mut key));
231
0
    let parsed = rustls_pemfile::read_one(&mut reader)
232
0
        .map_err(|e| Error::CertificateParseError(e.to_string()))?
233
0
        .ok_or_else(|| Error::CertificateParseError("no key".to_string()))?;
234
0
    match parsed {
235
0
        Item::Pkcs8Key(c) => Ok(PrivateKeyDer::Pkcs8(c)),
236
0
        _ => Err(Error::CertificateParseError("no key".to_string())),
237
    }
238
0
}
239
240
impl WorkloadCertificate {
241
0
    pub fn new(key: &[u8], cert: &[u8], chain: Vec<&[u8]>) -> Result<WorkloadCertificate, Error> {
242
0
        let cert = parse_cert(cert.to_vec())?;
243
244
        // The Istio API does something pretty unhelpful, by providing a single chain of certs.
245
        // The last one is the root. However, there may be multiple roots concatenated in that last cert,
246
        // so we will need to split them.
247
0
        let Some(raw_root) = chain.last() else {
248
0
            return Err(Error::InvalidRootCert(
249
0
                "no root certificate present".to_string(),
250
0
            ));
251
        };
252
0
        let roots = parse_cert_multi(raw_root)?;
253
0
        let chain = chain[..cmp::max(0, chain.len() - 1)]
254
0
            .iter()
255
0
            .map(|x| x.to_vec())
256
0
            .map(parse_cert)
257
0
            .collect::<Result<Vec<_>, _>>()?;
258
0
        let key: PrivateKeyDer = parse_key(key)?;
259
260
0
        let mut roots_store = RootCertStore::empty();
261
0
        let (_valid, invalid) =
262
0
            roots_store.add_parsable_certificates(roots.iter().map(|c| c.der.clone()));
263
0
        if invalid > 0 {
264
0
            tracing::warn!("warning: found {invalid} invalid root certs");
265
0
        }
266
0
        Ok(WorkloadCertificate {
267
0
            cert,
268
0
            chain,
269
0
            private_key: key,
270
0
            roots,
271
0
            root_store: Arc::new(roots_store),
272
0
        })
273
0
    }
274
275
0
    pub fn identity(&self) -> Option<Identity> {
276
0
        self.cert.identity()
277
0
    }
278
279
    // TODO: can we precompute some or all of this?
280
281
0
    pub(in crate::tls) fn cert_and_intermediates_der(&self) -> Vec<CertificateDer<'static>> {
282
0
        std::iter::once(self.cert.der.clone())
283
0
            .chain(self.chain.iter().map(|x| x.der.clone()))
284
0
            .collect()
285
0
    }
286
287
0
    pub fn cert_and_intermediates(&self) -> Vec<Certificate> {
288
0
        std::iter::once(self.cert.clone())
289
0
            .chain(self.chain.clone())
290
0
            .collect()
291
0
    }
292
293
0
    pub fn full_chain_and_roots(&self) -> Vec<String> {
294
0
        self.cert_and_intermediates()
295
0
            .into_iter()
296
0
            .map(|c| c.as_pem())
297
0
            .chain(iter::once(self.roots.iter().map(|c| c.as_pem()).join("\n")))
298
0
            .collect()
299
0
    }
300
301
0
    pub fn server_config(&self) -> Result<ServerConfig, Error> {
302
0
        let td = self.cert.identity().map(|i| match i {
303
0
            Identity::Spiffe { trust_domain, .. } => trust_domain,
304
0
        });
305
0
        let raw_client_cert_verifier = WebPkiClientVerifier::builder_with_provider(
306
0
            self.root_store.clone(),
307
0
            crate::tls::lib::provider(),
308
        )
309
0
        .build()?;
310
311
0
        let client_cert_verifier =
312
0
            crate::tls::workload::TrustDomainVerifier::new(raw_client_cert_verifier, td);
313
0
        let mut sc = ServerConfig::builder_with_provider(crate::tls::lib::provider())
314
0
            .with_protocol_versions(tls::TLS_VERSIONS)
315
0
            .expect("server config must be valid")
316
0
            .with_client_cert_verifier(client_cert_verifier)
317
0
            .with_single_cert(
318
0
                self.cert_and_intermediates_der(),
319
0
                self.private_key.clone_key(),
320
0
            )?;
321
0
        sc.alpn_protocols = vec![b"h2".into()];
322
0
        Ok(sc)
323
0
    }
324
325
0
    pub fn outbound_connector(&self, identity: Vec<Identity>) -> Result<OutboundConnector, Error> {
326
0
        let roots = self.root_store.clone();
327
0
        let verifier = IdentityVerifier { roots, identity };
328
0
        let mut cc = ClientConfig::builder_with_provider(crate::tls::lib::provider())
329
0
            .with_protocol_versions(tls::TLS_VERSIONS)
330
0
            .expect("client config must be valid")
331
0
            .dangerous() // Customer verifier is requires "dangerous" opt-in
332
0
            .with_custom_certificate_verifier(Arc::new(verifier))
333
0
            .with_client_auth_cert(
334
0
                self.cert_and_intermediates_der(),
335
0
                self.private_key.clone_key(),
336
0
            )?;
337
0
        cc.alpn_protocols = vec![b"h2".into()];
338
0
        cc.resumption = Resumption::disabled();
339
0
        cc.enable_sni = false;
340
0
        Ok(OutboundConnector {
341
0
            client_config: Arc::new(cc),
342
0
        })
343
0
    }
344
345
0
    pub fn dump_chain(&self) -> Bytes {
346
0
        self.chain.iter().map(|c| c.as_pem()).join("\n").into()
347
0
    }
348
349
0
    pub fn is_expired(&self) -> bool {
350
0
        SystemTime::now() > self.cert.expiry.not_after
351
0
    }
352
353
0
    pub fn refresh_at(&self) -> SystemTime {
354
0
        let expiry = &self.cert.expiry;
355
0
        match expiry.not_after.duration_since(expiry.not_before) {
356
0
            Ok(valid_for) => expiry.not_before + valid_for / 2,
357
0
            Err(_) => expiry.not_after,
358
        }
359
0
    }
360
361
0
    pub fn get_duration_until_refresh(&self) -> Duration {
362
0
        let expiry = &self.cert.expiry;
363
0
        let halflife = expiry
364
0
            .not_after
365
0
            .duration_since(expiry.not_before)
366
0
            .unwrap_or_else(|_| std::time::Duration::from_secs(0))
367
            / 2;
368
        // If now() is earlier than not_before, we need to refresh ASAP, so return 0.
369
0
        let elapsed = SystemTime::now()
370
0
            .duration_since(expiry.not_before)
371
0
            .unwrap_or(halflife);
372
0
        halflife
373
0
            .checked_sub(elapsed)
374
0
            .unwrap_or_else(|| Duration::from_secs(0))
375
0
    }
376
}
377
378
const CERTIFICATE: &str = "CERTIFICATE";
379
380
/// Converts DER encoded data to PEM.
381
0
fn der_to_pem(der: &[u8], label: &str) -> String {
382
    use base64::Engine;
383
0
    let mut ans = String::from("-----BEGIN ");
384
0
    ans.push_str(label);
385
0
    ans.push_str("-----\n");
386
0
    let b64 = STANDARD.encode(der);
387
0
    let line_length = 60;
388
0
    for chunk in b64.chars().collect::<Vec<_>>().chunks(line_length) {
389
0
        ans.extend(chunk);
390
0
        ans.push('\n');
391
0
    }
392
0
    ans.push_str("-----END ");
393
0
    ans.push_str(label);
394
0
    ans.push_str("-----\n");
395
0
    ans
396
0
}
397
398
#[cfg(test)]
399
mod test {
400
    use crate::identity::Identity;
401
    use crate::test_helpers::helpers;
402
    use crate::tls::WorkloadCertificate;
403
    use crate::tls::mock::{TEST_ROOT, TEST_ROOT_KEY, TEST_ROOT2, TEST_ROOT2_KEY, TestIdentity};
404
405
    use std::str::FromStr;
406
    use std::sync::Arc;
407
    use std::time::Duration;
408
    use std::time::SystemTime;
409
    use tokio::io::AsyncReadExt;
410
    use tokio::io::AsyncWriteExt;
411
    use tokio::net::TcpListener;
412
    use tokio::net::TcpStream;
413
    use tokio_rustls::TlsAcceptor;
414
415
    #[tokio::test]
416
    async fn multi_root() {
417
        helpers::initialize_telemetry();
418
        let id = Identity::from_str("spiffe://td/ns/n/sa/a").unwrap();
419
        // Joined root
420
        let mut joined = TEST_ROOT.to_vec();
421
        joined.push(b'\n');
422
        joined.extend(TEST_ROOT2);
423
424
        // Generate key+cert signed by root1
425
        let (key, cert) = crate::tls::mock::generate_test_certs_with_root(
426
            &TestIdentity::Identity(id.clone()),
427
            SystemTime::now(),
428
            SystemTime::now() + Duration::from_secs(60),
429
            None,
430
            TEST_ROOT_KEY,
431
        );
432
        let cert1 =
433
            WorkloadCertificate::new(key.as_bytes(), cert.as_bytes(), vec![&joined]).unwrap();
434
435
        // Generate key+cert signed by root2
436
        let (key, cert) = crate::tls::mock::generate_test_certs_with_root(
437
            &TestIdentity::Identity(id.clone()),
438
            SystemTime::now(),
439
            SystemTime::now() + Duration::from_secs(60),
440
            None,
441
            TEST_ROOT2_KEY,
442
        );
443
        let cert2 =
444
            WorkloadCertificate::new(key.as_bytes(), cert.as_bytes(), vec![&joined]).unwrap();
445
446
        // Do a simple handshake between them; we should be able to accept the trusted root
447
        let server = cert1.server_config().unwrap();
448
        let tls = TlsAcceptor::from(Arc::new(server));
449
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
450
        let addr = listener.local_addr().unwrap();
451
        tokio::task::spawn(async move {
452
            let (stream, _) = listener.accept().await.unwrap();
453
            let mut tls = tls.accept(stream).await.unwrap();
454
            let _ = tls.write(b"serv").await.unwrap();
455
        });
456
457
        let stream = TcpStream::connect(addr).await.unwrap();
458
        let client = cert2.outbound_connector(vec![id]).unwrap();
459
        let mut tls = client.connect(stream).await.unwrap();
460
461
        let _ = tls.write(b"hi").await.unwrap();
462
        let mut buf = [0u8; 4];
463
        tls.read_exact(&mut buf).await.unwrap();
464
        assert_eq!(&buf, b"serv");
465
    }
466
}