Coverage Report

Created: 2022-11-24 06:56

/src/botan/src/lib/tls/tls12/tls_handshake_io.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* TLS Handshake IO
3
* (C) 2012,2014,2015 Jack Lloyd
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#include <botan/internal/tls_handshake_io.h>
9
#include <botan/internal/tls_record.h>
10
#include <botan/internal/tls_seq_numbers.h>
11
#include <botan/tls_messages.h>
12
#include <botan/exceptn.h>
13
#include <botan/internal/loadstor.h>
14
#include <chrono>
15
16
namespace Botan::TLS {
17
18
namespace {
19
20
inline size_t load_be24(const uint8_t q[3])
21
6.18k
   {
22
6.18k
   return make_uint32(0,
23
6.18k
                      q[0],
24
6.18k
                      q[1],
25
6.18k
                      q[2]);
26
6.18k
   }
27
28
void store_be24(uint8_t out[3], size_t val)
29
117k
   {
30
117k
   out[0] = get_byte<1>(static_cast<uint32_t>(val));
31
117k
   out[1] = get_byte<2>(static_cast<uint32_t>(val));
32
117k
   out[2] = get_byte<3>(static_cast<uint32_t>(val));
33
117k
   }
34
35
uint64_t steady_clock_ms()
36
38
   {
37
38
   return std::chrono::duration_cast<std::chrono::milliseconds>(
38
38
      std::chrono::steady_clock::now().time_since_epoch()).count();
39
38
   }
40
41
}
42
43
Protocol_Version Stream_Handshake_IO::initial_record_version() const
44
29.0k
   {
45
29.0k
   return Protocol_Version::TLS_V12;
46
29.0k
   }
47
48
void Stream_Handshake_IO::add_record(const uint8_t record[],
49
                                     size_t record_len,
50
                                     Record_Type record_type, uint64_t /*sequence_number*/)
51
51.3k
   {
52
51.3k
   if(record_type == HANDSHAKE)
53
50.2k
      {
54
50.2k
      m_queue.insert(m_queue.end(), record, record + record_len);
55
50.2k
      }
56
1.04k
   else if(record_type == CHANGE_CIPHER_SPEC)
57
1.04k
      {
58
1.04k
      if(record_len != 1 || record[0] != 1)
59
45
         throw Decoding_Error("Invalid ChangeCipherSpec");
60
61
      // Pretend it's a regular handshake message of zero length
62
1.00k
      const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
63
1.00k
      m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
64
1.00k
      }
65
0
   else
66
0
      throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
67
51.3k
   }
68
69
std::pair<Handshake_Type, std::vector<uint8_t>>
70
Stream_Handshake_IO::get_next_record(bool /*expecting_ccs*/)
71
91.4k
   {
72
91.4k
   if(m_queue.size() >= 4)
73
59.7k
      {
74
59.7k
      const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
75
76
59.7k
      if(m_queue.size() >= length)
77
45.2k
         {
78
45.2k
         Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
79
80
45.2k
         if(type == HANDSHAKE_NONE)
81
4
            throw Decoding_Error("Invalid handshake message type");
82
83
45.2k
         std::vector<uint8_t> contents(m_queue.begin() + 4,
84
45.2k
                                       m_queue.begin() + length);
85
86
45.2k
         m_queue.erase(m_queue.begin(), m_queue.begin() + length);
87
88
45.2k
         return std::make_pair(type, contents);
89
45.2k
         }
90
59.7k
      }
91
92
46.2k
   return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
93
91.4k
   }
94
95
std::vector<uint8_t>
96
Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
97
                            Handshake_Type type) const
98
117k
   {
99
117k
   std::vector<uint8_t> send_buf(4 + msg.size());
100
101
117k
   const size_t buf_size = msg.size();
102
103
117k
   send_buf[0] = static_cast<uint8_t>(type);
104
105
117k
   store_be24(&send_buf[1], buf_size);
106
107
117k
   if (!msg.empty())
108
94.7k
      {
109
94.7k
      copy_mem(&send_buf[4], msg.data(), msg.size());
110
94.7k
      }
111
112
117k
   return send_buf;
113
117k
   }
