Coverage Report

Created: 2025-12-31 06:25

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spdm-rs/spdmlib/src/responder/psk_exchange_rsp.rs
Line
Count
Source
1
// Copyright (c) 2020 Intel Corporation
2
//
3
// SPDX-License-Identifier: Apache-2.0 or MIT
4
5
use crate::common::opaque::SpdmOpaqueStruct;
6
use crate::common::SMVersionSelOpaque;
7
use crate::common::SecuredMessageVersion;
8
use crate::common::SpdmCodec;
9
use crate::common::SpdmConnectionState;
10
use crate::common::SpdmTransportEncap;
11
use crate::crypto;
12
use crate::error::SpdmResult;
13
use crate::error::SPDM_STATUS_CRYPTO_ERROR;
14
use crate::error::SPDM_STATUS_INVALID_MSG_FIELD;
15
use crate::error::SPDM_STATUS_INVALID_STATE_LOCAL;
16
use crate::error::SPDM_STATUS_INVALID_STATE_PEER;
17
use crate::message::*;
18
use crate::protocol::*;
19
use crate::responder::*;
20
use crate::watchdog::start_watchdog;
21
use config::MAX_SPDM_PSK_CONTEXT_SIZE;
22
extern crate alloc;
23
use crate::secret;
24
use alloc::boxed::Box;
25
use core::convert::TryFrom;
26
use core::ops::DerefMut;
27
28
impl ResponderContext {
29
0
    pub fn handle_spdm_psk_exchange<'a>(
30
0
        &mut self,
31
0
        bytes: &[u8],
32
0
        writer: &'a mut Writer,
33
0
    ) -> (SpdmResult, Option<&'a [u8]>) {
34
0
        let mut target_session_id = None;
35
0
        let (result, rsp_slice) =
36
0
            self.write_spdm_psk_exchange_response(bytes, writer, &mut target_session_id);
37
0
        if result.is_err() {
38
0
            if let Some(session_id) = target_session_id {
39
0
                if let Some(session) = self.common.get_session_via_id(session_id) {
40
0
                    session.teardown();
41
0
                }
42
0
            }
43
0
        }
44
45
0
        (result, rsp_slice)
46
0
    }
47
48
0
    pub fn write_spdm_psk_exchange_response<'a>(
49
0
        &mut self,
50
0
        bytes: &[u8],
51
0
        writer: &'a mut Writer,
52
0
        target_session_id: &mut Option<u32>,
53
0
    ) -> (SpdmResult, Option<&'a [u8]>) {
54
0
        if self.common.runtime_info.get_connection_state().get_u8()
55
0
            < SpdmConnectionState::SpdmConnectionNegotiated.get_u8()
56
        {
57
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnexpectedRequest, 0, writer);
58
0
            return (
59
0
                Err(SPDM_STATUS_INVALID_STATE_PEER),
60
0
                Some(writer.used_slice()),
61
0
            );
62
0
        }
63
0
        let mut reader = Reader::init(bytes);
64
0
        let message_header = SpdmMessageHeader::read(&mut reader);
65
0
        if let Some(message_header) = message_header {
66
0
            if message_header.version != self.common.negotiate_info.spdm_version_sel {
67
0
                self.write_spdm_error(SpdmErrorCode::SpdmErrorVersionMismatch, 0, writer);
68
0
                return (
69
0
                    Err(SPDM_STATUS_INVALID_MSG_FIELD),
70
0
                    Some(writer.used_slice()),
71
0
                );
72
0
            }
73
0
            if message_header.version < SpdmVersion::SpdmVersion11 {
74
0
                self.write_spdm_error(SpdmErrorCode::SpdmErrorUnsupportedRequest, 0, writer);
75
0
                return (
76
0
                    Err(SPDM_STATUS_INVALID_MSG_FIELD),
77
0
                    Some(writer.used_slice()),
78
0
                );
79
0
            }
80
        } else {
81
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer);
82
0
            return (
83
0
                Err(SPDM_STATUS_INVALID_MSG_FIELD),
84
0
                Some(writer.used_slice()),
85
0
            );
86
        }
