/rust/git/checkouts/nss-rs-71e20fe79ef91440/9b94ca3/src/agentio.rs
Line | Count | Source |
1 | | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
2 | | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
3 | | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
4 | | // option. This file may not be copied, modified, or distributed |
5 | | // except according to those terms. |
6 | | |
7 | | #![expect( |
8 | | clippy::unwrap_used, |
9 | | reason = "Let's assume the use of `unwrap` was checked when the use of `unsafe` was reviewed." |
10 | | )] |
11 | | |
12 | | use std::{ |
13 | | cmp::min, |
14 | | convert::{TryFrom as _, TryInto as _}, |
15 | | fmt::{self, Display, Formatter}, |
16 | | mem, |
17 | | ops::Deref, |
18 | | os::raw::{c_uint, c_void}, |
19 | | pin::Pin, |
20 | | ptr::{null, null_mut}, |
21 | | }; |
22 | | |
23 | | use log::trace; |
24 | | |
25 | | use crate::{ |
26 | | SECStatus, |
27 | | constants::{ContentType, Epoch}, |
28 | | err::{Error, PR_SetError, Res, nspr}, |
29 | | null_safe_slice, |
30 | | p11::hex_with_len, |
31 | | prio, |
32 | | selfencrypt::hex, |
33 | | ssl::{self, PRInt32, PRInt64, PRIntn, PRUint8, PRUint16}, |
34 | | }; |
35 | | |
36 | | // Alias common types. |
37 | | type PrFd = *mut prio::PRFileDesc; |
38 | | type PrStatus = prio::PRStatus::Type; |
39 | | const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS; |
40 | | const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE; |
41 | | |
42 | | /// Convert a pinned, boxed object into a void pointer. |
43 | 23.3k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { |
44 | 23.3k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() |
45 | 23.3k | } nss_rs::agentio::as_c_void::<alloc::vec::Vec<nss_rs::agent::ResumptionToken>> Line | Count | Source | 43 | 1.28k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { | 44 | 1.28k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() | 45 | 1.28k | } |
nss_rs::agentio::as_c_void::<alloc::boxed::Box<alloc::rc::Rc<core::cell::RefCell<dyn nss_rs::ext::ExtensionHandler>>>> Line | Count | Source | 43 | 5.15k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { | 44 | 5.15k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() | 45 | 5.15k | } |
nss_rs::agentio::as_c_void::<core::option::Option<u8>> Line | Count | Source | 43 | 2.57k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { | 44 | 2.57k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() | 45 | 2.57k | } |
nss_rs::agentio::as_c_void::<nss_rs::agentio::RecordList> Line | Count | Source | 43 | 3.54k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { | 44 | 3.54k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() | 45 | 3.54k | } |
nss_rs::agentio::as_c_void::<nss_rs::agentio::AgentIo> Line | Count | Source | 43 | 2.57k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { | 44 | 2.57k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() | 45 | 2.57k | } |
nss_rs::agentio::as_c_void::<nss_rs::agent::ZeroRttCheckState> Line | Count | Source | 43 | 1.28k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { | 44 | 1.28k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() | 45 | 1.28k | } |
nss_rs::agentio::as_c_void::<nss_rs::secrets::Secrets> Line | Count | Source | 43 | 1.77k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { | 44 | 1.77k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() | 45 | 1.77k | } |
nss_rs::agentio::as_c_void::<bool> Line | Count | Source | 43 | 2.57k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { | 44 | 2.57k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() | 45 | 2.57k | } |
nss_rs::agentio::as_c_void::<i64> Line | Count | Source | 43 | 2.57k | pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void { | 44 | 2.57k | (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast() | 45 | 2.57k | } |
|
46 | | |
47 | | /// A slice of the output. |
48 | | #[derive(Default)] |
49 | | pub struct Record { |
50 | | pub epoch: Epoch, |
51 | | pub ct: ContentType, |
52 | | pub data: Vec<u8>, |
53 | | } |
54 | | |
55 | | impl Record { |
56 | | #[must_use] |
57 | 2.25k | pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self { |
58 | 2.25k | Self { |
59 | 2.25k | epoch, |
60 | 2.25k | ct, |
61 | 2.25k | data: data.to_vec(), |
62 | 2.25k | } |
63 | 2.25k | } |
64 | | |
65 | | // Shoves this record into the socket, returns true if blocked. |
66 | 484 | pub(crate) fn write(self, fd: *mut prio::PRFileDesc) -> Res<()> { |
67 | | unsafe { |
68 | 484 | ssl::SSL_RecordLayerData( |
69 | 484 | fd, |
70 | 484 | self.epoch, |
71 | 484 | ssl::SSLContentType::Type::from(self.ct), |
72 | 484 | self.data.as_ptr(), |
73 | 484 | c_uint::try_from(self.data.len())?, |
74 | | ) |
75 | | } |
76 | 484 | } |
77 | | } |
78 | | |
79 | | impl fmt::Debug for Record { |
80 | 0 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { |
81 | 0 | write!( |
82 | 0 | f, |
83 | | "Record {:?}:{:?} {}", |
84 | | self.epoch, |
85 | | self.ct, |
86 | 0 | hex_with_len(&self.data[..]) |
87 | | ) |
88 | 0 | } |
89 | | } |
90 | | |
91 | | #[derive(Debug, Default)] |
92 | | pub struct RecordList { |
93 | | records: Vec<Record>, |
94 | | } |
95 | | |
96 | | impl RecordList { |
97 | 2.25k | fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) { |
98 | 2.25k | self.records.push(Record::new(epoch, ct, data)); |
99 | 2.25k | } |
100 | | |
101 | 2.25k | unsafe extern "C" fn ingest( |
102 | 2.25k | _fd: *mut prio::PRFileDesc, |
103 | 2.25k | epoch: PRUint16, |
104 | 2.25k | ct: ssl::SSLContentType::Type, |
105 | 2.25k | data: *const PRUint8, |
106 | 2.25k | len: c_uint, |
107 | 2.25k | arg: *mut c_void, |
108 | 2.25k | ) -> SECStatus { |
109 | 2.25k | let Ok(epoch) = Epoch::try_from(epoch) else { |
110 | 0 | return ssl::SECFailure; |
111 | | }; |
112 | 2.25k | let Ok(ct) = ContentType::try_from(ct) else { |
113 | 0 | return ssl::SECFailure; |
114 | | }; |
115 | 2.25k | let Some(records) = (unsafe { arg.cast::<Self>().as_mut() }) else { |
116 | 0 | return ssl::SECFailure; |
117 | | }; |
118 | 2.25k | let slice = unsafe { null_safe_slice(data, len) }; |
119 | 2.25k | records.append(epoch, ct, slice); |
120 | 2.25k | ssl::SECSuccess |
121 | 2.25k | } |
122 | | |
123 | | /// Create a new record list. |
124 | 3.54k | pub(crate) fn setup(fd: *mut prio::PRFileDesc) -> Res<Pin<Box<Self>>> { |
125 | 3.54k | let mut records = Box::pin(Self::default()); |
126 | | unsafe { |
127 | 3.54k | ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records)) |
128 | 0 | }?; |
129 | 3.54k | Ok(records) |
130 | 3.54k | } |
131 | | } |
132 | | |
133 | | impl Deref for RecordList { |
134 | | type Target = [Record]; |
135 | 0 | fn deref(&self) -> &[Record] { |
136 | 0 | &self.records |
137 | 0 | } |
138 | | } |
139 | | |
140 | | pub struct RecordListIter(std::vec::IntoIter<Record>); |
141 | | |
142 | | impl Iterator for RecordListIter { |
143 | | type Item = Record; |
144 | 4.03k | fn next(&mut self) -> Option<Self::Item> { |
145 | 4.03k | self.0.next() |
146 | 4.03k | } |
147 | | } |
148 | | |
149 | | impl IntoIterator for RecordList { |
150 | | type Item = Record; |
151 | | type IntoIter = RecordListIter; |
152 | 1.77k | fn into_iter(self) -> Self::IntoIter { |
153 | 1.77k | RecordListIter(self.records.into_iter()) |
154 | 1.77k | } |
155 | | } |
156 | | |
157 | | pub struct AgentIoInputContext<'a> { |
158 | | input: &'a mut AgentIoInput, |
159 | | } |
160 | | |
161 | | impl Drop for AgentIoInputContext<'_> { |
162 | 805 | fn drop(&mut self) { |
163 | 805 | self.input.reset(); |
164 | 805 | } |
165 | | } |
166 | | |
167 | | #[derive(Debug, Default)] |
168 | | struct AgentIoInput { |
169 | | // input is data that is read by TLS. |
170 | | input: *const u8, |
171 | | // input_available is how much data is left for reading. |
172 | | available: usize, |
173 | | } |
174 | | |
175 | | impl AgentIoInput { |
176 | 805 | fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> { |
177 | 805 | assert!(self.input.is_null()); |
178 | 805 | self.input = input.as_ptr(); |
179 | 805 | self.available = input.len(); |
180 | 805 | trace!("AgentIoInput wrap {:p}", self.input); |
181 | 805 | AgentIoInputContext { input: self } |
182 | 805 | } |
183 | | |
184 | | // Take the data provided as input and provide it to the TLS stack. |
185 | 0 | fn read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize> { |
186 | 0 | let amount = min(self.available, count); |
187 | 0 | if amount == 0 { |
188 | 0 | unsafe { |
189 | 0 | PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0); |
190 | 0 | } |
191 | 0 | return Err(Error::NoDataAvailable); |
192 | 0 | } |
193 | | |
194 | | #[expect( |
195 | | clippy::disallowed_methods, |
196 | | reason = "We just checked if this was empty." |
197 | | )] |
198 | 0 | let src = unsafe { std::slice::from_raw_parts(self.input, amount) }; |
199 | 0 | trace!("[{self}] read {}", hex(src)); |
200 | 0 | let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) }; |
201 | 0 | dst.copy_from_slice(src); |
202 | 0 | self.input = self.input.wrapping_add(amount); |
203 | 0 | self.available -= amount; |
204 | 0 | Ok(amount) |
205 | 0 | } |
206 | | |
207 | 805 | fn reset(&mut self) { |
208 | 805 | trace!("[{self}] reset"); |
209 | 805 | self.input = null(); |
210 | 805 | self.available = 0; |
211 | 805 | } |
212 | | } |
213 | | |
214 | | impl Display for AgentIoInput { |
215 | 0 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { |
216 | 0 | write!(f, "AgentIoInput {:p}", self.input) |
217 | 0 | } |
218 | | } |
219 | | |
220 | | #[derive(Debug, Default)] |
221 | | pub struct AgentIo { |
222 | | // input collects the input we might provide to TLS. |
223 | | input: AgentIoInput, |
224 | | |
225 | | // output contains data that is written by TLS. |
226 | | output: Vec<u8>, |
227 | | } |
228 | | |
229 | | impl AgentIo { |
230 | 0 | unsafe fn borrow(fd: &mut PrFd) -> &mut Self { |
231 | 0 | unsafe { (**fd).secret.cast::<Self>().as_mut().unwrap() } |
232 | 0 | } |
233 | | |
234 | 805 | pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> { |
235 | 805 | assert_eq!(self.output.len(), 0); |
236 | 805 | self.input.wrap(input) |
237 | 805 | } |
238 | | |
239 | | // Stage output from TLS into the output buffer. |
240 | 0 | fn save_output(&mut self, buf: *const u8, count: usize) { |
241 | 0 | let slice = unsafe { null_safe_slice(buf, count) }; |
242 | 0 | trace!("[{self}] save output {}", hex(slice)); |
243 | 0 | self.output.extend_from_slice(slice); |
244 | 0 | } |
245 | | |
246 | 2.57k | pub fn take_output(&mut self) -> Vec<u8> { |
247 | 2.57k | trace!("[{self}] take output"); |
248 | 2.57k | mem::take(&mut self.output) |
249 | 2.57k | } |
250 | | } |
251 | | |
252 | | impl Display for AgentIo { |
253 | 0 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { |
254 | 0 | write!(f, "AgentIo") |
255 | 0 | } |
256 | | } |
257 | | |
258 | 2.57k | unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus { |
259 | | unsafe { |
260 | 2.57k | (*fd).secret = null_mut(); |
261 | 2.57k | if let Some(dtor) = (*fd).dtor { |
262 | 2.57k | dtor(fd); |
263 | 2.57k | } |
264 | | } |
265 | 2.57k | PR_SUCCESS |
266 | 2.57k | } |
267 | | |
268 | 0 | unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: PRInt32) -> PrStatus { |
269 | 0 | let io = unsafe { AgentIo::borrow(&mut fd) }; |
270 | 0 | let Ok(a) = usize::try_from(amount) else { |
271 | 0 | return PR_FAILURE; |
272 | | }; |
273 | 0 | match io.input.read_input(buf.cast(), a) { |
274 | 0 | Ok(_) => PR_SUCCESS, |
275 | 0 | Err(_) => PR_FAILURE, |
276 | | } |
277 | 0 | } |
278 | | |
279 | 0 | unsafe extern "C" fn agent_recv( |
280 | 0 | mut fd: PrFd, |
281 | 0 | buf: *mut c_void, |
282 | 0 | amount: PRInt32, |
283 | 0 | flags: PRIntn, |
284 | 0 | _timeout: prio::PRIntervalTime, |
285 | 0 | ) -> prio::PRInt32 { |
286 | 0 | let io = unsafe { AgentIo::borrow(&mut fd) }; |
287 | 0 | if flags != 0 { |
288 | 0 | return PR_FAILURE; |
289 | 0 | } |
290 | 0 | let Ok(a) = usize::try_from(amount) else { |
291 | 0 | return PR_FAILURE; |
292 | | }; |
293 | 0 | io.input.read_input(buf.cast(), a).map_or(PR_FAILURE, |v| { |
294 | 0 | prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE) |
295 | 0 | }) |
296 | 0 | } |
297 | | |
298 | 0 | unsafe extern "C" fn agent_write(mut fd: PrFd, buf: *const c_void, amount: PRInt32) -> PrStatus { |
299 | 0 | let io = unsafe { AgentIo::borrow(&mut fd) }; |
300 | 0 | usize::try_from(amount).map_or(PR_FAILURE, |a| { |
301 | 0 | io.save_output(buf.cast(), a); |
302 | 0 | amount |
303 | 0 | }) |
304 | 0 | } |
305 | | |
306 | 0 | unsafe extern "C" fn agent_send( |
307 | 0 | mut fd: PrFd, |
308 | 0 | buf: *const c_void, |
309 | 0 | amount: PRInt32, |
310 | 0 | flags: PRIntn, |
311 | 0 | _timeout: prio::PRIntervalTime, |
312 | 0 | ) -> PRInt32 { |
313 | 0 | let io = unsafe { AgentIo::borrow(&mut fd) }; |
314 | 0 | if flags != 0 { |
315 | 0 | return PR_FAILURE; |
316 | 0 | } |
317 | 0 | usize::try_from(amount).map_or(PR_FAILURE, |a| { |
318 | 0 | io.save_output(buf.cast(), a); |
319 | 0 | amount |
320 | 0 | }) |
321 | 0 | } |
322 | | |
323 | 0 | unsafe extern "C" fn agent_available(mut fd: PrFd) -> PRInt32 { |
324 | 0 | let io = unsafe { AgentIo::borrow(&mut fd) }; |
325 | 0 | io.input.available.try_into().unwrap_or(PR_FAILURE) |
326 | 0 | } |
327 | | |
328 | 0 | unsafe extern "C" fn agent_available64(mut fd: PrFd) -> PRInt64 { |
329 | 0 | let io = unsafe { AgentIo::borrow(&mut fd) }; |
330 | 0 | io.input |
331 | 0 | .available |
332 | 0 | .try_into() |
333 | 0 | .unwrap_or_else(|_| PR_FAILURE.into()) |
334 | 0 | } |
335 | | |
336 | | #[expect( |
337 | | clippy::cast_possible_truncation, |
338 | | reason = "Cast is safe because prio::PR_AF_INET is 2." |
339 | | )] |
340 | 4.35k | const unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus { |
341 | 4.35k | let Some(a) = (unsafe { addr.as_mut() }) else { |
342 | 0 | return PR_FAILURE; |
343 | | }; |
344 | | // Cast is safe because prio::PR_AF_INET is 2 |
345 | 4.35k | a.inet.family = prio::PR_AF_INET as PRUint16; |
346 | 4.35k | a.inet.port = 0; |
347 | 4.35k | a.inet.ip = 0; |
348 | 4.35k | PR_SUCCESS |
349 | 4.35k | } |
350 | | |
351 | 1.77k | const unsafe extern "C" fn agent_getsockopt( |
352 | 1.77k | _fd: PrFd, |
353 | 1.77k | opt: *mut prio::PRSocketOptionData, |
354 | 1.77k | ) -> PrStatus { |
355 | 1.77k | let Some(o) = (unsafe { opt.as_mut() }) else { |
356 | 0 | return PR_FAILURE; |
357 | | }; |
358 | 1.77k | if o.option == prio::PRSockOption_PR_SockOpt_Nonblocking { |
359 | 1.77k | o.value.non_blocking = 1; |
360 | 1.77k | return PR_SUCCESS; |
361 | 0 | } |
362 | 0 | PR_FAILURE |
363 | 1.77k | } |
364 | | |
365 | | pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods { |
366 | | file_type: prio::PRDescType::PR_DESC_LAYERED, |
367 | | close: Some(agent_close), |
368 | | read: Some(agent_read), |
369 | | write: Some(agent_write), |
370 | | available: Some(agent_available), |
371 | | available64: Some(agent_available64), |
372 | | fsync: None, |
373 | | seek: None, |
374 | | seek64: None, |
375 | | fileInfo: None, |
376 | | fileInfo64: None, |
377 | | writev: None, |
378 | | connect: None, |
379 | | accept: None, |
380 | | bind: None, |
381 | | listen: None, |
382 | | shutdown: None, |
383 | | recv: Some(agent_recv), |
384 | | send: Some(agent_send), |
385 | | recvfrom: None, |
386 | | sendto: None, |
387 | | poll: None, |
388 | | acceptread: None, |
389 | | transmitfile: None, |
390 | | getsockname: Some(agent_getname), |
391 | | getpeername: Some(agent_getname), |
392 | | reserved_fn_6: None, |
393 | | reserved_fn_5: None, |
394 | | getsocketoption: Some(agent_getsockopt), |
395 | | setsocketoption: None, |
396 | | sendfile: None, |
397 | | connectcontinue: None, |
398 | | reserved_fn_3: None, |
399 | | reserved_fn_2: None, |
400 | | reserved_fn_1: None, |
401 | | reserved_fn_0: None, |
402 | | }; |
403 | | |
404 | | #[cfg(test)] |
405 | | #[cfg_attr(coverage_nightly, coverage(off))] |
406 | | mod tests { |
407 | | use std::ptr::addr_of_mut; |
408 | | |
409 | | use super::*; |
410 | | |
411 | | #[test] |
412 | | fn ingest_errors() { |
413 | | let mut records = RecordList::default(); |
414 | | let data = [0u8]; |
415 | | unsafe { |
416 | | assert_eq!( |
417 | | RecordList::ingest( |
418 | | null_mut(), |
419 | | 999, |
420 | | 0x17, |
421 | | data.as_ptr(), |
422 | | 1, |
423 | | addr_of_mut!(records).cast() |
424 | | ), |
425 | | ssl::SECFailure |
426 | | ); |
427 | | assert_eq!( |
428 | | RecordList::ingest(null_mut(), 0, 0x17, data.as_ptr(), 1, null_mut()), |
429 | | ssl::SECFailure |
430 | | ); |
431 | | // Test invalid content type (value outside u8 range) |
432 | | assert_eq!( |
433 | | RecordList::ingest( |
434 | | null_mut(), |
435 | | 0, |
436 | | 256, |
437 | | data.as_ptr(), |
438 | | 1, |
439 | | addr_of_mut!(records).cast() |
440 | | ), |
441 | | ssl::SECFailure |
442 | | ); |
443 | | } |
444 | | } |
445 | | |
446 | | #[test] |
447 | | fn formatting() { |
448 | | let record = Record::new(Epoch::ApplicationData, 0x17, &[1, 2, 3]); |
449 | | let dbg = format!("{record:?}"); |
450 | | assert_eq!(&dbg[..6], "Record"); |
451 | | |
452 | | let input = AgentIoInput::default(); |
453 | | let disp = format!("{input}"); |
454 | | assert_eq!(&disp[..12], "AgentIoInput"); |
455 | | |
456 | | let io = AgentIo::default(); |
457 | | assert_eq!(format!("{io}"), "AgentIo"); |
458 | | } |
459 | | } |