/src/spdm-rs/spdmlib/src/requester/psk_exchange_req.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2020 Intel Corporation |
2 | | // |
3 | | // SPDX-License-Identifier: Apache-2.0 or MIT |
4 | | |
5 | | use config::MAX_SPDM_PSK_CONTEXT_SIZE; |
6 | | |
7 | | use crate::crypto; |
8 | | use crate::error::{ |
9 | | SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD, |
10 | | SPDM_STATUS_INVALID_PARAMETER, SPDM_STATUS_SESSION_NUMBER_EXCEED, SPDM_STATUS_VERIF_FAIL, |
11 | | }; |
12 | | use crate::error::{SPDM_STATUS_BUFFER_FULL, SPDM_STATUS_INVALID_STATE_LOCAL}; |
13 | | use crate::message::*; |
14 | | use crate::protocol::*; |
15 | | use crate::requester::*; |
16 | | extern crate alloc; |
17 | | use core::ops::DerefMut; |
18 | | |
19 | | impl RequesterContext { |
20 | | #[maybe_async::maybe_async] |
21 | | pub async fn send_receive_spdm_psk_exchange( |
22 | | &mut self, |
23 | | measurement_summary_hash_type: SpdmMeasurementSummaryHashType, |
24 | | psk_hint: Option<&SpdmPskHintStruct>, |
25 | 0 | ) -> SpdmResult<u32> { |
26 | 0 | info!("send spdm psk exchange\n"); |
27 | | |
28 | 0 | let psk_hint = if let Some(hint) = psk_hint { |
29 | 0 | hint.clone() |
30 | | } else { |
31 | 0 | SpdmPskHintStruct::default() |
32 | | }; |
33 | | |
34 | 0 | self.common |
35 | 0 | .reset_buffer_via_request_code(SpdmRequestResponseCode::SpdmRequestPskExchange, None); |
36 | 0 |
|
37 | 0 | let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; |
38 | 0 | let half_session_id = self.common.get_next_half_session_id(true)?; |
39 | 0 | let send_used = self.encode_spdm_psk_exchange( |
40 | 0 | half_session_id, |
41 | 0 | measurement_summary_hash_type, |
42 | 0 | &psk_hint, |
43 | 0 | &mut send_buffer, |
44 | 0 | )?; |
45 | | |
46 | 0 | self.send_message(None, &send_buffer[..send_used], false) |
47 | 0 | .await?; |
48 | | |
49 | | // Receive |
50 | 0 | let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; |
51 | 0 | let receive_used = self |
52 | 0 | .receive_message(None, &mut receive_buffer, false) |
53 | 0 | .await?; |
54 | | |
55 | 0 | let mut target_session_id = None; |
56 | 0 | if let Err(e) = self.handle_spdm_psk_exchange_response( |
57 | 0 | half_session_id, |
58 | 0 | measurement_summary_hash_type, |
59 | 0 | &psk_hint, |
60 | 0 | &send_buffer[..send_used], |
61 | 0 | &receive_buffer[..receive_used], |
62 | 0 | &mut target_session_id, |
63 | 0 | ) { |
64 | 0 | if let Some(session_id) = target_session_id { |
65 | 0 | if let Some(session) = self.common.get_session_via_id(session_id) { |
66 | 0 | session.teardown(); |
67 | 0 | } |
68 | 0 | } |
69 | | |
70 | 0 | Err(e) |
71 | | } else { |
72 | 0 | Ok(target_session_id.unwrap()) |
73 | | } |
74 | 0 | } |
75 | | |
76 | | pub fn encode_spdm_psk_exchange( |
77 | | &mut self, |
78 | | half_session_id: u16, |
79 | | measurement_summary_hash_type: SpdmMeasurementSummaryHashType, |
80 | | psk_hint: &SpdmPskHintStruct, |
81 | | buf: &mut [u8], |
82 | | ) -> SpdmResult<usize> { |
83 | | let mut writer = Writer::init(buf); |
84 | | |
85 | | let mut psk_context = [0u8; MAX_SPDM_PSK_CONTEXT_SIZE]; |
86 | | crypto::rand::get_random(&mut psk_context)?; |
87 | | |
88 | | let mut secured_message_version_list = SecuredMessageVersionList { |
89 | | version_count: 0, |
90 | | versions_list: [SecuredMessageVersion::default(); MAX_SECURE_SPDM_VERSION_COUNT], |
91 | | }; |
92 | | |
93 | | for local_version in self.common.config_info.secure_spdm_version.iter().flatten() { |
94 | | secured_message_version_list.versions_list |
95 | | [secured_message_version_list.version_count as usize] = *local_version; |
96 | | secured_message_version_list.version_count += 1; |
97 | | } |
98 | | |
99 | | let opaque = SpdmOpaqueStruct::from_sm_supported_ver_list_opaque( |
100 | | &mut self.common, |
101 | | &SMSupportedVerListOpaque { |
102 | | secured_message_version_list, |
103 | | }, |
104 | | )?; |
105 | | |
106 | | let request = SpdmMessage { |
107 | | header: SpdmMessageHeader { |
108 | | version: self.common.negotiate_info.spdm_version_sel, |
109 | | request_response_code: SpdmRequestResponseCode::SpdmRequestPskExchange, |
110 | | }, |
111 | | payload: SpdmMessagePayload::SpdmPskExchangeRequest(SpdmPskExchangeRequestPayload { |
112 | | measurement_summary_hash_type, |
113 | | req_session_id: half_session_id, |
114 | | psk_hint: psk_hint.clone(), |
115 | | psk_context: SpdmPskContextStruct { |
116 | | data_size: self.common.negotiate_info.base_hash_sel.get_size(), |
117 | | data: psk_context, |
118 | | }, |
119 | | opaque, |
120 | | }), |
121 | | }; |
122 | | request.spdm_encode(&mut self.common, &mut writer) |
123 | | } |
124 | | |
125 | | pub fn handle_spdm_psk_exchange_response( |
126 | | &mut self, |
127 | | half_session_id: u16, |
128 | | measurement_summary_hash_type: SpdmMeasurementSummaryHashType, |
129 | | psk_hint: &SpdmPskHintStruct, |
130 | | send_buffer: &[u8], |
131 | | receive_buffer: &[u8], |
132 | | target_session_id: &mut Option<u32>, |
133 | | ) -> SpdmResult<u32> { |
134 | | self.common.runtime_info.need_measurement_summary_hash = (measurement_summary_hash_type |
135 | | == SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeTcb) |
136 | | || (measurement_summary_hash_type |
137 | | == SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeAll); |
138 | | |
139 | | let mut reader = Reader::init(receive_buffer); |
140 | | match SpdmMessageHeader::read(&mut reader) { |
141 | | Some(message_header) => { |
142 | | if message_header.version != self.common.negotiate_info.spdm_version_sel { |
143 | | return Err(SPDM_STATUS_INVALID_MSG_FIELD); |
144 | | } |
145 | | match message_header.request_response_code { |
146 | | SpdmRequestResponseCode::SpdmResponsePskExchangeRsp => { |
147 | | let psk_exchange_rsp = SpdmPskExchangeResponsePayload::spdm_read( |
148 | | &mut self.common, |
149 | | &mut reader, |
150 | | ); |
151 | | let receive_used = reader.used(); |
152 | | if let Some(psk_exchange_rsp) = psk_exchange_rsp { |
153 | | debug!("!!! psk_exchange rsp : {:02x?}\n", psk_exchange_rsp); |
154 | | |
155 | | // create session structure |
156 | | let base_hash_algo = self.common.negotiate_info.base_hash_sel; |
157 | | let dhe_algo = self.common.negotiate_info.dhe_sel; |
158 | | let aead_algo = self.common.negotiate_info.aead_sel; |
159 | | let key_schedule_algo = self.common.negotiate_info.key_schedule_sel; |
160 | | let sequence_number_count = { |
161 | | let mut transport_encap = self.common.transport_encap.lock(); |
162 | | let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) = |
163 | | transport_encap.deref_mut(); |
164 | | transport_encap.get_sequence_number_count() |
165 | | }; |
166 | | let max_random_count = { |
167 | | let mut transport_encap = self.common.transport_encap.lock(); |
168 | | let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) = |
169 | | transport_encap.deref_mut(); |
170 | | transport_encap.get_max_random_count() |
171 | | }; |
172 | | |
173 | | let secure_spdm_version_sel = psk_exchange_rsp |
174 | | .opaque |
175 | | .req_get_dmtf_secure_spdm_version_selection(&mut self.common) |
176 | | .ok_or(SPDM_STATUS_INVALID_MSG_FIELD)?; |
177 | | |
178 | | let session_id = ((psk_exchange_rsp.rsp_session_id as u32) << 16) |
179 | | + half_session_id as u32; |
180 | | *target_session_id = Some(session_id); |
181 | | let spdm_version_sel = self.common.negotiate_info.spdm_version_sel; |
182 | | let message_a = self.common.runtime_info.message_a.clone(); |
183 | | |
184 | | let session = self |
185 | | .common |
186 | | .get_next_avaiable_session() |
187 | | .ok_or(SPDM_STATUS_SESSION_NUMBER_EXCEED)?; |
188 | | |
189 | | session.setup(session_id)?; |
190 | | |
191 | | session.set_use_psk(true); |
192 | | |
193 | | session.set_crypto_param( |
194 | | base_hash_algo, |
195 | | dhe_algo, |
196 | | aead_algo, |
197 | | key_schedule_algo, |
198 | | ); |
199 | | session.set_transport_param(sequence_number_count, max_random_count); |
200 | | |
201 | | session.runtime_info.psk_hint = Some(psk_hint.clone()); |
202 | | session.runtime_info.message_a = message_a; |
203 | | session.runtime_info.rsp_cert_hash = None; |
204 | | session.runtime_info.req_cert_hash = None; |
205 | | |
206 | | // create transcript |
207 | | let base_hash_size = |
208 | | self.common.negotiate_info.base_hash_sel.get_size() as usize; |
209 | | let temp_receive_used = receive_used - base_hash_size; |
210 | | |
211 | | self.common.append_message_k(session_id, send_buffer)?; |
212 | | self.common.append_message_k( |
213 | | session_id, |
214 | | &receive_buffer[..temp_receive_used], |
215 | | )?; |
216 | | |
217 | | let session = self |
218 | | .common |
219 | | .get_immutable_session_via_id(session_id) |
220 | | .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; |
221 | | |
222 | | // generate the handshake secret (including finished_key) before verify HMAC |
223 | | let th1 = self.common.calc_req_transcript_hash( |
224 | | true, |
225 | | INVALID_SLOT, |
226 | | false, |
227 | | session, |
228 | | )?; |
229 | | debug!("!!! th1 : {:02x?}\n", th1.as_ref()); |
230 | | |
231 | | let session = self |
232 | | .common |
233 | | .get_session_via_id(session_id) |
234 | | .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; |
235 | | session.generate_handshake_secret(spdm_version_sel, &th1)?; |
236 | | |
237 | | let session = self |
238 | | .common |
239 | | .get_immutable_session_via_id(session_id) |
240 | | .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; |
241 | | |
242 | | // verify HMAC with finished_key |
243 | | let transcript_hash = self.common.calc_req_transcript_hash( |
244 | | true, |
245 | | INVALID_SLOT, |
246 | | false, |
247 | | session, |
248 | | )?; |
249 | | |
250 | | let session = self |
251 | | .common |
252 | | .get_immutable_session_via_id(session_id) |
253 | | .ok_or(SPDM_STATUS_INVALID_PARAMETER)?; |
254 | | |
255 | | if session |
256 | | .verify_hmac_with_response_finished_key( |
257 | | transcript_hash.as_ref(), |
258 | | &psk_exchange_rsp.verify_data, |
259 | | ) |
260 | | .is_err() |
261 | | { |
262 | | error!("verify_hmac_with_response_finished_key fail"); |
263 | | let session = self |
264 | | .common |
265 | | .get_session_via_id(session_id) |
266 | | .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; |
267 | | session.teardown(); |
268 | | return Err(SPDM_STATUS_VERIF_FAIL); |
269 | | } else { |
270 | | info!("verify_hmac_with_response_finished_key pass"); |
271 | | } |
272 | | |
273 | | // append verify_data after TH1 |
274 | | if self |
275 | | .common |
276 | | .append_message_k(session_id, psk_exchange_rsp.verify_data.as_ref()) |
277 | | .is_err() |
278 | | { |
279 | | let session = self |
280 | | .common |
281 | | .get_session_via_id(session_id) |
282 | | .ok_or(SPDM_STATUS_INVALID_PARAMETER)?; |
283 | | session.teardown(); |
284 | | return Err(SPDM_STATUS_BUFFER_FULL); |
285 | | } |
286 | | |
287 | | let session = self |
288 | | .common |
289 | | .get_session_via_id(session_id) |
290 | | .ok_or(SPDM_STATUS_INVALID_PARAMETER)?; |
291 | | session.set_session_state( |
292 | | crate::common::session::SpdmSessionState::SpdmSessionHandshaking, |
293 | | ); |
294 | | |
295 | | let session = self |
296 | | .common |
297 | | .get_immutable_session_via_id(session_id) |
298 | | .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; |
299 | | let psk_without_context = self |
300 | | .common |
301 | | .negotiate_info |
302 | | .rsp_capabilities_sel |
303 | | .contains(SpdmResponseCapabilityFlags::PSK_CAP_WITHOUT_CONTEXT); |
304 | | if psk_without_context { |
305 | | // generate the data secret directly to skip PSK_FINISH |
306 | | let th2 = self.common.calc_req_transcript_hash( |
307 | | true, |
308 | | INVALID_SLOT, |
309 | | false, |
310 | | session, |
311 | | )?; |
312 | | |
313 | | debug!("!!! th2 : {:02x?}\n", th2.as_ref()); |
314 | | |
315 | | let session = self |
316 | | .common |
317 | | .get_session_via_id(session_id) |
318 | | .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; |
319 | | session.generate_data_secret(spdm_version_sel, &th2)?; |
320 | | session.set_session_state( |
321 | | crate::common::session::SpdmSessionState::SpdmSessionEstablished, |
322 | | ); |
323 | | } |
324 | | |
325 | | let session = self |
326 | | .common |
327 | | .get_session_via_id(session_id) |
328 | | .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; |
329 | | session.secure_spdm_version_sel = secure_spdm_version_sel; |
330 | | session.heartbeat_period = psk_exchange_rsp.heartbeat_period; |
331 | | |
332 | | Ok(session_id) |
333 | | } else { |
334 | | error!("!!! psk_exchange : fail !!!\n"); |
335 | | Err(SPDM_STATUS_INVALID_MSG_FIELD) |
336 | | } |
337 | | } |
338 | | SpdmRequestResponseCode::SpdmResponseError => { |
339 | | let status = self.spdm_handle_error_response_main( |
340 | | None, |
341 | | receive_buffer, |
342 | | SpdmRequestResponseCode::SpdmRequestPskExchange, |
343 | | SpdmRequestResponseCode::SpdmResponsePskExchangeRsp, |
344 | | ); |
345 | | match status { |
346 | | Err(status) => Err(status), |
347 | | Ok(()) => Err(SPDM_STATUS_ERROR_PEER), |
348 | | } |
349 | | } |
350 | | _ => Err(SPDM_STATUS_ERROR_PEER), |
351 | | } |
352 | | } |
353 | | None => Err(SPDM_STATUS_INVALID_MSG_FIELD), |
354 | | } |
355 | | } |
356 | | } |