87
88
0
        self.common
89
0
            .reset_buffer_via_request_code(SpdmRequestResponseCode::SpdmRequestPskExchange, None);
90
91
0
        let psk_exchange_req =
92
0
            SpdmPskExchangeRequestPayload::spdm_read(&mut self.common, &mut reader);
93
94
0
        let mut return_opaque = SpdmOpaqueStruct::default();
95
96
        let measurement_summary_hash;
97
        let psk_hint;
98
0
        if let Some(psk_exchange_req) = &psk_exchange_req {
99
0
            debug!("!!! psk_exchange req : {:02x?}\n", psk_exchange_req);
100
101
0
            if (psk_exchange_req.measurement_summary_hash_type
102
0
                == SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeTcb)
103
0
                || (psk_exchange_req.measurement_summary_hash_type
104
0
                    == SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeAll)
105
            {
106
0
                self.common.runtime_info.need_measurement_summary_hash = true;
107
0
                let measurement_summary_hash_res =
108
0
                    secret::measurement::generate_measurement_summary_hash(
109
0
                        self.common.negotiate_info.spdm_version_sel,
110
0
                        self.common.negotiate_info.base_hash_sel,
111
0
                        self.common.negotiate_info.measurement_specification_sel,
112
0
                        self.common.negotiate_info.measurement_hash_sel,
113
0
                        psk_exchange_req.measurement_summary_hash_type,
114
                    );
115
0
                if measurement_summary_hash_res.is_none() {
116
0
                    self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
117
0
                    return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice()));
118
0
                }
119
0
                measurement_summary_hash = measurement_summary_hash_res.unwrap();
120
0
                if measurement_summary_hash.data_size == 0 {
121
0
                    self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
122
0
                    return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice()));
123
0
                }
124
0
            } else {
125
0
                self.common.runtime_info.need_measurement_summary_hash = false;
126
0
                measurement_summary_hash = SpdmDigestStruct::default();
127
0
            }
128
129
0
            psk_hint = psk_exchange_req.psk_hint.clone();
130
131
0
            if let Some(secured_message_version_list) = psk_exchange_req
132
0
                .opaque
133
0
                .rsp_get_dmtf_supported_secure_spdm_version_list(&mut self.common)
134
            {
135
0
                if secured_message_version_list.version_count
136
0
                    > crate::common::opaque::MAX_SECURE_SPDM_VERSION_COUNT as u8
137
                {
138
0
                    self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer);
139
0
                    return (
140
0
                        Err(SPDM_STATUS_INVALID_MSG_FIELD),
141
0
                        Some(writer.used_slice()),
142
0
                    );
143
0
                }
144
145
0
                let mut selected_version: Option<SecuredMessageVersion> = None;
146
0
                for index in 0..secured_message_version_list.version_count as usize {
147
0
                    for local_version in
148
0
                        self.common.config_info.secure_spdm_version.iter().flatten()
149
                    {
150
0
                        if secured_message_version_list.versions_list[index] == *local_version {
151
0
                            selected_version = Some(*local_version);
152
0
                        }
153
                    }
154
                }
155
156
0
                if let Some(selected_version) = selected_version {
157
0
                    if let Ok(opaque) = SpdmOpaqueStruct::from_sm_version_sel_opaque(
158
0
                        &mut self.common,
159
0
                        &SMVersionSelOpaque {
160
0
                            secured_message_version: selected_version,
161
0
                        },
162
0
                    ) {
163
0
                        return_opaque = opaque;
164
0
                    } else {
165
0
                        self.write_spdm_error(
166
0
                            SpdmErrorCode::SpdmErrorUnsupportedRequest,
167
                            0,
168
0
                            writer,
169
                        );
170
0
                        return (
171
0
                            Err(SPDM_STATUS_INVALID_MSG_FIELD),
172
0
                            Some(writer.used_slice()),
173
0
                        );
174
                    }
175
                } else {
176
0
                    error!("secure message version not selected!");
177
0
                    self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer);
178
0
                    return (
179
0
                        Err(SPDM_STATUS_INVALID_MSG_FIELD),
180
0
                        Some(writer.used_slice()),
181
0
                    );
