Coverage Report

Created: 2025-04-11 06:34

/src/botan/src/lib/tls/tls12/tls_handshake_io.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* TLS Handshake IO
3
* (C) 2012,2014,2015 Jack Lloyd
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#include <botan/internal/tls_handshake_io.h>
9
10
#include <botan/exceptn.h>
11
#include <botan/tls_messages.h>
12
#include <botan/internal/loadstor.h>
13
#include <botan/internal/tls_record.h>
14
#include <botan/internal/tls_seq_numbers.h>
15
#include <chrono>
16
17
namespace Botan::TLS {
18
19
namespace {
20
21
6.80k
inline size_t load_be24(const uint8_t q[3]) {
22
6.80k
   return make_uint32(0, q[0], q[1], q[2]);
23
6.80k
}
24
25
107k
void store_be24(uint8_t out[3], size_t val) {
26
107k
   out[0] = get_byte<1>(static_cast<uint32_t>(val));
27
107k
   out[1] = get_byte<2>(static_cast<uint32_t>(val));
28
107k
   out[2] = get_byte<3>(static_cast<uint32_t>(val));
29
107k
}
30
31
37
uint64_t steady_clock_ms() {
32
37
   return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch())
33
37
      .count();
34
37
}
35
36
}  // namespace
37
38
28.6k
Protocol_Version Stream_Handshake_IO::initial_record_version() const {
39
28.6k
   return Protocol_Version::TLS_V12;
40
28.6k
}
41
42
void Stream_Handshake_IO::add_record(const uint8_t record[],
43
                                     size_t record_len,
44
                                     Record_Type record_type,
45
45.7k
                                     uint64_t /*sequence_number*/) {
46
45.7k
   if(record_type == Record_Type::Handshake) {
47
44.8k
      m_queue.insert(m_queue.end(), record, record + record_len);
48
44.8k
   } else if(record_type == Record_Type::ChangeCipherSpec) {
49
808
      if(record_len != 1 || record[0] != 1) {
50
45
         throw Decoding_Error("Invalid ChangeCipherSpec");
51
45
      }
52
53
      // Pretend it's a regular handshake message of zero length
54
763
      const uint8_t ccs_hs[] = {static_cast<uint8_t>(Handshake_Type::HandshakeCCS), 0, 0, 0};
55
763
      m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
56
763
   } else {
57
0
      throw Decoding_Error("Unknown message type " + std::to_string(static_cast<size_t>(record_type)) +
58
0
                           " in handshake processing");
59
0
   }
60
45.7k
}
61
62
77.9k
std::pair<Handshake_Type, std::vector<uint8_t>> Stream_Handshake_IO::get_next_record(bool /*expecting_ccs*/) {
63
77.9k
   if(m_queue.size() >= 4) {
64
47.3k
      const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
65
66
47.3k
      if(m_queue.size() >= length) {
67
36.6k
         Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
68
69
36.6k
         if(type == Handshake_Type::None) {
70
6
            throw Decoding_Error("Invalid handshake message type");
71
6
         }
72
73
36.6k
         std::vector<uint8_t> contents(m_queue.begin() + 4, m_queue.begin() + length);
74
75
36.6k
         m_queue.erase(m_queue.begin(), m_queue.begin() + length);
76
77
36.6k
         return std::make_pair(type, contents);
78
36.6k
      }
79
47.3k
   }
80
81
41.3k
   return std::make_pair(Handshake_Type::None, std::vector<uint8_t>());
82
77.9k
}
83
84
107k
std::vector<uint8_t> Stream_Handshake_IO::format(const std::vector<uint8_t>& msg, Handshake_Type type) const {
85
107k
   std::vector<uint8_t> send_buf(4 + msg.size());
86
87
107k
   const size_t buf_size = msg.size();
88
89
107k
   send_buf[0] = static_cast<uint8_t>(type);
90
91
107k
   store_be24(&send_buf[1], buf_size);
92
93
107k
   if(!msg.empty()) {
94
86.2k
      copy_mem(&send_buf[4], msg.data(), msg.size());
95
86.2k
   }
96
97
107k
   return send_buf;
98
107k
}
99
100
0
std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/) {
101
0
   throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
102
0
}
103
104
72.0k
std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg) {
105
72.0k
   const std::vector<uint8_t> msg_bits = msg.serialize();
106
107
72.0k
   if(msg.type() == Handshake_Type::HandshakeCCS) {
108
244
      m_send_hs(Record_Type::ChangeCipherSpec, msg_bits);
109
244
      return std::vector<uint8_t>();  // not included in handshake hashes
110
244
   }
111
112
71.7k
   auto buf = format(msg_bits, msg.wire_type());
113
71.7k
   m_send_hs(Record_Type::Handshake, buf);
114
71.7k
   return buf;
115
72.0k
}
116
117
805
Protocol_Version Datagram_Handshake_IO::initial_record_version() const {
118
805
   return Protocol_Version::DTLS_V12;
119
805
}
120
121
0
void Datagram_Handshake_IO::retransmit_last_flight() {
122
0
   const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
123
0
   retransmit_flight(flight_idx);
124
0
}
125
126
0
void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx) {
127
0
   const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
128
129
0
   BOTAN_ASSERT(!flight.empty(), "Nonempty flight to retransmit");
130
131
0
   uint16_t epoch = m_flight_data[flight[0]].epoch;
132
133
0
   for(auto msg_seq : flight) {
134
0
      auto& msg = m_flight_data[msg_seq];
135
136
0
      if(msg.epoch != epoch) {
137
         // Epoch gap: insert the CCS
138
0
         std::vector<uint8_t> ccs(1, 1);
139
0
         m_send_hs(epoch, Record_Type::ChangeCipherSpec, ccs);
140
0
      }
141
142
0
      send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
143
0
      epoch = msg.epoch;
144
0
   }
145
0
}
146
147
36
bool Datagram_Handshake_IO::have_more_data() const {
148
36
   return false;
149
36
}
150
151
0
bool Datagram_Handshake_IO::timeout_check() {
152
0
   if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty())) {
153
      /*
154
      If we haven't written anything yet obviously no timeout.
155
      Also no timeout possible if we are mid-flight,
156
      */
157
0
      return false;
158
0
   }
