Coverage Report

Created: 2026-05-18 06:32

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/git/checkouts/nss-rs-71e20fe79ef91440/9b94ca3/src/agentio.rs
Line
Count
Source
1
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4
// option. This file may not be copied, modified, or distributed
5
// except according to those terms.
6
7
#![expect(
8
    clippy::unwrap_used,
9
    reason = "Let's assume the use of `unwrap` was checked when the use of `unsafe` was reviewed."
10
)]
11
12
use std::{
13
    cmp::min,
14
    convert::{TryFrom as _, TryInto as _},
15
    fmt::{self, Display, Formatter},
16
    mem,
17
    ops::Deref,
18
    os::raw::{c_uint, c_void},
19
    pin::Pin,
20
    ptr::{null, null_mut},
21
};
22
23
use log::trace;
24
25
use crate::{
26
    SECStatus,
27
    constants::{ContentType, Epoch},
28
    err::{Error, PR_SetError, Res, nspr},
29
    null_safe_slice,
30
    p11::hex_with_len,
31
    prio,
32
    selfencrypt::hex,
33
    ssl::{self, PRInt32, PRInt64, PRIntn, PRUint8, PRUint16},
34
};
35
36
// Alias common types.
37
type PrFd = *mut prio::PRFileDesc;
38
type PrStatus = prio::PRStatus::Type;
39
const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS;
40
const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE;
41
42
/// Convert a pinned, boxed object into a void pointer.
43
23.3k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
23.3k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
23.3k
}
nss_rs::agentio::as_c_void::<alloc::vec::Vec<nss_rs::agent::ResumptionToken>>
Line
Count
Source
43
1.28k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
1.28k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
1.28k
}
nss_rs::agentio::as_c_void::<alloc::boxed::Box<alloc::rc::Rc<core::cell::RefCell<dyn nss_rs::ext::ExtensionHandler>>>>
Line
Count
Source
43
5.15k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
5.15k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
5.15k
}
nss_rs::agentio::as_c_void::<core::option::Option<u8>>
Line
Count
Source
43
2.57k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
2.57k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
2.57k
}
nss_rs::agentio::as_c_void::<nss_rs::agentio::RecordList>
Line
Count
Source
43
3.54k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
3.54k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
3.54k
}
nss_rs::agentio::as_c_void::<nss_rs::agentio::AgentIo>
Line
Count
Source
43
2.57k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
2.57k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
2.57k
}
nss_rs::agentio::as_c_void::<nss_rs::agent::ZeroRttCheckState>
Line
Count
Source
43
1.28k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
1.28k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
1.28k
}
nss_rs::agentio::as_c_void::<nss_rs::secrets::Secrets>
Line
Count
Source
43
1.77k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
1.77k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
1.77k
}
nss_rs::agentio::as_c_void::<bool>
Line
Count
Source
43
2.57k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
2.57k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
2.57k
}
nss_rs::agentio::as_c_void::<i64>
Line
Count
Source
43
2.57k
pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
44
2.57k
    (std::ptr::from_mut::<T>(Pin::into_inner(pin.as_mut()))).cast()
45
2.57k
}
46
47
/// A slice of the output.
48
#[derive(Default)]
49
pub struct Record {
50
    pub epoch: Epoch,
51
    pub ct: ContentType,
52
    pub data: Vec<u8>,
53
}
54
55
impl Record {
56
    #[must_use]
57
2.25k
    pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self {
58
2.25k
        Self {
59
2.25k
            epoch,
60
2.25k
            ct,
61
2.25k
            data: data.to_vec(),
62
2.25k
        }
63
2.25k
    }
64
65
    // Shoves this record into the socket, returns true if blocked.
66
484
    pub(crate) fn write(self, fd: *mut prio::PRFileDesc) -> Res<()> {
67
        unsafe {
68
484
            ssl::SSL_RecordLayerData(
69
484
                fd,
70
484
                self.epoch,
71
484
                ssl::SSLContentType::Type::from(self.ct),
72
484
                self.data.as_ptr(),
73
484
                c_uint::try_from(self.data.len())?,
74
            )
75
        }
76
484
    }
77
}
78
79
impl fmt::Debug for Record {
80
0
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
81
0
        write!(
82
0
            f,
83
            "Record {:?}:{:?} {}",
84
            self.epoch,
85
            self.ct,
86
0
            hex_with_len(&self.data[..])
87
        )
88
0
    }
