Coverage Report

Created: 2022-06-23 06:44

/src/botan/src/lib/tls/tls12/tls_handshake_io.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* TLS Handshake IO
3
* (C) 2012,2014,2015 Jack Lloyd
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#include <botan/internal/tls_handshake_io.h>
9
#include <botan/internal/tls_record.h>
10
#include <botan/internal/tls_seq_numbers.h>
11
#include <botan/tls_messages.h>
12
#include <botan/exceptn.h>
13
#include <botan/internal/loadstor.h>
14
#include <chrono>
15
16
namespace Botan::TLS {
17
18
namespace {
19
20
inline size_t load_be24(const uint8_t q[3])
21
4.74k
   {
22
4.74k
   return make_uint32(0,
23
4.74k
                      q[0],
24
4.74k
                      q[1],
25
4.74k
                      q[2]);
26
4.74k
   }
27
28
void store_be24(uint8_t out[3], size_t val)
29
99.4k
   {
30
99.4k
   out[0] = get_byte<1>(static_cast<uint32_t>(val));
31
99.4k
   out[1] = get_byte<2>(static_cast<uint32_t>(val));
32
99.4k
   out[2] = get_byte<3>(static_cast<uint32_t>(val));
33
99.4k
   }
34
35
uint64_t steady_clock_ms()
36
11
   {
37
11
   return std::chrono::duration_cast<std::chrono::milliseconds>(
38
11
      std::chrono::steady_clock::now().time_since_epoch()).count();
39
11
   }
40
41
}
42
43
Protocol_Version Stream_Handshake_IO::initial_record_version() const
44
24.3k
   {
45
24.3k
   return Protocol_Version::TLS_V12;
46
24.3k
   }
47
48
void Stream_Handshake_IO::add_record(const uint8_t record[],
49
                                     size_t record_len,
50
                                     Record_Type record_type, uint64_t /*sequence_number*/)
51
44.1k
   {
52
44.1k
   if(record_type == HANDSHAKE)
53
42.8k
      {
54
42.8k
      m_queue.insert(m_queue.end(), record, record + record_len);
55
42.8k
      }
56
1.31k
   else if(record_type == CHANGE_CIPHER_SPEC)
57
1.31k
      {
58
1.31k
      if(record_len != 1 || record[0] != 1)
59
59
         throw Decoding_Error("Invalid ChangeCipherSpec");
60
61
      // Pretend it's a regular handshake message of zero length
62
1.25k
      const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
63
1.25k
      m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
64
1.25k
      }
65
0
   else
66
0
      throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
67
44.1k
   }
68
69
std::pair<Handshake_Type, std::vector<uint8_t>>
70
Stream_Handshake_IO::get_next_record(bool /*expecting_ccs*/)
71
79.0k
   {
72
79.0k
   if(m_queue.size() >= 4)
73
55.2k
      {
74
55.2k
      const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
75
76
55.2k
      if(m_queue.size() >= length)
77
37.6k
         {
78
37.6k
         Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
79
80
37.6k
         if(type == HANDSHAKE_NONE)
81
7
            throw Decoding_Error("Invalid handshake message type");
82
83
37.6k
         std::vector<uint8_t> contents(m_queue.begin() + 4,
84
37.6k
                                       m_queue.begin() + length);
85
86
37.6k
         m_queue.erase(m_queue.begin(), m_queue.begin() + length);
87
88
37.6k
         return std::make_pair(type, contents);
89
37.6k
         }
90
55.2k
      }
91
92
41.3k
   return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
93
79.0k
   }
94
95
std::vector<uint8_t>
96
Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
97
                            Handshake_Type type) const
98
99.3k
   {
99
99.3k
   std::vector<uint8_t> send_buf(4 + msg.size());
100
101
99.3k
   const size_t buf_size = msg.size();
102
103
99.3k
   send_buf[0] = static_cast<uint8_t>(type);
104
105
99.3k
   store_be24(&send_buf[1], buf_size);
106
107
99.3k
   if (!msg.empty())
108
79.5k
      {
109
79.5k
      copy_mem(&send_buf[4], msg.data(), msg.size());
110
79.5k
      }
111
112
99.3k
   return send_buf;
113
99.3k
   }