114
115
std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
116
0
   {
117
0
   throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
118
0
   }
119
120
std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
121
73.1k
   {
122
73.1k
   const std::vector<uint8_t> msg_bits = msg.serialize();
123
124
73.1k
   if(msg.type() == HANDSHAKE_CCS)
125
377
      {
126
377
      m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
127
377
      return std::vector<uint8_t>(); // not included in handshake hashes
128
377
      }
129
130
72.8k
   auto buf = format(msg_bits, msg.wire_type());
131
72.8k
   m_send_hs(HANDSHAKE, buf);
132
72.8k
   return buf;
133
73.1k
   }
134
135
Protocol_Version Datagram_Handshake_IO::initial_record_version() const
136
810
   {
137
810
   return Protocol_Version::DTLS_V12;
138
810
   }
139
140
void Datagram_Handshake_IO::retransmit_last_flight()
141
0
   {
142
0
   const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
143
0
   retransmit_flight(flight_idx);
144
0
   }
145
146
void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
147
0
   {
148
0
   const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
149
150
0
   BOTAN_ASSERT(!flight.empty(), "Nonempty flight to retransmit");
151
152
0
   uint16_t epoch = m_flight_data[flight[0]].epoch;
153
154
0
   for(auto msg_seq : flight)
155
0
      {
156
0
      auto& msg = m_flight_data[msg_seq];
157
158
0
      if(msg.epoch != epoch)
159
0
         {
160
         // Epoch gap: insert the CCS
161
0
         std::vector<uint8_t> ccs(1, 1);
162
0
         m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
163
0
         }
164
165
0
      send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
166
0
      epoch = msg.epoch;
167
0
      }
168
0
   }
169
170
bool Datagram_Handshake_IO::have_more_data() const
171
32
   {
172
32
   return false;
173
32
   }
174
175
bool Datagram_Handshake_IO::timeout_check()
176
0
   {
177
0
   if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
178
0
      {
179
      /*
180
      If we haven't written anything yet obviously no timeout.
181
      Also no timeout possible if we are mid-flight,
182
      */
183
0
      return false;
184
0
      }
185
186
0
   const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
187
188
0
   if(ms_since_write < m_next_timeout)
189
0
      return false;
190
191
0
   retransmit_last_flight();
192
193
0
   m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
194
0
   return true;
195
0
   }
196
197
void Datagram_Handshake_IO::add_record(const uint8_t record[],
198
                                       size_t record_len,
199
                                       Record_Type record_type,
200
                                       uint64_t record_sequence)
201
1.16k
   {
202
1.16k
   const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
203
204
1.16k
   if(record_type == CHANGE_CIPHER_SPEC)
205
34
      {
206
34
      if(record_len != 1 || record[0] != 1)
207
7
         throw Decoding_Error("Invalid ChangeCipherSpec");
208
209
      // TODO: check this is otherwise empty
210
27
      m_ccs_epochs.insert(epoch);
211
27
      return;
212
34
      }
213
214
1.13k
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
215
216
3.12k
   while(record_len)
217
2.27k
      {
218
2.27k
      if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
219
217
         return; // completely bogus? at least degenerate/weird
220
221
2.06k
      const uint8_t msg_type = record[0];
222
2.06k
      const size_t msg_len = load_be24(&record[1]);
223
2.06k
      const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
224
2.06k
      const size_t fragment_offset = load_be24(&record[6]);
225
2.06k
      const size_t fragment_length = load_be24(&record[9]);
226
227
2.06k
      const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
228
229
2.06k
      if(record_len < total_size)
230
70
         throw Decoding_Error("Bad lengths in DTLS header");
231
232
1.99k
      if(message_seq >= m_in_message_seq)
233
1.76k
         {
234
1.76k
         m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
235
1.76k
                                              fragment_length,
236
1.76k
                                              fragment_offset,
237
1.76k
                                              epoch,
238
1.76k
                                              msg_type,
239
1.76k
                                              msg_len);
240
1.76k
         }
241
227
      else
242
227
         {
243
         // TODO: detect retransmitted flight
244
227
         }
245
246
1.99k
      record += total_size;
247
1.99k
      record_len -= total_size;
248
1.99k
      }