89
}
90
91
#[derive(Debug, Default)]
92
pub struct RecordList {
93
    records: Vec<Record>,
94
}
95
96
impl RecordList {
97
2.25k
    fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) {
98
2.25k
        self.records.push(Record::new(epoch, ct, data));
99
2.25k
    }
100
101
2.25k
    unsafe extern "C" fn ingest(
102
2.25k
        _fd: *mut prio::PRFileDesc,
103
2.25k
        epoch: PRUint16,
104
2.25k
        ct: ssl::SSLContentType::Type,
105
2.25k
        data: *const PRUint8,
106
2.25k
        len: c_uint,
107
2.25k
        arg: *mut c_void,
108
2.25k
    ) -> SECStatus {
109
2.25k
        let Ok(epoch) = Epoch::try_from(epoch) else {
110
0
            return ssl::SECFailure;
111
        };
112
2.25k
        let Ok(ct) = ContentType::try_from(ct) else {
113
0
            return ssl::SECFailure;
114
        };
115
2.25k
        let Some(records) = (unsafe { arg.cast::<Self>().as_mut() }) else {
116
0
            return ssl::SECFailure;
117
        };
118
2.25k
        let slice = unsafe { null_safe_slice(data, len) };
119
2.25k
        records.append(epoch, ct, slice);
120
2.25k
        ssl::SECSuccess
121
2.25k
    }
122
123
    /// Create a new record list.
124
3.54k
    pub(crate) fn setup(fd: *mut prio::PRFileDesc) -> Res<Pin<Box<Self>>> {
125
3.54k
        let mut records = Box::pin(Self::default());
126
        unsafe {
127
3.54k
            ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records))
128
0
        }?;
129
3.54k
        Ok(records)
130
3.54k
    }
131
}
132
133
impl Deref for RecordList {
134
    type Target = [Record];
135
0
    fn deref(&self) -> &[Record] {
136
0
        &self.records
137
0
    }
138
}
139
140
pub struct RecordListIter(std::vec::IntoIter<Record>);
141
142
impl Iterator for RecordListIter {
143
    type Item = Record;
144
4.03k
    fn next(&mut self) -> Option<Self::Item> {
145
4.03k
        self.0.next()
146
4.03k
    }
147
}
148
149
impl IntoIterator for RecordList {
150
    type Item = Record;
151
    type IntoIter = RecordListIter;
152
1.77k
    fn into_iter(self) -> Self::IntoIter {
153
1.77k
        RecordListIter(self.records.into_iter())
154
1.77k
    }
155
}
156
157
pub struct AgentIoInputContext<'a> {
158
    input: &'a mut AgentIoInput,
159
}
160
161
impl Drop for AgentIoInputContext<'_> {
162
805
    fn drop(&mut self) {
163
805
        self.input.reset();
164
805
    }
165
}
166
167
#[derive(Debug, Default)]
168
struct AgentIoInput {
169
    // input is data that is read by TLS.
170
    input: *const u8,
171
    // input_available is how much data is left for reading.
172
    available: usize,
173
}
174
175
impl AgentIoInput {
176
805
    fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
177
805
        assert!(self.input.is_null());
178
805
        self.input = input.as_ptr();
179
805
        self.available = input.len();
180
805
        trace!("AgentIoInput wrap {:p}", self.input);
181
805
        AgentIoInputContext { input: self }
182
805
    }
183
184
    // Take the data provided as input and provide it to the TLS stack.
185
0
    fn read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize> {
186
0
        let amount = min(self.available, count);
187
0
        if amount == 0 {
188
0
            unsafe {
189
0
                PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0);
190
0
            }
191
0
            return Err(Error::NoDataAvailable);
192
0
        }
193
194
        #[expect(
195
            clippy::disallowed_methods,
196
            reason = "We just checked if this was empty."
197
        )]
198
0
        let src = unsafe { std::slice::from_raw_parts(self.input, amount) };
199
0
        trace!("[{self}] read {}", hex(src));
200
0
        let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) };
201
0
        dst.copy_from_slice(src);
202
0
        self.input = self.input.wrapping_add(amount);
203
0
        self.available -= amount;
204
0
        Ok(amount)
205
0
    }
206
207
805
    fn reset(&mut self) {
208
805
        trace!("[{self}] reset");
209
805
        self.input = null();
210
805
        self.available = 0;
211
805
    }
212
}
213
214
impl Display for AgentIoInput {
215
0
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
216
0
        write!(f, "AgentIoInput {:p}", self.input)
217
0
    }