114
115
std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
116
0
   {
117
0
   throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
118
0
   }
119
120
std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
121
62.6k
   {
122
62.6k
   const std::vector<uint8_t> msg_bits = msg.serialize();
123
124
62.6k
   if(msg.type() == HANDSHAKE_CCS)
125
308
      {
126
308
      m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
127
308
      return std::vector<uint8_t>(); // not included in handshake hashes
128
308
      }
129
130
62.3k
   auto buf = format(msg_bits, msg.wire_type());
131
62.3k
   m_send_hs(HANDSHAKE, buf);
132
62.3k
   return buf;
133
62.6k
   }
134
135
Protocol_Version Datagram_Handshake_IO::initial_record_version() const
136
680
   {
137
680
   return Protocol_Version::DTLS_V12;
138
680
   }
139
140
void Datagram_Handshake_IO::retransmit_last_flight()
141
0
   {
142
0
   const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
143
0
   retransmit_flight(flight_idx);
144
0
   }
145
146
void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
147
0
   {
148
0
   const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
149
150
0
   BOTAN_ASSERT(!flight.empty(), "Nonempty flight to retransmit");
151
152
0
   uint16_t epoch = m_flight_data[flight[0]].epoch;
153
154
0
   for(auto msg_seq : flight)
155
0
      {
156
0
      auto& msg = m_flight_data[msg_seq];
157
158
0
      if(msg.epoch != epoch)
159
0
         {
160
         // Epoch gap: insert the CCS
161
0
         std::vector<uint8_t> ccs(1, 1);
162
0
         m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
163
0
         }
164
165
0
      send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
166
0
      epoch = msg.epoch;
167
0
      }
168
0
   }
169
170
bool Datagram_Handshake_IO::have_more_data() const
171
25
   {
172
25
   return false;
173
25
   }
174
175
bool Datagram_Handshake_IO::timeout_check()
176
0
   {
177
0
   if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
178
0
      {
179
      /*
180
      If we haven't written anything yet obviously no timeout.
181
      Also no timeout possible if we are mid-flight,
182
      */
183
0
      return false;
184
0
      }
185
186
0
   const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
187
188
0
   if(ms_since_write < m_next_timeout)
189
0
      return false;
190
191
0
   retransmit_last_flight();
192
193
0
   m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
194
0
   return true;
195
0
   }
196
197
void Datagram_Handshake_IO::add_record(const uint8_t record[],
198
                                       size_t record_len,
199
                                       Record_Type record_type,
200
                                       uint64_t record_sequence)
201
950
   {
202
950
   const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
203
204
950
   if(record_type == CHANGE_CIPHER_SPEC)
205
48
      {
206
48
      if(record_len != 1 || record[0] != 1)
207
17
         throw Decoding_Error("Invalid ChangeCipherSpec");
208
209
      // TODO: check this is otherwise empty
210
31
      m_ccs_epochs.insert(epoch);
211
31
      return;
212
48
      }
213
214
902
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
215
216
2.43k
   while(record_len)
217
1.80k
      {
218
1.80k
      if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
219
218
         return; // completely bogus? at least degenerate/weird
220
221
1.58k
      const uint8_t msg_type = record[0];
222
1.58k
      const size_t msg_len = load_be24(&record[1]);
223
1.58k
      const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
224
1.58k
      const size_t fragment_offset = load_be24(&record[6]);
225
1.58k
      const size_t fragment_length = load_be24(&record[9]);
226
227
1.58k
      const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
228
229
1.58k
      if(record_len < total_size)
230
48
         throw Decoding_Error("Bad lengths in DTLS header");
231
232
1.53k
      if(message_seq >= m_in_message_seq)
233
1.46k
         {
234
1.46k
         m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
235
1.46k
                                              fragment_length,
236
1.46k
                                              fragment_offset,
237
1.46k
                                              epoch,
238
1.46k
                                              msg_type,
239
1.46k
                                              msg_len);
240
1.46k
         }
241
71
      else
242
71
         {
243
         // TODO: detect retransmitted flight
244
71
         }
245
246
1.53k
      record += total_size;
247
1.53k
      record_len -= total_size;
248
1.53k
      }
