/rust/git/checkouts/nss-rs-71e20fe79ef91440/9b94ca3/src/secrets.rs
Line | Count | Source |
1 | | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
2 | | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
3 | | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
4 | | // option. This file may not be copied, modified, or distributed |
5 | | // except according to those terms. |
6 | | |
7 | | #![allow(clippy::unwrap_used)] // Let's assume the use of `unwrap` was checked when the use of `unsafe` was reviewed. |
8 | | |
9 | | use std::{convert::TryFrom as _, mem, os::raw::c_void, pin::Pin}; |
10 | | |
11 | | use enum_map::EnumMap; |
12 | | use log::debug; |
13 | | use strum::FromRepr; |
14 | | |
15 | | use crate::{ |
16 | | agentio::as_c_void, |
17 | | constants::Epoch, |
18 | | err::Res, |
19 | | p11::{PK11_ReferenceSymKey, PK11SymKey, SymKey}, |
20 | | prio::PRFileDesc, |
21 | | ssl::{SSLSecretCallback, SSLSecretDirection}, |
22 | | }; |
23 | | |
24 | | experimental_api!(SSL_SecretCallback( |
25 | | fd: *mut PRFileDesc, |
26 | | cb: SSLSecretCallback, |
27 | | arg: *mut c_void, |
28 | | )); |
29 | | |
30 | | #[derive(Clone, Copy, Debug, FromRepr)] |
31 | | // Use i32 for Windows MSVC, unless it is MinGW (see |
32 | | // https://bugzilla.mozilla.org/show_bug.cgi?id=1960482). All other platforms |
33 | | // use u32. |
34 | | #[cfg_attr(all(windows, not(target_env = "gnu")), repr(i32))] |
35 | | #[cfg_attr(not(all(windows, not(target_env = "gnu"))), repr(u32))] |
36 | | pub enum SecretDirection { |
37 | | Read = SSLSecretDirection::ssl_secret_read, |
38 | | Write = SSLSecretDirection::ssl_secret_write, |
39 | | } |
40 | | |
41 | | impl From<SSLSecretDirection::Type> for SecretDirection { |
42 | 1.93k | fn from(dir: SSLSecretDirection::Type) -> Self { |
43 | 1.93k | Self::from_repr(dir).expect("Invalid secret direction") |
44 | 1.93k | } |
45 | | } |
46 | | |
47 | | #[derive(Debug, Default)] |
48 | | pub struct DirectionalSecrets { |
49 | | secrets: EnumMap<Epoch, SymKey>, |
50 | | } |
51 | | |
52 | | impl DirectionalSecrets { |
53 | 1.93k | fn put(&mut self, epoch: Epoch, key: SymKey) { |
54 | 1.93k | debug_assert_ne!(epoch, Epoch::Initial); |
55 | 1.93k | self.secrets[epoch] = key; |
56 | 1.93k | } |
57 | | |
58 | 484 | pub fn has(&self, epoch: Epoch) -> bool { |
59 | 484 | !self.secrets[epoch].is_null() |
60 | 484 | } |
61 | | |
62 | 1.45k | pub fn take(&mut self, epoch: Epoch) -> Option<SymKey> { |
63 | 1.45k | if self.secrets[epoch].is_null() { |
64 | 0 | None |
65 | | } else { |
66 | 1.45k | Some(mem::take(&mut self.secrets[epoch])) |
67 | | } |
68 | 1.45k | } |
69 | | } |
70 | | |
71 | | #[derive(Debug, Default)] |
72 | | pub struct Secrets { |
73 | | r: DirectionalSecrets, |
74 | | w: DirectionalSecrets, |
75 | | } |
76 | | |
77 | | impl Secrets { |
78 | 1.93k | unsafe extern "C" fn secret_available( |
79 | 1.93k | _fd: *mut PRFileDesc, |
80 | 1.93k | epoch: u16, |
81 | 1.93k | dir: SSLSecretDirection::Type, |
82 | 1.93k | secret: *mut PK11SymKey, |
83 | 1.93k | arg: *mut c_void, |
84 | 1.93k | ) { |
85 | 1.93k | let Ok(epoch) = Epoch::try_from(epoch) else { |
86 | 0 | debug_assert!(false, "Invalid epoch"); |
87 | | // Don't touch secrets. |
88 | 0 | return; |
89 | | }; |
90 | 1.93k | let Some(secrets) = (unsafe { arg.cast::<Self>().as_mut() }) else { |
91 | 0 | debug_assert!(false, "No secrets"); |
92 | 0 | return; |
93 | | }; |
94 | 1.93k | secrets.put_raw(epoch, dir, secret); |
95 | 1.93k | } |
96 | | |
97 | 1.93k | fn put_raw(&mut self, epoch: Epoch, dir: SSLSecretDirection::Type, key_ptr: *mut PK11SymKey) { |
98 | 1.93k | let key_ptr = unsafe { PK11_ReferenceSymKey(key_ptr) }; |
99 | 1.93k | let key = SymKey::from_ptr(key_ptr).expect("NSS shouldn't be passing out NULL secrets"); |
100 | 1.93k | self.put(SecretDirection::from(dir), epoch, key); |
101 | 1.93k | } |
102 | | |
103 | 1.93k | fn put(&mut self, dir: SecretDirection, epoch: Epoch, key: SymKey) { |
104 | 1.93k | debug!("{dir:?} secret available for {epoch:?}: {key:?}"); |
105 | 1.93k | let keys = match dir { |
106 | 968 | SecretDirection::Read => &mut self.r, |
107 | 968 | SecretDirection::Write => &mut self.w, |
108 | | }; |
109 | 1.93k | keys.put(epoch, key); |
110 | 1.93k | } |
111 | | } |
112 | | |
113 | | #[derive(Debug)] |
114 | | pub struct SecretHolder { |
115 | | secrets: Pin<Box<Secrets>>, |
116 | | } |
117 | | |
118 | | impl SecretHolder { |
119 | | /// This registers with NSS. The lifetime of this object needs to match the lifetime |
120 | | /// of the connection, or bad things might happen. |
121 | 1.77k | pub fn register(&mut self, fd: *mut PRFileDesc) -> Res<()> { |
122 | 1.77k | let p = as_c_void(&mut self.secrets); |
123 | 1.77k | unsafe { SSL_SecretCallback(fd, Some(Secrets::secret_available), p) } |
124 | 1.77k | } |
125 | | |
126 | 484 | pub fn has(&self, epoch: Epoch) -> bool { |
127 | 484 | self.secrets.r.has(epoch) |
128 | 484 | } |
129 | | |
130 | 484 | pub fn take_read(&mut self, epoch: Epoch) -> Option<SymKey> { |
131 | 484 | self.secrets.r.take(epoch) |
132 | 484 | } |
133 | | |
134 | 968 | pub fn take_write(&mut self, epoch: Epoch) -> Option<SymKey> { |
135 | 968 | self.secrets.w.take(epoch) |
136 | 968 | } |
137 | | } |
138 | | |
139 | | impl Default for SecretHolder { |
140 | 2.57k | fn default() -> Self { |
141 | 2.57k | Self { |
142 | 2.57k | secrets: Box::pin(Secrets::default()), |
143 | 2.57k | } |
144 | 2.57k | } |
145 | | } |