249
1.13k
   }
250
251
std::pair<Handshake_Type, std::vector<uint8_t>>
252
Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
253
1.02k
   {
254
   // Expecting a message means the last flight is concluded
255
1.02k
   if(!m_flights.rbegin()->empty())
256
12
      m_flights.push_back(std::vector<uint16_t>());
257
258
1.02k
   if(expecting_ccs)
259
0
      {
260
0
      if(!m_messages.empty())
261
0
         {
262
0
         const uint16_t current_epoch = m_messages.begin()->second.epoch();
263
264
0
         if(m_ccs_epochs.count(current_epoch))
265
0
            return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
266
0
         }
267
0
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
268
0
      }
269
270
1.02k
   auto i = m_messages.find(m_in_message_seq);
271
272
1.02k
   if(i == m_messages.end() || !i->second.complete())
273
941
      {
274
941
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
275
941
      }
276
277
84
   m_in_message_seq += 1;
278
279
84
   return i->second.message();
280
1.02k
   }
281
282
void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
283
   const uint8_t fragment[],
284
   size_t fragment_length,
285
   size_t fragment_offset,
286
   uint16_t epoch,
287
   uint8_t msg_type,
288
   size_t msg_length)
289
1.76k
   {
290
1.76k
   if(complete())
291
255
      return; // already have entire message, ignore this
292
293
1.51k
   if(m_msg_type == HANDSHAKE_NONE)
294
1.33k
      {
295
1.33k
      m_epoch = epoch;
296
1.33k
      m_msg_type = msg_type;
297
1.33k
      m_msg_length = msg_length;
298
1.33k
      }
299
300
1.51k
   if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
301
64
      throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
302
303
1.44k
   if(fragment_offset > m_msg_length)
304
17
      throw Decoding_Error("Fragment offset past end of message");
305
306
1.42k
   if(fragment_offset + fragment_length > m_msg_length)
307
7
      throw Decoding_Error("Fragment overlaps past end of message");
308
309
1.42k
   if(fragment_offset == 0 && fragment_length == m_msg_length)
310
278
      {
311
278
      m_fragments.clear();
312
278
      m_message.assign(fragment, fragment+fragment_length);
313
278
      }
314
1.14k
   else
315
1.14k
      {
316
      /*
317
      * FIXME. This is a pretty lame way to do defragmentation, huge
318
      * overhead with a tree node per byte.
319
      *
320
      * Also should confirm that all overlaps have no changes,
321
      * otherwise we expose ourselves to the classic fingerprinting
322
      * and IDS evasion attacks on IP fragmentation.
323
      */
324
27.0k
      for(size_t i = 0; i != fragment_length; ++i)
325
25.8k
         m_fragments[fragment_offset+i] = fragment[i];
326
327
1.14k
      if(m_fragments.size() == m_msg_length)
328
90
         {
329
90
         m_message.resize(m_msg_length);
330
3.03k
         for(size_t i = 0; i != m_msg_length; ++i)
331
2.94k
            m_message[i] = m_fragments[i];
332
90
         m_fragments.clear();
333
90
         }
334
1.14k
      }
335
1.42k
   }
336
337
bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
338
2.11k
   {
339
2.11k
   return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
340
2.11k
   }
341
342
std::pair<Handshake_Type, std::vector<uint8_t>>
343
Datagram_Handshake_IO::Handshake_Reassembly::message() const
344
84
   {
345
84
   if(!complete())
346
0
      throw Internal_Error("Datagram_Handshake_IO - message not complete");
347
348
84
   return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
349
84
   }
350
351
std::vector<uint8_t>
352
Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
353
                                       size_t frag_len,
354
                                       uint16_t frag_offset,
355
                                       uint16_t msg_len,
356
                                       Handshake_Type type,