218
}
219
220
#[derive(Debug, Default)]
221
pub struct AgentIo {
222
    // input collects the input we might provide to TLS.
223
    input: AgentIoInput,
224
225
    // output contains data that is written by TLS.
226
    output: Vec<u8>,
227
}
228
229
impl AgentIo {
230
0
    unsafe fn borrow(fd: &mut PrFd) -> &mut Self {
231
0
        unsafe { (**fd).secret.cast::<Self>().as_mut().unwrap() }
232
0
    }
233
234
805
    pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
235
805
        assert_eq!(self.output.len(), 0);
236
805
        self.input.wrap(input)
237
805
    }
238
239
    // Stage output from TLS into the output buffer.
240
0
    fn save_output(&mut self, buf: *const u8, count: usize) {
241
0
        let slice = unsafe { null_safe_slice(buf, count) };
242
0
        trace!("[{self}] save output {}", hex(slice));
243
0
        self.output.extend_from_slice(slice);
244
0
    }
245
246
2.57k
    pub fn take_output(&mut self) -> Vec<u8> {
247
2.57k
        trace!("[{self}] take output");
248
2.57k
        mem::take(&mut self.output)
249
2.57k
    }
250
}
251
252
impl Display for AgentIo {
253
0
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
254
0
        write!(f, "AgentIo")
255
0
    }
256
}
257
258
2.57k
unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus {
259
    unsafe {
260
2.57k
        (*fd).secret = null_mut();
261
2.57k
        if let Some(dtor) = (*fd).dtor {
262
2.57k
            dtor(fd);
263
2.57k
        }
264
    }
265
2.57k
    PR_SUCCESS
266
2.57k
}
267
268
0
unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: PRInt32) -> PrStatus {
269
0
    let io = unsafe { AgentIo::borrow(&mut fd) };
270
0
    let Ok(a) = usize::try_from(amount) else {
271
0
        return PR_FAILURE;
272
    };
273
0
    match io.input.read_input(buf.cast(), a) {
274
0
        Ok(_) => PR_SUCCESS,
275
0
        Err(_) => PR_FAILURE,
276
    }
277
0
}
278
279
0
unsafe extern "C" fn agent_recv(
280
0
    mut fd: PrFd,
281
0
    buf: *mut c_void,
282
0
    amount: PRInt32,
283
0
    flags: PRIntn,
284
0
    _timeout: prio::PRIntervalTime,
285
0
) -> prio::PRInt32 {
286
0
    let io = unsafe { AgentIo::borrow(&mut fd) };
287
0
    if flags != 0 {
288
0
        return PR_FAILURE;
289
0
    }
290
0
    let Ok(a) = usize::try_from(amount) else {
291
0
        return PR_FAILURE;
292
    };
293
0
    io.input.read_input(buf.cast(), a).map_or(PR_FAILURE, |v| {
294
0
        prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE)
295
0
    })
296
0
}
297
298
0
unsafe extern "C" fn agent_write(mut fd: PrFd, buf: *const c_void, amount: PRInt32) -> PrStatus {
299
0
    let io = unsafe { AgentIo::borrow(&mut fd) };
300
0
    usize::try_from(amount).map_or(PR_FAILURE, |a| {
301
0
        io.save_output(buf.cast(), a);
302
0
        amount
303
0
    })
304
0
}
305
306
0
unsafe extern "C" fn agent_send(
307
0
    mut fd: PrFd,
308
0
    buf: *const c_void,
309
0
    amount: PRInt32,
310
0
    flags: PRIntn,
311
0
    _timeout: prio::PRIntervalTime,
312
0
) -> PRInt32 {
313
0
    let io = unsafe { AgentIo::borrow(&mut fd) };
314
0
    if flags != 0 {
315
0
        return PR_FAILURE;
316
0
    }
317
0
    usize::try_from(amount).map_or(PR_FAILURE, |a| {
318
0
        io.save_output(buf.cast(), a);
319
0
        amount
320
0
    })
321
0
}
322
323
0
unsafe extern "C" fn agent_available(mut fd: PrFd) -> PRInt32 {
324
0
    let io = unsafe { AgentIo::borrow(&mut fd) };
325
0
    io.input.available.try_into().unwrap_or(PR_FAILURE)
326
0
}
327
328
0
unsafe extern "C" fn agent_available64(mut fd: PrFd) -> PRInt64 {
329
0
    let io = unsafe { AgentIo::borrow(&mut fd) };
330
0
    io.input
331
0
        .available
332
0
        .try_into()
333
0
        .unwrap_or_else(|_| PR_FAILURE.into())
334
0
}
335
336
#[expect(
337
    clippy::cast_possible_truncation,