249
902
   }
250
251
std::pair<Handshake_Type, std::vector<uint8_t>>
252
Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
253
826
   {
254
   // Expecting a message means the last flight is concluded
255
826
   if(!m_flights.rbegin()->empty())
256
3
      m_flights.push_back(std::vector<uint16_t>());
257
258
826
   if(expecting_ccs)
259
0
      {
260
0
      if(!m_messages.empty())
261
0
         {
262
0
         const uint16_t current_epoch = m_messages.begin()->second.epoch();
263
264
0
         if(m_ccs_epochs.count(current_epoch))
265
0
            return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
266
0
         }
267
0
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
268
0
      }
269
270
826
   auto i = m_messages.find(m_in_message_seq);
271
272
826
   if(i == m_messages.end() || !i->second.complete())
273
754
      {
274
754
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
275
754
      }
276
277
72
   m_in_message_seq += 1;
278
279
72
   return i->second.message();
280
826
   }
281
282
void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
283
   const uint8_t fragment[],
284
   size_t fragment_length,
285
   size_t fragment_offset,
286
   uint16_t epoch,
287
   uint8_t msg_type,
288
   size_t msg_length)
289
1.46k
   {
290
1.46k
   if(complete())
291
329
      return; // already have entire message, ignore this
292
293
1.13k
   if(m_msg_type == HANDSHAKE_NONE)
294
974
      {
295
974
      m_epoch = epoch;
296
974
      m_msg_type = msg_type;
297
974
      m_msg_length = msg_length;
298
974
      }
299
300
1.13k
   if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
301
57
      throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
302
303
1.07k
   if(fragment_offset > m_msg_length)
304
13
      throw Decoding_Error("Fragment offset past end of message");
305
306
1.06k
   if(fragment_offset + fragment_length > m_msg_length)
307
6
      throw Decoding_Error("Fragment overlaps past end of message");
308
309
1.05k
   if(fragment_offset == 0 && fragment_length == m_msg_length)
310
195
      {
311
195
      m_fragments.clear();
312
195
      m_message.assign(fragment, fragment+fragment_length);
313
195
      }
314
864
   else
315
864
      {
316
      /*
317
      * FIXME. This is a pretty lame way to do defragmentation, huge
318
      * overhead with a tree node per byte.
319
      *
320
      * Also should confirm that all overlaps have no changes,
321
      * otherwise we expose ourselves to the classic fingerprinting
322
      * and IDS evasion attacks on IP fragmentation.
323
      */
324
12.7k
      for(size_t i = 0; i != fragment_length; ++i)
325
11.8k
         m_fragments[fragment_offset+i] = fragment[i];
326
327
864
      if(m_fragments.size() == m_msg_length)
328
66
         {
329
66
         m_message.resize(m_msg_length);
330
2.17k
         for(size_t i = 0; i != m_msg_length; ++i)
331
2.11k
            m_message[i] = m_fragments[i];
332
66
         m_fragments.clear();
333
66
         }
334
864
      }
335
1.05k
   }
336
337
bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
338
1.75k
   {
339
1.75k
   return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
340
1.75k
   }
341
342
std::pair<Handshake_Type, std::vector<uint8_t>>
343
Datagram_Handshake_IO::Handshake_Reassembly::message() const
344
72
   {
345
72
   if(!complete())
346
0
      throw Internal_Error("Datagram_Handshake_IO - message not complete");
347
348
72
   return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
349
72
   }
350
351
std::vector<uint8_t>
352
Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
353
                                       size_t frag_len,
354
                                       uint16_t frag_offset,
355
                                       uint16_t msg_len,
356
                                       Handshake_Type type,