159
160
0
   const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
161
162
0
   if(ms_since_write < m_next_timeout) {
163
0
      return false;
164
0
   }
165
166
0
   retransmit_last_flight();
167
168
0
   m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
169
0
   return true;
170
0
}
171
172
void Datagram_Handshake_IO::add_record(const uint8_t record[],
173
                                       size_t record_len,
174
                                       Record_Type record_type,
175
1.13k
                                       uint64_t record_sequence) {
176
1.13k
   const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
177
178
1.13k
   if(record_type == Record_Type::ChangeCipherSpec) {
179
40
      if(record_len != 1 || record[0] != 1) {
180
11
         throw Decoding_Error("Invalid ChangeCipherSpec");
181
11
      }
182
183
      // TODO: check this is otherwise empty
184
29
      m_ccs_epochs.insert(epoch);
185
29
      return;
186
40
   }
187
188
1.09k
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
189
190
3.30k
   while(record_len) {
191
2.49k
      if(record_len < DTLS_HANDSHAKE_HEADER_LEN) {
192
228
         return;  // completely bogus? at least degenerate/weird
193
228
      }
194
195
2.26k
      const Handshake_Type msg_type = static_cast<Handshake_Type>(record[0]);
196
2.26k
      const size_t msg_len = load_be24(&record[1]);
197
2.26k
      const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
198
2.26k
      const size_t fragment_offset = load_be24(&record[6]);
199
2.26k
      const size_t fragment_length = load_be24(&record[9]);
200
201
2.26k
      const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
202
203
2.26k
      if(record_len < total_size) {
204
57
         throw Decoding_Error("Bad lengths in DTLS header");
205
57
      }
206
207
2.21k
      if(message_seq >= m_in_message_seq) {
208
1.89k
         m_messages[message_seq].add_fragment(
209
1.89k
            &record[DTLS_HANDSHAKE_HEADER_LEN], fragment_length, fragment_offset, epoch, msg_type, msg_len);
210
1.89k
      } else {
211
         // TODO: detect retransmitted flight
212
313
      }
213
214
2.21k
      record += total_size;
215
2.21k
      record_len -= total_size;
216
2.21k
   }
217
1.09k
}
218
219
1.01k
std::pair<Handshake_Type, std::vector<uint8_t>> Datagram_Handshake_IO::get_next_record(bool expecting_ccs) {
220
   // Expecting a message means the last flight is concluded
221
1.01k
   if(!m_flights.rbegin()->empty()) {
222
12
      m_flights.push_back(std::vector<uint16_t>());
223
12
   }
224
225
1.01k
   if(expecting_ccs) {
226
0
      if(!m_messages.empty()) {
227
0
         const uint16_t current_epoch = m_messages.begin()->second.epoch();
228
229
0
         if(m_ccs_epochs.contains(current_epoch)) {
230
0
            return std::make_pair(Handshake_Type::HandshakeCCS, std::vector<uint8_t>());
231
0
         }
232
0
      }
233
0
      return std::make_pair(Handshake_Type::None, std::vector<uint8_t>());
234
0
   }
235
236
1.01k
   auto i = m_messages.find(m_in_message_seq);
237
238
1.01k
   if(i == m_messages.end() || !i->second.complete()) {
239
903
      return std::make_pair(Handshake_Type::None, std::vector<uint8_t>());
240
903
   }
241
242
108
   m_in_message_seq += 1;
243
244
108
   return i->second.message();
245
1.01k
}
246
247
void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(const uint8_t fragment[],
248
                                                               size_t fragment_length,
249
                                                               size_t fragment_offset,
250
                                                               uint16_t epoch,
251
                                                               Handshake_Type msg_type,
252
1.89k
                                                               size_t msg_length) {
253
1.89k
   if(complete()) {
254
427
      return;  // already have entire message, ignore this
255
427
   }
256
257
1.47k
   if(m_msg_type == Handshake_Type::None) {
258
1.27k
      m_epoch = epoch;
259
1.27k
      m_msg_type = msg_type;
260
1.27k
      m_msg_length = msg_length;
261
1.27k
   }
262
263
1.47k
   if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch) {
264
54
      throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
265
54
   }
