Coverage Report

Created: 2025-07-18 06:52

/src/spdm-rs/spdmlib/src/requester/psk_exchange_req.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2020 Intel Corporation
2
//
3
// SPDX-License-Identifier: Apache-2.0 or MIT
4
5
use config::MAX_SPDM_PSK_CONTEXT_SIZE;
6
7
use crate::crypto;
8
use crate::error::{
9
    SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD,
10
    SPDM_STATUS_INVALID_PARAMETER, SPDM_STATUS_SESSION_NUMBER_EXCEED, SPDM_STATUS_VERIF_FAIL,
11
};
12
use crate::error::{SPDM_STATUS_BUFFER_FULL, SPDM_STATUS_INVALID_STATE_LOCAL};
13
use crate::message::*;
14
use crate::protocol::*;
15
use crate::requester::*;
16
extern crate alloc;
17
use core::ops::DerefMut;
18
19
impl RequesterContext {
20
    #[maybe_async::maybe_async]
21
    pub async fn send_receive_spdm_psk_exchange(
22
        &mut self,
23
        measurement_summary_hash_type: SpdmMeasurementSummaryHashType,
24
        psk_hint: Option<&SpdmPskHintStruct>,
25
0
    ) -> SpdmResult<u32> {
26
0
        info!("send spdm psk exchange\n");
27
28
0
        let psk_hint = if let Some(hint) = psk_hint {
29
0
            hint.clone()
30
        } else {
31
0
            SpdmPskHintStruct::default()
32
        };
33
34
0
        self.common
35
0
            .reset_buffer_via_request_code(SpdmRequestResponseCode::SpdmRequestPskExchange, None);
36
0
37
0
        let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE];
38
0
        let half_session_id = self.common.get_next_half_session_id(true)?;
39
0
        let send_used = self.encode_spdm_psk_exchange(
40
0
            half_session_id,
41
0
            measurement_summary_hash_type,
42
0
            &psk_hint,
43
0
            &mut send_buffer,
44
0
        )?;
45
46
0
        self.send_message(None, &send_buffer[..send_used], false)
47
0
            .await?;
48
49
        // Receive
50
0
        let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE];
51
0
        let receive_used = self
52
0
            .receive_message(None, &mut receive_buffer, false)
53
0
            .await?;
54
55
0
        let mut target_session_id = None;
56
0
        if let Err(e) = self.handle_spdm_psk_exchange_response(
57
0
            half_session_id,
58
0
            measurement_summary_hash_type,
59
0
            &psk_hint,
60
0
            &send_buffer[..send_used],
61
0
            &receive_buffer[..receive_used],
62
0
            &mut target_session_id,
63
0
        ) {
64
0
            if let Some(session_id) = target_session_id {
65
0
                if let Some(session) = self.common.get_session_via_id(session_id) {
66
0
                    session.teardown();
67
0
                }
68
0
            }
69
70
0
            Err(e)
71
        } else {
72
0
            Ok(target_session_id.unwrap())
73
        }
74
0
    }
75
76
    pub fn encode_spdm_psk_exchange(
77
        &mut self,
78
        half_session_id: u16,
79
        measurement_summary_hash_type: SpdmMeasurementSummaryHashType,
80
        psk_hint: &SpdmPskHintStruct,
81
        buf: &mut [u8],
82
    ) -> SpdmResult<usize> {
83
        let mut writer = Writer::init(buf);
84
85
        let mut psk_context = [0u8; MAX_SPDM_PSK_CONTEXT_SIZE];
86
        crypto::rand::get_random(&mut psk_context)?;
87
88
        let mut secured_message_version_list = SecuredMessageVersionList {
89
            version_count: 0,
90
            versions_list: [SecuredMessageVersion::default(); MAX_SECURE_SPDM_VERSION_COUNT],
91
        };
92
93
        for local_version in self.common.config_info.secure_spdm_version.iter().flatten() {
94
            secured_message_version_list.versions_list
95
                [secured_message_version_list.version_count as usize] = *local_version;
96
            secured_message_version_list.version_count += 1;
97
        }
98
99
        let opaque = SpdmOpaqueStruct::from_sm_supported_ver_list_opaque(
100
            &mut self.common,
101
            &SMSupportedVerListOpaque {
102
                secured_message_version_list,
103
            },
104
        )?;
105
106
        let request = SpdmMessage {
107
            header: SpdmMessageHeader {
108
                version: self.common.negotiate_info.spdm_version_sel,
109
                request_response_code: SpdmRequestResponseCode::SpdmRequestPskExchange,
110
            },
111
            payload: SpdmMessagePayload::SpdmPskExchangeRequest(SpdmPskExchangeRequestPayload {
112
                measurement_summary_hash_type,
113
                req_session_id: half_session_id,
114
                psk_hint: psk_hint.clone(),
115
                psk_context: SpdmPskContextStruct {
116
                    data_size: self.common.negotiate_info.base_hash_sel.get_size(),
117
                    data: psk_context,
118
                },
119
                opaque,
120
            }),
121
        };
122
        request.spdm_encode(&mut self.common, &mut writer)
123
    }