357
                                       uint16_t msg_sequence) const
358
50
   {
359
50
   std::vector<uint8_t> send_buf(12 + frag_len);
360
361
50
   send_buf[0] = static_cast<uint8_t>(type);
362
363
50
   store_be24(&send_buf[1], msg_len);
364
365
50
   store_be(msg_sequence, &send_buf[4]);
366
367
50
   store_be24(&send_buf[6], frag_offset);
368
50
   store_be24(&send_buf[9], frag_len);
369
370
50
   if (frag_len > 0)
371
43
      {
372
43
      copy_mem(&send_buf[12], fragment, frag_len);
373
43
      }
374
375
50
   return send_buf;
376
50
   }
377
378
std::vector<uint8_t>
379
Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
380
                                    Handshake_Type type,
381
                                    uint16_t msg_sequence) const
382
50
   {
383
50
   return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
384
50
   }
385
386
std::vector<uint8_t>
387
Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
388
                              Handshake_Type type) const
389
25
   {
390
25
   return format_w_seq(msg, type, m_in_message_seq - 1);
391
25
   }
392
393
std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
394
25
   {
395
25
   return this->send_under_epoch(msg, m_seqs.current_write_epoch());
396
25
   }
397
398
std::vector<uint8_t>
399
Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
400
25
   {
401
25
   const std::vector<uint8_t> msg_bits = msg.serialize();
402
25
   const Handshake_Type msg_type = msg.type();
403
404
25
   if(msg_type == HANDSHAKE_CCS)
405
0
      {
406
0
      m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
407
0
      return std::vector<uint8_t>(); // not included in handshake hashes
408
0
      }
409
25
   else if(msg_type == HELLO_VERIFY_REQUEST)
410
14
      {
411
      // This message is not included in the handshake hashes
412
14
      send_message(m_out_message_seq, epoch, msg_type, msg_bits);
413
14
      m_out_message_seq += 1;
414
14
      return std::vector<uint8_t>();
415
14
      }
416
417
   // Note: not saving CCS, instead we know it was there due to change in epoch
418
11
   m_flights.rbegin()->push_back(m_out_message_seq);
419
11
   m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
420
421
11
   m_out_message_seq += 1;
422
11
   m_last_write = steady_clock_ms();
423
11
   m_next_timeout = m_initial_timeout;
424
425
11
   return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
426
25
   }
427
428
std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
429
                                                      uint16_t epoch,
430
                                                      Handshake_Type msg_type,
431
                                                      const std::vector<uint8_t>& msg_bits)
432
25
   {
433
25
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
434
435
25
   auto no_fragment = format_w_seq(msg_bits, msg_type, msg_seq);
436
437
25
   if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
438
25
      {
439
25
      m_send_hs(epoch, HANDSHAKE, no_fragment);
440
25
      }
441
0
   else
442
0
      {
443
0
      size_t frag_offset = 0;
444
445
      /**
446
      * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
447
      * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
448
      * over-estimate here. When CBC ciphers are removed this can be reduced
449
      * since AEAD modes have no padding, at most 16 byte mac, and smaller
450
      * per-record nonce.
451
      */
452
0
      const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
453
0
      const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
454
455
0
      if(m_mtu <= (header_overhead + ciphersuite_overhead))
456
0
         throw Invalid_Argument("DTLS MTU is too small to send headers");
457
458
0
      const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
459
460
0
      while(frag_offset != msg_bits.size())
461
0
         {
462
0
         const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
463
464
0
         const std::vector<uint8_t> frag =
465
0
            format_fragment(&msg_bits[frag_offset],
466
0
                            frag_len,
467
0
                            static_cast<uint16_t>(frag_offset),
468
0
                            static_cast<uint16_t>(msg_bits.size()),
469
0
                            msg_type,
470
0
                            msg_seq);
471
472
0
         m_send_hs(epoch, HANDSHAKE, frag);
473
474
0
         frag_offset += frag_len;
475
0
         }
476
0
      }
477
478
25
   return no_fragment;
479
25
   }
480
481
}