266
267
1.41k
   if(fragment_offset > m_msg_length) {
268
11
      throw Decoding_Error("Fragment offset past end of message");
269
11
   }
270
271
1.40k
   if(fragment_offset + fragment_length > m_msg_length) {
272
14
      throw Decoding_Error("Fragment overlaps past end of message");
273
14
   }
274
275
1.39k
   if(fragment_offset == 0 && fragment_length == m_msg_length) {
276
304
      m_fragments.clear();
277
304
      m_message.assign(fragment, fragment + fragment_length);
278
1.08k
   } else {
279
      /*
280
      * FIXME. This is a pretty lame way to do defragmentation, huge
281
      * overhead with a tree node per byte.
282
      *
283
      * Also should confirm that all overlaps have no changes,
284
      * otherwise we expose ourselves to the classic fingerprinting
285
      * and IDS evasion attacks on IP fragmentation.
286
      */
287
19.1k
      for(size_t i = 0; i != fragment_length; ++i) {
288
18.1k
         m_fragments[fragment_offset + i] = fragment[i];
289
18.1k
      }
290
291
1.08k
      if(m_fragments.size() == m_msg_length) {
292
105
         m_message.resize(m_msg_length);
293
4.83k
         for(size_t i = 0; i != m_msg_length; ++i) {
294
4.72k
            m_message[i] = m_fragments[i];
295
4.72k
         }
296
105
         m_fragments.clear();
297
105
      }
298
1.08k
   }
299
1.39k
}
300
301
2.30k
bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const {
302
2.30k
   return (m_msg_type != Handshake_Type::None && m_message.size() == m_msg_length);
303
2.30k
}
304
305
108
std::pair<Handshake_Type, std::vector<uint8_t>> Datagram_Handshake_IO::Handshake_Reassembly::message() const {
306
108
   if(!complete()) {
307
0
      throw Internal_Error("Datagram_Handshake_IO - message not complete");
308
0
   }
309
310
108
   return std::make_pair(m_msg_type, m_message);
311
108
}
312
313
std::vector<uint8_t> Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
314
                                                            size_t frag_len,