124
125
    pub fn handle_spdm_psk_exchange_response(
126
        &mut self,
127
        half_session_id: u16,
128
        measurement_summary_hash_type: SpdmMeasurementSummaryHashType,
129
        psk_hint: &SpdmPskHintStruct,
130
        send_buffer: &[u8],
131
        receive_buffer: &[u8],
132
        target_session_id: &mut Option<u32>,
133
    ) -> SpdmResult<u32> {
134
        self.common.runtime_info.need_measurement_summary_hash = (measurement_summary_hash_type
135
            == SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeTcb)
136
            || (measurement_summary_hash_type
137
                == SpdmMeasurementSummaryHashType::SpdmMeasurementSummaryHashTypeAll);
138
139
        let mut reader = Reader::init(receive_buffer);
140
        match SpdmMessageHeader::read(&mut reader) {
141
            Some(message_header) => {
142
                if message_header.version != self.common.negotiate_info.spdm_version_sel {
143
                    return Err(SPDM_STATUS_INVALID_MSG_FIELD);
144
                }
145
                match message_header.request_response_code {
146
                    SpdmRequestResponseCode::SpdmResponsePskExchangeRsp => {
147
                        let psk_exchange_rsp = SpdmPskExchangeResponsePayload::spdm_read(
148
                            &mut self.common,
149
                            &mut reader,
150
                        );
151
                        let receive_used = reader.used();
152
                        if let Some(psk_exchange_rsp) = psk_exchange_rsp {
153
                            debug!("!!! psk_exchange rsp : {:02x?}\n", psk_exchange_rsp);
154
155
                            // create session structure
156
                            let base_hash_algo = self.common.negotiate_info.base_hash_sel;
157
                            let dhe_algo = self.common.negotiate_info.dhe_sel;
158
                            let aead_algo = self.common.negotiate_info.aead_sel;
159
                            let key_schedule_algo = self.common.negotiate_info.key_schedule_sel;
160
                            let sequence_number_count = {
161
                                let mut transport_encap = self.common.transport_encap.lock();
162
                                let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) =
163
                                    transport_encap.deref_mut();
164
                                transport_encap.get_sequence_number_count()
165
                            };
166
                            let max_random_count = {
167
                                let mut transport_encap = self.common.transport_encap.lock();
168
                                let transport_encap: &mut (dyn SpdmTransportEncap + Send + Sync) =
169
                                    transport_encap.deref_mut();
170
                                transport_encap.get_max_random_count()
171
                            };
172
173
                            let secure_spdm_version_sel = psk_exchange_rsp
174
                                .opaque
175
                                .req_get_dmtf_secure_spdm_version_selection(&mut self.common)
176
                                .ok_or(SPDM_STATUS_INVALID_MSG_FIELD)?;
177
178
                            let session_id = ((psk_exchange_rsp.rsp_session_id as u32) << 16)
179
                                + half_session_id as u32;
180
                            *target_session_id = Some(session_id);
181
                            let spdm_version_sel = self.common.negotiate_info.spdm_version_sel;
182
                            let message_a = self.common.runtime_info.message_a.clone();
183
184
                            let session = self
185
                                .common
186
                                .get_next_avaiable_session()
187
                                .ok_or(SPDM_STATUS_SESSION_NUMBER_EXCEED)?;
188
189
                            session.setup(session_id)?;
190
191
                            session.set_use_psk(true);
192
193
                            session.set_crypto_param(
194
                                base_hash_algo,
195
                                dhe_algo,
196
                                aead_algo,
197
                                key_schedule_algo,
198
                            );
199
                            session.set_transport_param(sequence_number_count, max_random_count);
200
201
                            session.runtime_info.psk_hint = Some(psk_hint.clone());
202
                            session.runtime_info.message_a = message_a;
203
                            session.runtime_info.rsp_cert_hash = None;
204
                            session.runtime_info.req_cert_hash = None;
205
206
                            // create transcript
207
                            let base_hash_size =
208
                                self.common.negotiate_info.base_hash_sel.get_size() as usize;
209
                            let temp_receive_used = receive_used - base_hash_size;
210
211
                            self.common.append_message_k(session_id, send_buffer)?;
212
                            self.common.append_message_k(
213
                                session_id,
214
                                &receive_buffer[..temp_receive_used],
215
                            )?;
216
217
                            let session = self
218
                                .common
219
                                .get_immutable_session_via_id(session_id)
220
                                .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
221
222
                            // generate the handshake secret (including finished_key) before verify HMAC
223
                            let th1 = self.common.calc_req_transcript_hash(
224
                                true,
225
                                INVALID_SLOT,
226
                                false,
227
                                session,
228
                            )?;
229
                            debug!("!!! th1 : {:02x?}\n", th1.as_ref());
230
231
                            let session = self
232
                                .common
233
                                .get_session_via_id(session_id)
234
                                .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
235
                            session.generate_handshake_secret(spdm_version_sel, &th1)?;
236
237
                            let session = self
238
                                .common
239
                                .get_immutable_session_via_id(session_id)
240
                                .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
241
242
                            // verify HMAC with finished_key
243
                            let transcript_hash = self.common.calc_req_transcript_hash(
244
                                true,
245
                                INVALID_SLOT,
246
                                false,
247
                                session,
248
                            )?;
249
250
                            let session = self
251
                                .common
252
                                .get_immutable_session_via_id(session_id)
253
                                .ok_or(SPDM_STATUS_INVALID_PARAMETER)?;
254
255
                            if session
256
                                .verify_hmac_with_response_finished_key(
257
                                    transcript_hash.as_ref(),
258
                                    &psk_exchange_rsp.verify_data,
259
                                )
260
                                .is_err()
261
                            {
262
                                error!("verify_hmac_with_response_finished_key fail");
263
                                let session = self
264
                                    .common
265
                                    .get_session_via_id(session_id)
266
                                    .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
267
                                session.teardown();
268
                                return Err(SPDM_STATUS_VERIF_FAIL);
269
                            } else {
270
                                info!("verify_hmac_with_response_finished_key pass");
271
                            }
272
273
                            // append verify_data after TH1
274
                            if self
275
                                .common
276
                                .append_message_k(session_id, psk_exchange_rsp.verify_data.as_ref())
277
                                .is_err()
278
                            {
279
                                let session = self
280
                                    .common
281
                                    .get_session_via_id(session_id)
282
                                    .ok_or(SPDM_STATUS_INVALID_PARAMETER)?;
283
                                session.teardown();
284
                                return Err(SPDM_STATUS_BUFFER_FULL);
285
                            }
286
287
                            let session = self
288
                                .common
289
                                .get_session_via_id(session_id)
290
                                .ok_or(SPDM_STATUS_INVALID_PARAMETER)?;
291
                            session.set_session_state(
292
                                crate::common::session::SpdmSessionState::SpdmSessionHandshaking,
293
                            );
294
295
                            let session = self
296
                                .common
297
                                .get_immutable_session_via_id(session_id)
298
                                .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
299
                            let psk_without_context = self
300
                                .common
301
                                .negotiate_info
302
                                .rsp_capabilities_sel
303
                                .contains(SpdmResponseCapabilityFlags::PSK_CAP_WITHOUT_CONTEXT);
304
                            if psk_without_context {
305
                                // generate the data secret directly to skip PSK_FINISH
306
                                let th2 = self.common.calc_req_transcript_hash(
307
                                    true,
308
                                    INVALID_SLOT,
309
                                    false,
310
                                    session,
311
                                )?;
312
313
                                debug!("!!! th2 : {:02x?}\n", th2.as_ref());
314
315
                                let session = self
316
                                    .common
317
                                    .get_session_via_id(session_id)
318
                                    .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
319
                                session.generate_data_secret(spdm_version_sel, &th2)?;
320
                                session.set_session_state(
321
                                    crate::common::session::SpdmSessionState::SpdmSessionEstablished,
322
                                );
323
                            }
324
325
                            let session = self
326
                                .common
327
                                .get_session_via_id(session_id)
328
                                .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?;
329
                            session.secure_spdm_version_sel = secure_spdm_version_sel;
330
                            session.heartbeat_period = psk_exchange_rsp.heartbeat_period;
331
332
                            Ok(session_id)
333
                        } else {
334
                            error!("!!! psk_exchange : fail !!!\n");
335
                            Err(SPDM_STATUS_INVALID_MSG_FIELD)
336
                        }
337
                    }
338
                    SpdmRequestResponseCode::SpdmResponseError => {
339
                        let status = self.spdm_handle_error_response_main(
340
                            None,
341
                            receive_buffer,
342
                            SpdmRequestResponseCode::SpdmRequestPskExchange,
343
                            SpdmRequestResponseCode::SpdmResponsePskExchangeRsp,
344
                        );
345
                        match status {
346
                            Err(status) => Err(status),
347
                            Ok(()) => Err(SPDM_STATUS_ERROR_PEER),
348
                        }
349
                    }
350
                    _ => Err(SPDM_STATUS_ERROR_PEER),
351
                }
352
            }
353
            None => Err(SPDM_STATUS_INVALID_MSG_FIELD),
354
        }
355
    }
356
}