182
                }
183
0
            }
184
        } else {
185
0
            error!("!!! psk_exchange req : fail !!!\n");
186
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer);
187
0
            return (
188
0
                Err(SPDM_STATUS_INVALID_MSG_FIELD),
189
0
                Some(writer.used_slice()),
190
0
            );
191
        }
192
193
0
        let psk_without_context = self
194
0
            .common
195
0
            .negotiate_info
196
0
            .rsp_capabilities_sel
197
0
            .contains(SpdmResponseCapabilityFlags::PSK_CAP_WITHOUT_CONTEXT);
198
0
        let psk_context_size = if psk_without_context {
199
0
            0u16
200
        } else {
201
0
            MAX_SPDM_PSK_CONTEXT_SIZE as u16
202
        };
203
0
        let mut psk_context = [0u8; MAX_SPDM_PSK_CONTEXT_SIZE];
204
0
        if psk_without_context {
205
0
            let res = crypto::rand::get_random(&mut psk_context);
206
0
            if res.is_err() {
207
0
                self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
208
0
                return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice()));
209
0
            }
210
0
        }
211
212
0
        let rsp_session_id = self.common.get_next_half_session_id(false);
213
0
        if rsp_session_id.is_err() {
214
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorSessionLimitExceeded, 0, writer);
215
0
            return (
216
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
217
0
                Some(writer.used_slice()),
218
0
            );
219
0
        }
220
0
        let rsp_session_id = rsp_session_id.unwrap();
221
222
        // create session structure
223
0
        let hash_algo = self.common.negotiate_info.base_hash_sel;
224
0
        let dhe_algo = self.common.negotiate_info.dhe_sel;
225
0
        let kem_algo = self.common.negotiate_info.kem_sel;
226
0
        let aead_algo = self.common.negotiate_info.aead_sel;
227
0
        let key_schedule_algo = self.common.negotiate_info.key_schedule_sel;
228
0
        let sequence_number_count = {
229
0
            let mut transport_encap = self.common.transport_encap.lock();
230
0
            let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) =
231
0
                transport_encap.deref_mut();
232
0
            transport_encap.get_sequence_number_count()
233
        };
234
0
        let max_random_count = {
235
0
            let mut transport_encap = self.common.transport_encap.lock();
236
0
            let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) =
237
0
                transport_encap.deref_mut();
238
0
            transport_encap.get_max_random_count()
239
        };
240
241
0
        let spdm_version_sel = self.common.negotiate_info.spdm_version_sel;
242
0
        let message_a = self.common.runtime_info.message_a.clone();
243
244
0
        let session = self.common.get_next_avaiable_session();
245
0
        if session.is_none() {
246
0
            error!("!!! too many sessions : fail !!!\n");
247
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorSessionLimitExceeded, 0, writer);
248
0
            return (
249
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
250
0
                Some(writer.used_slice()),
251
0
            );
252
0
        }
253
254
0
        let session = session.unwrap();
255
0
        let session_id =
256
0
            ((rsp_session_id as u32) << 16) + psk_exchange_req.unwrap().req_session_id as u32;
257
0
        *target_session_id = Some(session_id);
258
0
        session.setup(session_id).unwrap();
259
0
        session.set_use_psk(true);
260
261
0
        session.set_crypto_param(hash_algo, dhe_algo, kem_algo, aead_algo, key_schedule_algo);
262
0
        session.set_transport_param(sequence_number_count, max_random_count);
263
264
0
        session.runtime_info.psk_hint = Some(psk_hint);
265
0
        session.runtime_info.message_a = message_a;
266
0
        session.runtime_info.rsp_cert_hash = None;
267
0
        session.runtime_info.req_cert_hash = None;
268
269
0
        info!("send spdm psk_exchange rsp\n");
270
271
        // prepare response