338
    reason = "Cast is safe because prio::PR_AF_INET is 2."
339
)]
340
4.35k
const unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus {
341
4.35k
    let Some(a) = (unsafe { addr.as_mut() }) else {
342
0
        return PR_FAILURE;
343
    };
344
    // Cast is safe because prio::PR_AF_INET is 2
345
4.35k
    a.inet.family = prio::PR_AF_INET as PRUint16;
346
4.35k
    a.inet.port = 0;
347
4.35k
    a.inet.ip = 0;
348
4.35k
    PR_SUCCESS
349
4.35k
}
350
351
1.77k
const unsafe extern "C" fn agent_getsockopt(
352
1.77k
    _fd: PrFd,
353
1.77k
    opt: *mut prio::PRSocketOptionData,
354
1.77k
) -> PrStatus {
355
1.77k
    let Some(o) = (unsafe { opt.as_mut() }) else {
356
0
        return PR_FAILURE;
357
    };
358
1.77k
    if o.option == prio::PRSockOption_PR_SockOpt_Nonblocking {
359
1.77k
        o.value.non_blocking = 1;
360
1.77k
        return PR_SUCCESS;
361
0
    }
362
0
    PR_FAILURE
363
1.77k
}
364
365
pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods {
366
    file_type: prio::PRDescType::PR_DESC_LAYERED,
367
    close: Some(agent_close),
368
    read: Some(agent_read),
369
    write: Some(agent_write),
370
    available: Some(agent_available),
371
    available64: Some(agent_available64),
372
    fsync: None,
373
    seek: None,
374
    seek64: None,
375
    fileInfo: None,
376
    fileInfo64: None,
377
    writev: None,
378
    connect: None,
379
    accept: None,
380
    bind: None,
381
    listen: None,
382
    shutdown: None,
383
    recv: Some(agent_recv),
384
    send: Some(agent_send),
385
    recvfrom: None,
386
    sendto: None,
387
    poll: None,
388
    acceptread: None,
389
    transmitfile: None,
390
    getsockname: Some(agent_getname),
391
    getpeername: Some(agent_getname),
392
    reserved_fn_6: None,
393
    reserved_fn_5: None,
394
    getsocketoption: Some(agent_getsockopt),
395
    setsocketoption: None,
396
    sendfile: None,
397
    connectcontinue: None,
398
    reserved_fn_3: None,
399
    reserved_fn_2: None,
400
    reserved_fn_1: None,
401
    reserved_fn_0: None,
402
};
403
404
#[cfg(test)]
405
#[cfg_attr(coverage_nightly, coverage(off))]
406
mod tests {
407
    use std::ptr::addr_of_mut;
408
409
    use super::*;
410
411
    #[test]
412
    fn ingest_errors() {
413
        let mut records = RecordList::default();
414
        let data = [0u8];
415
        unsafe {
416
            assert_eq!(
417
                RecordList::ingest(
418
                    null_mut(),
419
                    999,
420
                    0x17,
421
                    data.as_ptr(),
422
                    1,
423
                    addr_of_mut!(records).cast()
424
                ),
425
                ssl::SECFailure
426
            );
427
            assert_eq!(
428
                RecordList::ingest(null_mut(), 0, 0x17, data.as_ptr(), 1, null_mut()),
429
                ssl::SECFailure
430
            );
431
            // Test invalid content type (value outside u8 range)
432
            assert_eq!(
433
                RecordList::ingest(
434
                    null_mut(),
435
                    0,
436
                    256,
437
                    data.as_ptr(),
438
                    1,
439
                    addr_of_mut!(records).cast()
440
                ),
441
                ssl::SECFailure
442
            );
443
        }
444
    }
445
446
    #[test]
447
    fn formatting() {
448
        let record = Record::new(Epoch::ApplicationData, 0x17, &[1, 2, 3]);
449
        let dbg = format!("{record:?}");
450
        assert_eq!(&dbg[..6], "Record");
451
452
        let input = AgentIoInput::default();
453
        let disp = format!("{input}");
454
        assert_eq!(&disp[..12], "AgentIoInput");
455
456
        let io = AgentIo::default();
457
        assert_eq!(format!("{io}"), "AgentIo");
458
    }
459
}