315
                                                            uint16_t frag_offset,
316
                                                            uint16_t msg_len,
317
                                                            Handshake_Type type,
318
87
                                                            uint16_t msg_sequence) const {
319
87
   std::vector<uint8_t> send_buf(12 + frag_len);
320
321
87
   send_buf[0] = static_cast<uint8_t>(type);
322
323
87
   store_be24(&send_buf[1], msg_len);
324
325
87
   store_be(msg_sequence, &send_buf[4]);
326
327
87
   store_be24(&send_buf[6], frag_offset);
328
87
   store_be24(&send_buf[9], frag_len);
329
330
87
   if(frag_len > 0) {
331
73
      copy_mem(&send_buf[12], fragment, frag_len);
332
73
   }
333
334
87
   return send_buf;
335
87
}
336
337
std::vector<uint8_t> Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
338
                                                         Handshake_Type type,
339
87
                                                         uint16_t msg_sequence) const {
340
87
   return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
341
87
}
342
343
36
std::vector<uint8_t> Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg, Handshake_Type type) const {
344
36
   return format_w_seq(msg, type, m_in_message_seq - 1);
345
36
}
346
347
51
std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg) {
348
51
   return this->send_under_epoch(msg, m_seqs.current_write_epoch());
349
51
}
350
351
51
std::vector<uint8_t> Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch) {
352
51
   const std::vector<uint8_t> msg_bits = msg.serialize();
353
51
   const Handshake_Type msg_type = msg.type();
354
355
51
   if(msg_type == Handshake_Type::HandshakeCCS) {
356
0
      m_send_hs(epoch, Record_Type::ChangeCipherSpec, msg_bits);
357
0
      return std::vector<uint8_t>();  // not included in handshake hashes
358
51
   } else if(msg_type == Handshake_Type::HelloVerifyRequest) {
359
      // This message is not included in the handshake hashes
360
14
      send_message(m_out_message_seq, epoch, msg_type, msg_bits);
361
14
      m_out_message_seq += 1;
362
14
      return std::vector<uint8_t>();
363
14
   }
364
365
   // Note: not saving CCS, instead we know it was there due to change in epoch
366
37
   m_flights.rbegin()->push_back(m_out_message_seq);
367
37
   m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
368
369
37
   m_out_message_seq += 1;
370
37
   m_last_write = steady_clock_ms();
371
37
   m_next_timeout = m_initial_timeout;
372
373
37
   return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
374
51
}
375
376
std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
377
                                                         uint16_t epoch,
378
                                                         Handshake_Type msg_type,
379
51
                                                         const std::vector<uint8_t>& msg_bits) {
380
51
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
381
382
51
   auto no_fragment = format_w_seq(msg_bits, msg_type, msg_seq);
383
384
51
   if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) {
385
51
      m_send_hs(epoch, Record_Type::Handshake, no_fragment);
386
51
   } else {
387
0
      size_t frag_offset = 0;
388
389
      /**
390
      * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
391
      * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
392
      * over-estimate here. When CBC ciphers are removed this can be reduced
393
      * since AEAD modes have no padding, at most 16 byte mac, and smaller
394
      * per-record nonce.
395
      */
396
0
      const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
397
0
      const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
398
399
0
      if(m_mtu <= (header_overhead + ciphersuite_overhead)) {
400
0
         throw Invalid_Argument("DTLS MTU is too small to send headers");
401
0
      }
402
403
0
      const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
404
405
0
      while(frag_offset != msg_bits.size()) {
406
0
         const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
407
408
0
         const std::vector<uint8_t> frag = format_fragment(&msg_bits[frag_offset],
409
0
                                                           frag_len,
410
0
                                                           static_cast<uint16_t>(frag_offset),
411
0
                                                           static_cast<uint16_t>(msg_bits.size()),
412
0
                                                           msg_type,
413
0
                                                           msg_seq);
414
415
0
         m_send_hs(epoch, Record_Type::Handshake, frag);
416
417
0
         frag_offset += frag_len;
418
0
      }
419
0
   }
420
421
51
   return no_fragment;
422
51
}
423
424
}  // namespace Botan::TLS