272
0
        let response = SpdmMessage {
273
0
            header: SpdmMessageHeader {
274
0
                version: self.common.negotiate_info.spdm_version_sel,
275
0
                request_response_code: SpdmRequestResponseCode::SpdmResponsePskExchangeRsp,
276
0
            },
277
0
            payload: SpdmMessagePayload::SpdmPskExchangeResponse(SpdmPskExchangeResponsePayload {
278
0
                heartbeat_period: self.common.config_info.heartbeat_period,
279
0
                rsp_session_id,
280
0
                measurement_summary_hash,
281
0
                psk_context: SpdmPskContextStruct {
282
0
                    data_size: psk_context_size,
283
0
                    data: psk_context,
284
0
                },
285
0
                opaque: return_opaque,
286
0
                verify_data: SpdmDigestStruct {
287
0
                    data_size: self.common.get_hash_size(),
288
0
                    data: Box::new([0xcc; SPDM_MAX_HASH_SIZE]),
289
0
                },
290
0
            }),
291
0
        };
292
293
0
        let res = response.spdm_encode(&mut self.common, writer);
294
0
        if res.is_err() {
295
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
296
0
            return (
297
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
298
0
                Some(writer.used_slice()),
299
0
            );
300
0
        }
301
0
        let used = writer.used();
302
303
0
        let base_hash_size = self.common.get_hash_size() as usize;
304
0
        let temp_used = used - base_hash_size;
305
306
0
        if self
307
0
            .common
308
0
            .append_message_k(session_id, &bytes[..reader.used()])
309
0
            .is_err()
310
        {
311
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
312
0
            return (
313
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
314
0
                Some(writer.used_slice()),
315
0
            );
316
0
        }
317
0
        if self
318
0
            .common
319
0
            .append_message_k(session_id, &writer.used_slice()[..temp_used])
320
0
            .is_err()
321
        {
322
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
323
0
            return (
324
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
325
0
                Some(writer.used_slice()),
326
0
            );
327
0
        }
328
329
0
        let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) {
330
0
            session
331
        } else {
332
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
333
0
            return (
334
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
335
0
                Some(writer.used_slice()),
336
0
            );
337
        };
338
339
        // create session - generate the handshake secret (including finished_key)
340
0
        let th1 = self.common.calc_rsp_transcript_hash(true, false, session);
341
0
        if th1.is_err() {
342
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
343
0
            return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice()));
344
0
        }
345
0
        let th1 = th1.unwrap();
346
0
        debug!("!!! th1 : {:02x?}\n", th1.as_ref());
347
348
0
        let session = if let Some(session) = self.common.get_session_via_id(session_id) {
349
0
            session
350
        } else {
351
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
352
0
            return (
353
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
354
0
                Some(writer.used_slice()),
355
0
            );
356
        };
357
0
        session.set_th1(th1.clone());
358
0
        if let Err(e) = session.generate_handshake_secret(spdm_version_sel, &th1) {
359
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
360
0
            return (Err(e), Some(writer.used_slice()));
361
0
        }
362
363
0
        let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) {
364
0
            session
365
        } else {
366
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
367
0
            return (
368
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
369
0
                Some(writer.used_slice()),
370
0
            );
371
        };
372
        // generate HMAC with finished_key
373
0
        let transcript_hash = self.common.calc_rsp_transcript_hash(true, false, session);
374
0
        if transcript_hash.is_err() {
375
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
376
0
            return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice()));
377
0
        }
378
0
        let transcript_hash = transcript_hash.unwrap();
379
380
0
        let hmac = session.generate_hmac_with_response_finished_key(transcript_hash.as_ref());
381
0
        if hmac.is_err() {
382
0
            let session = if let Some(session) = self.common.get_session_via_id(session_id) {
383
0
                session
384
            } else {
385
0
                self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
386
0
                return (
387
0
                    Err(SPDM_STATUS_INVALID_STATE_LOCAL),
388
0
                    Some(writer.used_slice()),
389
0
                );
390
            };
391
0
            session.teardown();
392
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
393
0
            return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice()));
394
0
        }
395
0
        let hmac = hmac.unwrap();
396
397
        // append verify_data after TH1
398
0
        if self
399
0
            .common
400
0
            .append_message_k(session_id, hmac.as_ref())
401
0
            .is_err()
