/rust/git/checkouts/nss-rs-71e20fe79ef91440/9b94ca3/src/ext.rs
Line | Count | Source |
1 | | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
2 | | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
3 | | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
4 | | // option. This file may not be copied, modified, or distributed |
5 | | // except according to those terms. |
6 | | |
7 | | #![expect( |
8 | | clippy::unwrap_used, |
9 | | reason = "Let's assume the use of `unwrap` was checked when the use of `unsafe` was reviewed." |
10 | | )] |
11 | | |
12 | | use std::{ |
13 | | cell::RefCell, |
14 | | convert::TryFrom as _, |
15 | | fmt::{self, Debug, Formatter}, |
16 | | os::raw::{c_uint, c_void}, |
17 | | pin::Pin, |
18 | | rc::Rc, |
19 | | }; |
20 | | |
21 | | use crate::{ |
22 | | SECStatus, |
23 | | agentio::as_c_void, |
24 | | constants::{Extension, HandshakeMessage, TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS}, |
25 | | err::Res, |
26 | | nss_prelude::PRBool, |
27 | | null_safe_slice, |
28 | | prio::PRFileDesc, |
29 | | ssl::{ |
30 | | SECFailure, SECSuccess, SSLAlertDescription, SSLExtensionHandler, SSLExtensionWriter, |
31 | | SSLHandshakeType, |
32 | | }, |
33 | | }; |
34 | | |
35 | | experimental_api!(SSL_InstallExtensionHooks( |
36 | | fd: *mut PRFileDesc, |
37 | | extension: u16, |
38 | | writer: SSLExtensionWriter, |
39 | | writer_arg: *mut c_void, |
40 | | handler: SSLExtensionHandler, |
41 | | handler_arg: *mut c_void, |
42 | | )); |
43 | | |
44 | | experimental_api!(SSL_CallExtensionWriterOnEchInner( |
45 | | fd: *mut PRFileDesc, |
46 | | enabled: PRBool, |
47 | | )); |
48 | | |
49 | | pub enum ExtensionWriterResult { |
50 | | Write(usize), |
51 | | Skip, |
52 | | } |
53 | | |
54 | | pub enum ExtensionHandlerResult { |
55 | | Ok, |
56 | | Alert(crate::constants::Alert), |
57 | | } |
58 | | |
59 | | pub trait ExtensionHandler { |
60 | | /// Write an extension to the given buffer. |
61 | | /// NSS will call back when it needs an extension. |
62 | | /// Supply the bytes of the extension (without a type and length); |
63 | | /// the default implementation writes a zero-length extension |
64 | | /// to both the `ClientHello` and `EncryptedExtensions` message. |
65 | | /// |
66 | | /// The value of `ch_outer` is only relevant when ECH is enabled; |
67 | | /// it will be `false` when ECH is disabled or for the inner `ClientHello`. |
68 | | /// For ECH, where `msg == TLS_HS_CLIENT_HELLO`, |
69 | | /// you can write different values to the inner and outer extensions; |
70 | | /// if they are different, NSS won't compress them. |
71 | 0 | fn write( |
72 | 0 | &mut self, |
73 | 0 | msg: HandshakeMessage, |
74 | 0 | _ch_outer: bool, |
75 | 0 | _d: &mut [u8], |
76 | 0 | ) -> ExtensionWriterResult { |
77 | 0 | match msg { |
78 | 0 | TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionWriterResult::Write(0), |
79 | 0 | _ => ExtensionWriterResult::Skip, |
80 | | } |
81 | 0 | } |
82 | | |
83 | 0 | fn handle(&mut self, msg: HandshakeMessage, _d: &[u8]) -> ExtensionHandlerResult { |
84 | 0 | match msg { |
85 | 0 | TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionHandlerResult::Ok, |
86 | 0 | _ => ExtensionHandlerResult::Alert(110), // unsupported_extension |
87 | | } |
88 | 0 | } |
89 | | } |
90 | | |
91 | | type BoxedExtensionHandler = Box<Rc<RefCell<dyn ExtensionHandler>>>; |
92 | | |
93 | | pub struct ExtensionTracker { |
94 | | extension: Extension, |
95 | | handler: Pin<Box<BoxedExtensionHandler>>, |
96 | | } |
97 | | |
98 | | impl ExtensionTracker { |
99 | | // Technically the as_mut() call here is the only unsafe bit, |
100 | | // but don't call this function lightly. |
101 | 3.22k | unsafe fn wrap_handler_call<F, T>(arg: *mut c_void, f: F) -> T |
102 | 3.22k | where |
103 | 3.22k | F: FnOnce(&mut dyn ExtensionHandler) -> T, |
104 | | { |
105 | 3.22k | let rc = unsafe { arg.cast::<BoxedExtensionHandler>().as_mut().unwrap() }; |
106 | 3.22k | f(&mut *rc.borrow_mut()) |
107 | 3.22k | } <nss_rs::ext::ExtensionTracker>::wrap_handler_call::<<nss_rs::ext::ExtensionTracker>::extension_writer::{closure#2}, i32>Line | Count | Source | 101 | 2.74k | unsafe fn wrap_handler_call<F, T>(arg: *mut c_void, f: F) -> T | 102 | 2.74k | where | 103 | 2.74k | F: FnOnce(&mut dyn ExtensionHandler) -> T, | 104 | | { | 105 | 2.74k | let rc = unsafe { arg.cast::<BoxedExtensionHandler>().as_mut().unwrap() }; | 106 | 2.74k | f(&mut *rc.borrow_mut()) | 107 | 2.74k | } |
<nss_rs::ext::ExtensionTracker>::wrap_handler_call::<<nss_rs::ext::ExtensionTracker>::extension_handler::{closure#0}, i32>Line | Count | Source | 101 | 484 | unsafe fn wrap_handler_call<F, T>(arg: *mut c_void, f: F) -> T | 102 | 484 | where | 103 | 484 | F: FnOnce(&mut dyn ExtensionHandler) -> T, | 104 | | { | 105 | 484 | let rc = unsafe { arg.cast::<BoxedExtensionHandler>().as_mut().unwrap() }; | 106 | 484 | f(&mut *rc.borrow_mut()) | 107 | 484 | } |
|
108 | | |
109 | 2.74k | unsafe extern "C" fn extension_writer( |
110 | 2.74k | _fd: *mut PRFileDesc, |
111 | 2.74k | message: SSLHandshakeType::Type, |
112 | 2.74k | data: *mut u8, |
113 | 2.74k | len: *mut c_uint, |
114 | 2.74k | max_len: c_uint, |
115 | 2.74k | arg: *mut c_void, |
116 | 2.74k | ) -> PRBool { |
117 | | // The input message type is larger than the `u8` range of `SSLHandshakeType`. |
118 | | // The only valid value outside that range is for ECH outer ClientHello, |
119 | | // which we need to have special handling for. |
120 | 2.74k | let (msg, ch_outer) = HandshakeMessage::try_from(message).map_or_else( |
121 | 0 | |_| { |
122 | 0 | debug_assert_eq!(message, SSLHandshakeType::ssl_hs_ech_outer_client_hello); |
123 | 0 | (TLS_HS_CLIENT_HELLO, true) |
124 | 0 | }, |
125 | 2.74k | |msg| (msg, false), |
126 | | ); |
127 | 2.74k | let d = unsafe { std::slice::from_raw_parts_mut(data, max_len as usize) }; |
128 | | // provided by NSS for writing the output length. |
129 | | unsafe { |
130 | 2.74k | Self::wrap_handler_call(arg, |handler| match handler.write(msg, ch_outer, d) { |
131 | 1.77k | ExtensionWriterResult::Write(sz) => { |
132 | 1.77k | *len = c_uint::try_from(sz).expect("integer overflow from extension writer"); |
133 | 1.77k | 1 |
134 | | } |
135 | 968 | ExtensionWriterResult::Skip => 0, |
136 | 2.74k | }) |
137 | | } |
138 | 2.74k | } |
139 | | |
140 | 484 | unsafe extern "C" fn extension_handler( |
141 | 484 | _fd: *mut PRFileDesc, |
142 | 484 | message: SSLHandshakeType::Type, |
143 | 484 | data: *const u8, |
144 | 484 | len: c_uint, |
145 | 484 | alert: *mut SSLAlertDescription, |
146 | 484 | arg: *mut c_void, |
147 | 484 | ) -> SECStatus { |
148 | 484 | let d = unsafe { null_safe_slice(data, len) }; |
149 | | // provided by NSS for writing the alert description. |
150 | | unsafe { |
151 | 484 | Self::wrap_handler_call(arg, |handler| { |
152 | | // Cast is safe here because the message type is always part of the enum |
153 | | #[allow( |
154 | | clippy::allow_attributes, |
155 | | clippy::cast_possible_truncation, |
156 | | clippy::cast_sign_loss, |
157 | | reason = "Cast is safe here because the message type is always part of the enum." |
158 | | )] |
159 | 484 | match handler.handle(message as HandshakeMessage, d) { |
160 | 484 | ExtensionHandlerResult::Ok => SECSuccess, |
161 | 0 | ExtensionHandlerResult::Alert(a) => { |
162 | 0 | *alert = a; |
163 | 0 | SECFailure |
164 | | } |
165 | | } |
166 | 484 | }) |
167 | | } |
168 | 484 | } |
169 | | |
170 | | /// Use the provided handler to manage an extension. This is quite unsafe. |
171 | | /// |
172 | | /// # Safety |
173 | | /// |
174 | | /// The holder of this `ExtensionTracker` needs to ensure that it lives at |
175 | | /// least as long as the file descriptor, as NSS provides no way to remove |
176 | | /// an extension handler once it is configured. |
177 | | /// |
178 | | /// # Errors |
179 | | /// |
180 | | /// If the underlying NSS API fails to register a handler. |
181 | 2.57k | pub unsafe fn new( |
182 | 2.57k | fd: *mut PRFileDesc, |
183 | 2.57k | extension: Extension, |
184 | 2.57k | handler: Rc<RefCell<dyn ExtensionHandler>>, |
185 | 2.57k | ) -> Res<Self> { |
186 | | unsafe { |
187 | | // The ergonomics here aren't great for users of this API, but it's |
188 | | // horrific here. The pinned outer box gives us a stable pointer to the inner |
189 | | // box. This is the pointer that is passed to NSS. |
190 | | // |
191 | | // The inner box points to the reference-counted object. This inner box is |
192 | | // what we end up with a reference to in callbacks. That extra wrapper around |
193 | | // the Rc avoid any touching of reference counts in callbacks, which would |
194 | | // inevitably lead to leaks as we don't control how many times the callback |
195 | | // is invoked. |
196 | | // |
197 | | // This way, only this "outer" code deals with the reference count. |
198 | 2.57k | let mut tracker = Self { |
199 | 2.57k | extension, |
200 | 2.57k | handler: Box::pin(Box::new(handler)), |
201 | 2.57k | }; |
202 | 2.57k | SSL_InstallExtensionHooks( |
203 | 2.57k | fd, |
204 | 2.57k | extension, |
205 | 2.57k | Some(Self::extension_writer), |
206 | 2.57k | as_c_void(&mut tracker.handler), |
207 | 2.57k | Some(Self::extension_handler), |
208 | 2.57k | as_c_void(&mut tracker.handler), |
209 | 0 | )?; |
210 | 2.57k | Ok(tracker) |
211 | | } |
212 | 2.57k | } |
213 | | } |
214 | | |
215 | | impl Debug for ExtensionTracker { |
216 | 0 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { |
217 | 0 | write!(f, "ExtensionTracker: {:?}", self.extension) |
218 | 0 | } |
219 | | } |