357
                                       uint16_t msg_sequence) const
358
83
   {
359
83
   std::vector<uint8_t> send_buf(12 + frag_len);
360
361
83
   send_buf[0] = static_cast<uint8_t>(type);
362
363
83
   store_be24(&send_buf[1], msg_len);
364
365
83
   store_be(msg_sequence, &send_buf[4]);
366
367
83
   store_be24(&send_buf[6], frag_offset);
368
83
   store_be24(&send_buf[9], frag_len);
369
370
83
   if (frag_len > 0)
371
68
      {
372
68
      copy_mem(&send_buf[12], fragment, frag_len);
373
68
      }
374
375
83
   return send_buf;
376
83
   }
377
378
std::vector<uint8_t>
379
Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
380
                                    Handshake_Type type,
381
                                    uint16_t msg_sequence) const
382
83
   {
383
83
   return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
384
83
   }
385
386
std::vector<uint8_t>
387
Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
388
                              Handshake_Type type) const
389
32
   {
390
32
   return format_w_seq(msg, type, m_in_message_seq - 1);
391
32
   }
392
393
std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
394
51
   {
395
51
   return this->send_under_epoch(msg, m_seqs.current_write_epoch());
396
51
   }
397
398
std::vector<uint8_t>
399
Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
400
51
   {
401
51
   const std::vector<uint8_t> msg_bits = msg.serialize();
402
51
   const Handshake_Type msg_type = msg.type();
403
404
51
   if(msg_type == HANDSHAKE_CCS)
405
0
      {
406
0
      m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
407
0
      return std::vector<uint8_t>(); // not included in handshake hashes
408
0
      }
409
51
   else if(msg_type == HELLO_VERIFY_REQUEST)
410
13
      {
411
      // This message is not included in the handshake hashes
412
13
      send_message(m_out_message_seq, epoch, msg_type, msg_bits);
413
13
      m_out_message_seq += 1;
414
13
      return std::vector<uint8_t>();
415
13
      }
416
417
   // Note: not saving CCS, instead we know it was there due to change in epoch
418
38
   m_flights.rbegin()->push_back(m_out_message_seq);
419
38
   m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
420
421
38
   m_out_message_seq += 1;
422
38
   m_last_write = steady_clock_ms();
423
38
   m_next_timeout = m_initial_timeout;
424
425
38
   return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
426
51
   }
427
428
std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
429
                                                      uint16_t epoch,
430
                                                      Handshake_Type msg_type,
431
                                                      const std::vector<uint8_t>& msg_bits)
432
51
   {
433
51
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
434
435
51
   auto no_fragment = format_w_seq(msg_bits, msg_type, msg_seq);
436
437
51
   if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
438
51
      {
439
51
      m_send_hs(epoch, HANDSHAKE, no_fragment);
440
51
      }
441
0
   else
442
0
      {
443
0
      size_t frag_offset = 0;
444
445
      /**
446
      * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
447
      * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
448
      * over-estimate here. When CBC ciphers are removed this can be reduced
449
      * since AEAD modes have no padding, at most 16 byte mac, and smaller
450
      * per-record nonce.
451
      */
452
0
      const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
453
0
      const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
454
455
0
      if(m_mtu <= (header_overhead + ciphersuite_overhead))
456
0
         throw Invalid_Argument("DTLS MTU is too small to send headers");
457
458
0
      const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
459
460
0
      while(frag_offset != msg_bits.size())
461
0
         {
462
0
         const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
463
464
0
         const std::vector<uint8_t> frag =
465
0
            format_fragment(&msg_bits[frag_offset],
466
0
                            frag_len,
467
0
                            static_cast<uint16_t>(frag_offset),
468
0
                            static_cast<uint16_t>(msg_bits.size()),
469
0
                            msg_type,
470
0
                            msg_seq);
471
472
0
         m_send_hs(epoch, HANDSHAKE, frag);
473
474
0
         frag_offset += frag_len;
475
0
         }
476
0
      }
477
478
51
   return no_fragment;
479
51
   }
480
481
}