402
        {
403
0
            let session = if let Some(session) = self.common.get_session_via_id(session_id) {
404
0
                session
405
            } else {
406
0
                self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
407
0
                return (
408
0
                    Err(SPDM_STATUS_INVALID_STATE_LOCAL),
409
0
                    Some(writer.used_slice()),
410
0
                );
411
            };
412
0
            session.teardown();
413
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
414
0
            return (
415
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
416
0
                Some(writer.used_slice()),
417
0
            );
418
0
        }
419
420
        // patch the message before send
421
0
        writer.mut_used_slice()[(used - base_hash_size)..used].copy_from_slice(hmac.as_ref());
422
0
        let heartbeat_period = self.common.config_info.heartbeat_period;
423
0
        let session = if let Some(session) = self.common.get_session_via_id(session_id) {
424
0
            session
425
        } else {
426
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
427
0
            return (
428
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
429
0
                Some(writer.used_slice()),
430
0
            );
431
        };
432
0
        session.set_session_state(crate::common::session::SpdmSessionState::SpdmSessionHandshaking);
433
434
0
        let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) {
435
0
            session
436
        } else {
437
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
438
0
            return (
439
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
440
0
                Some(writer.used_slice()),
441
0
            );
442
        };
443
444
0
        if psk_without_context {
445
            // generate the data secret directly to skip PSK_FINISH
446
0
            let th2 = self.common.calc_rsp_transcript_hash(true, false, session);
447
0
            if th2.is_err() {
448
0
                self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
449
0
                return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice()));
450
0
            }
451
0
            let th2 = th2.unwrap();
452
0
            debug!("!!! th2 : {:02x?}\n", th2.as_ref());
453
0
            let spdm_version_sel = self.common.negotiate_info.spdm_version_sel;
454
0
            let heartbeat_period = {
455
0
                let session = if let Some(session) = self.common.get_session_via_id(session_id) {
456
0
                    session
457
                } else {
458
0
                    self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
459
0
                    return (
460
0
                        Err(SPDM_STATUS_INVALID_STATE_LOCAL),
461
0
                        Some(writer.used_slice()),
462
0
                    );
463
                };
464
0
                session.set_th2(th2.clone());
465
0
                if session
466
0
                    .generate_data_secret(spdm_version_sel, &th2)
467
0
                    .is_err()
468
                {
469
0
                    self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
470
0
                    return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice()));
471
0
                }
472
0
                session.set_session_state(
473
0
                    crate::common::session::SpdmSessionState::SpdmSessionEstablished,
474
                );
475
476
0
                session.heartbeat_period
477
            };
478
0
            if self
479
0
                .common
480
0
                .negotiate_info
481
0
                .req_capabilities_sel
482
0
                .contains(SpdmRequestCapabilityFlags::HBEAT_CAP)
483
0
                && self
484
0
                    .common
485
0
                    .negotiate_info
486
0
                    .rsp_capabilities_sel
487
0
                    .contains(SpdmResponseCapabilityFlags::HBEAT_CAP)
488
0
            {
489
0
                start_watchdog(session_id, heartbeat_period as u16 * 2);
490
0
            }
491
0
        }
492
493
0
        let session = if let Some(session) = self.common.get_session_via_id(session_id) {
494
0
            session
495
        } else {
496
0
            self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer);
497
0
            return (
498
0
                Err(SPDM_STATUS_INVALID_STATE_LOCAL),
499
0
                Some(writer.used_slice()),
500
0
            );
501
        };
502
0
        session.heartbeat_period = heartbeat_period;
503
0
        if return_opaque.data_size != 0 {
504
0
            session.secure_spdm_version_sel = if let Ok(ssvs) = SecuredMessageVersion::try_from(
505
0
                return_opaque.data[return_opaque.data_size as usize - 1],
506
0
            ) {
507
0
                ssvs
508
            } else {
509
0
                self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer);
510
0
                return (
511
0
                    Err(SPDM_STATUS_INVALID_MSG_FIELD),
512
0
                    Some(writer.used_slice()),
513
0
                );
514
            };
515
0
        }
516
517
0
        (Ok(()), Some(writer.used_slice()))
518
0
    }
519
}