Coverage Report

Created: 2021-04-07 06:07

/src/botan/src/lib/tls/tls_handshake_io.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* TLS Handshake IO
3
* (C) 2012,2014,2015 Jack Lloyd
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#include <botan/internal/tls_handshake_io.h>
9
#include <botan/internal/tls_record.h>
10
#include <botan/internal/tls_seq_numbers.h>
11
#include <botan/tls_messages.h>
12
#include <botan/exceptn.h>
13
#include <botan/internal/loadstor.h>
14
#include <chrono>
15
16
namespace Botan {
17
18
namespace TLS {
19
20
namespace {
21
22
inline size_t load_be24(const uint8_t q[3])
23
9.87k
   {
24
9.87k
   return make_uint32(0,
25
9.87k
                      q[0],
26
9.87k
                      q[1],
27
9.87k
                      q[2]);
28
9.87k
   }
29
30
void store_be24(uint8_t out[3], size_t val)
31
117k
   {
32
117k
   out[0] = get_byte(1, static_cast<uint32_t>(val));
33
117k
   out[1] = get_byte(2, static_cast<uint32_t>(val));
34
117k
   out[2] = get_byte(3, static_cast<uint32_t>(val));
35
117k
   }
36
37
uint64_t steady_clock_ms()
38
0
   {
39
0
   return std::chrono::duration_cast<std::chrono::milliseconds>(
40
0
      std::chrono::steady_clock::now().time_since_epoch()).count();
41
0
   }
42
43
}
44
45
Protocol_Version Stream_Handshake_IO::initial_record_version() const
46
44.7k
   {
47
44.7k
   return Protocol_Version::TLS_V12;
48
44.7k
   }
49
50
void Stream_Handshake_IO::add_record(const uint8_t record[],
51
                                     size_t record_len,
52
                                     Record_Type record_type, uint64_t)
53
85.7k
   {
54
85.7k
   if(record_type == HANDSHAKE)
55
84.7k
      {
56
84.7k
      m_queue.insert(m_queue.end(), record, record + record_len);
57
84.7k
      }
58
997
   else if(record_type == CHANGE_CIPHER_SPEC)
59
997
      {
60
997
      if(record_len != 1 || record[0] != 1)
61
61
         throw Decoding_Error("Invalid ChangeCipherSpec");
62
63
      // Pretend it's a regular handshake message of zero length
64
936
      const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
65
936
      m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
66
936
      }
67
0
   else
68
0
      throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
69
85.7k
   }
70
71
std::pair<Handshake_Type, std::vector<uint8_t>>
72
Stream_Handshake_IO::get_next_record(bool)
73
125k
   {
74
125k
   if(m_queue.size() >= 4)
75
66.6k
      {
76
66.6k
      const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
77
78
66.6k
      if(m_queue.size() >= length)
79
43.2k
         {
80
43.2k
         Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
81
82
43.2k
         if(type == HANDSHAKE_NONE)
83
6
            throw Decoding_Error("Invalid handshake message type");
84
85
43.2k
         std::vector<uint8_t> contents(m_queue.begin() + 4,
86
43.2k
                                       m_queue.begin() + length);
87
88
43.2k
         m_queue.erase(m_queue.begin(), m_queue.begin() + length);
89
90
43.2k
         return std::make_pair(type, contents);
91
43.2k
         }
92
66.6k
      }
93
94
82.1k
   return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
95
82.1k
   }
96
97
std::vector<uint8_t>
98
Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
99
                            Handshake_Type type) const
100
117k
   {
101
117k
   std::vector<uint8_t> send_buf(4 + msg.size());
102
103
117k
   const size_t buf_size = msg.size();
104
105
117k
   send_buf[0] = static_cast<uint8_t>(type);
106
107
117k
   store_be24(&send_buf[1], buf_size);
108
109
117k
   if (msg.size() > 0)
110
95.6k
      {
111
95.6k
      copy_mem(&send_buf[4], msg.data(), msg.size());
112
95.6k
      }
113
114
117k
   return send_buf;
115
117k
   }
116
117
std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
118
0
   {
119
0
   throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
120
0
   }
121
122
std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
123
75.2k
   {
124
75.2k
   const std::vector<uint8_t> msg_bits = msg.serialize();
125
126
75.2k
   if(msg.type() == HANDSHAKE_CCS)
127
430
      {
128
430
      m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
129
430
      return std::vector<uint8_t>(); // not included in handshake hashes
130
430
      }
131
132
74.7k
   const std::vector<uint8_t> buf = format(msg_bits, msg.type());
133
74.7k
   m_send_hs(HANDSHAKE, buf);
134
74.7k
   return buf;
135
74.7k
   }
136
137
Protocol_Version Datagram_Handshake_IO::initial_record_version() const
138
630
   {
139
630
   return Protocol_Version::DTLS_V12;
140
630
   }
141
142
void Datagram_Handshake_IO::retransmit_last_flight()
143
0
   {
144
0
   const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
145
0
   retransmit_flight(flight_idx);
146
0
   }
147
148
void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
149
0
   {
150
0
   const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
151
152
0
   BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
153
154
0
   uint16_t epoch = m_flight_data[flight[0]].epoch;
155
156
0
   for(auto msg_seq : flight)
157
0
      {
158
0
      auto& msg = m_flight_data[msg_seq];
159
160
0
      if(msg.epoch != epoch)
161
0
         {
162
         // Epoch gap: insert the CCS
163
0
         std::vector<uint8_t> ccs(1, 1);
164
0
         m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
165
0
         }
166
167
0
      send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
168
0
      epoch = msg.epoch;
169
0
      }
170
0
   }
171
172
bool Datagram_Handshake_IO::have_more_data() const
173
13
   {
174
13
   return false;
175
13
   }
176
177
bool Datagram_Handshake_IO::timeout_check()
178
0
   {
179
0
   if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
180
0
      {
181
      /*
182
      If we haven't written anything yet obviously no timeout.
183
      Also no timeout possible if we are mid-flight,
184
      */
185
0
      return false;
186
0
      }
187
188
0
   const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
189
190
0
   if(ms_since_write < m_next_timeout)
191
0
      return false;
192
193
0
   retransmit_last_flight();
194
195
0
   m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
196
0
   return true;
197
0
   }
198
199
void Datagram_Handshake_IO::add_record(const uint8_t record[],
200
                                       size_t record_len,
201
                                       Record_Type record_type,
202
                                       uint64_t record_sequence)
203
871
   {
204
871
   const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
205
206
871
   if(record_type == CHANGE_CIPHER_SPEC)
207
52
      {
208
52
      if(record_len != 1 || record[0] != 1)
209
14
         throw Decoding_Error("Invalid ChangeCipherSpec");
210
211
      // TODO: check this is otherwise empty
212
38
      m_ccs_epochs.insert(epoch);
213
38
      return;
214
38
      }
215
216
819
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
217
218
4.03k
   while(record_len)
219
3.51k
      {
220
3.51k
      if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
221
218
         return; // completely bogus? at least degenerate/weird
222
223
3.29k
      const uint8_t msg_type = record[0];
224
3.29k
      const size_t msg_len = load_be24(&record[1]);
225
3.29k
      const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
226
3.29k
      const size_t fragment_offset = load_be24(&record[6]);
227
3.29k
      const size_t fragment_length = load_be24(&record[9]);
228
229
3.29k
      const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
230
231
3.29k
      if(record_len < total_size)
232
80
         throw Decoding_Error("Bad lengths in DTLS header");
233
234
3.21k
      if(message_seq >= m_in_message_seq)
235
3.21k
         {
236
3.21k
         m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
237
3.21k
                                              fragment_length,
238
3.21k
                                              fragment_offset,
239
3.21k
                                              epoch,
240
3.21k
                                              msg_type,
241
3.21k
                                              msg_len);
242
3.21k
         }
243
0
      else
244
0
         {
245
         // TODO: detect retransmitted flight
246
0
         }
247
248
3.21k
      record += total_size;
249
3.21k
      record_len -= total_size;
250
3.21k
      }
251
819
   }
252
253
std::pair<Handshake_Type, std::vector<uint8_t>>
254
Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
255
703
   {
256
   // Expecting a message means the last flight is concluded
257
703
   if(!m_flights.rbegin()->empty())
258
0
      m_flights.push_back(std::vector<uint16_t>());
259
260
703
   if(expecting_ccs)
261
0
      {
262
0
      if(!m_messages.empty())
263
0
         {
264
0
         const uint16_t current_epoch = m_messages.begin()->second.epoch();
265
266
0
         if(m_ccs_epochs.count(current_epoch))
267
0
            return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
268
0
         }
269
0
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
270
0
      }
271
272
703
   auto i = m_messages.find(m_in_message_seq);
273
274
703
   if(i == m_messages.end() || !i->second.complete())
275
662
      {
276
662
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
277
662
      }
278
279
41
   m_in_message_seq += 1;
280
281
41
   return i->second.message();
282
41
   }
283
284
void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
285
   const uint8_t fragment[],
286
   size_t fragment_length,
287
   size_t fragment_offset,
288
   uint16_t epoch,
289
   uint8_t msg_type,
290
   size_t msg_length)
291
3.21k
   {
292
3.21k
   if(complete())
293
1.67k
      return; // already have entire message, ignore this
294
295
1.54k
   if(m_msg_type == HANDSHAKE_NONE)
296
1.34k
      {
297
1.34k
      m_epoch = epoch;
298
1.34k
      m_msg_type = msg_type;
299
1.34k
      m_msg_length = msg_length;
300
1.34k
      }
301
302
1.54k
   if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
303
57
      throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
304
305
1.48k
   if(fragment_offset > m_msg_length)
306
14
      throw Decoding_Error("Fragment offset past end of message");
307
308
1.46k
   if(fragment_offset + fragment_length > m_msg_length)
309
3
      throw Decoding_Error("Fragment overlaps past end of message");
310
311
1.46k
   if(fragment_offset == 0 && fragment_length == m_msg_length)
312
534
      {
313
534
      m_fragments.clear();
314
534
      m_message.assign(fragment, fragment+fragment_length);
315
534
      }
316
932
   else
317
932
      {
318
      /*
319
      * FIXME. This is a pretty lame way to do defragmentation, huge
320
      * overhead with a tree node per byte.
321
      *
322
      * Also should confirm that all overlaps have no changes,
323
      * otherwise we expose ourselves to the classic fingerprinting
324
      * and IDS evasion attacks on IP fragmentation.
325
      */
326
15.4k
      for(size_t i = 0; i != fragment_length; ++i)
327
14.5k
         m_fragments[fragment_offset+i] = fragment[i];
328
329
932
      if(m_fragments.size() == m_msg_length)
330
39
         {
331
39
         m_message.resize(m_msg_length);
332
936
         for(size_t i = 0; i != m_msg_length; ++i)
333
897
            m_message[i] = m_fragments[i];
334
39
         m_fragments.clear();
335
39
         }
336
932
      }
337
1.46k
   }
338
339
bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
340
3.44k
   {
341
3.44k
   return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
342
3.44k
   }
343
344
std::pair<Handshake_Type, std::vector<uint8_t>>
345
Datagram_Handshake_IO::Handshake_Reassembly::message() const
346
41
   {
347
41
   if(!complete())
348
0
      throw Internal_Error("Datagram_Handshake_IO - message not complete");
349
350
41
   return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
351
41
   }
352
353
std::vector<uint8_t>
354
Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
355
                                       size_t frag_len,
356
                                       uint16_t frag_offset,
357
                                       uint16_t msg_len,
358
                                       Handshake_Type type,
359
                                       uint16_t msg_sequence) const
360
13
   {
361
13
   std::vector<uint8_t> send_buf(12 + frag_len);
362
363
13
   send_buf[0] = static_cast<uint8_t>(type);
364
365
13
   store_be24(&send_buf[1], msg_len);
366
367
13
   store_be(msg_sequence, &send_buf[4]);
368
369
13
   store_be24(&send_buf[6], frag_offset);
370
13
   store_be24(&send_buf[9], frag_len);
371
372
13
   if (frag_len > 0)
373
9
      {
374
9
      copy_mem(&send_buf[12], fragment, frag_len);
375
9
      }
376
377
13
   return send_buf;
378
13
   }
379
380
std::vector<uint8_t>
381
Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
382
                                    Handshake_Type type,
383
                                    uint16_t msg_sequence) const
384
13
   {
385
13
   return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
386
13
   }
387
388
std::vector<uint8_t>
389
Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
390
                              Handshake_Type type) const
391
13
   {
392
13
   return format_w_seq(msg, type, m_in_message_seq - 1);
393
13
   }
394
395
std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
396
0
   {
397
0
   return this->send_under_epoch(msg, m_seqs.current_write_epoch());
398
0
   }
399
400
std::vector<uint8_t>
401
Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
402
0
   {
403
0
   const std::vector<uint8_t> msg_bits = msg.serialize();
404
0
   const Handshake_Type msg_type = msg.type();
405
406
0
   if(msg_type == HANDSHAKE_CCS)
407
0
      {
408
0
      m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
409
0
      return std::vector<uint8_t>(); // not included in handshake hashes
410
0
      }
411
0
   else if(msg_type == HELLO_VERIFY_REQUEST)
412
0
      {
413
      // This message is not included in the handshake hashes
414
0
      send_message(m_out_message_seq, epoch, msg_type, msg_bits);
415
0
      m_out_message_seq += 1;
416
0
      return std::vector<uint8_t>();
417
0
      }
418
419
   // Note: not saving CCS, instead we know it was there due to change in epoch
420
0
   m_flights.rbegin()->push_back(m_out_message_seq);
421
0
   m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
422
423
0
   m_out_message_seq += 1;
424
0
   m_last_write = steady_clock_ms();
425
0
   m_next_timeout = m_initial_timeout;
426
427
0
   return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
428
0
   }
429
430
std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
431
                                                      uint16_t epoch,
432
                                                      Handshake_Type msg_type,
433
                                                      const std::vector<uint8_t>& msg_bits)
434
0
   {
435
0
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
436
437
0
   const std::vector<uint8_t> no_fragment =
438
0
      format_w_seq(msg_bits, msg_type, msg_seq);
439
440
0
   if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
441
0
      {
442
0
      m_send_hs(epoch, HANDSHAKE, no_fragment);
443
0
      }
444
0
   else
445
0
      {
446
0
      size_t frag_offset = 0;
447
448
      /**
449
      * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
450
      * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
451
      * over-estimate here. When CBC ciphers are removed this can be reduced
452
      * since AEAD modes have no padding, at most 16 byte mac, and smaller
453
      * per-record nonce.
454
      */
455
0
      const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
456
0
      const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
457
458
0
      if(m_mtu <= (header_overhead + ciphersuite_overhead))
459
0
         throw Invalid_Argument("DTLS MTU is too small to send headers");
460
461
0
      const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
462
463
0
      while(frag_offset != msg_bits.size())
464
0
         {
465
0
         const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
466
467
0
         const std::vector<uint8_t> frag =
468
0
            format_fragment(&msg_bits[frag_offset],
469
0
                            frag_len,
470
0
                            static_cast<uint16_t>(frag_offset),
471
0
                            static_cast<uint16_t>(msg_bits.size()),
472
0
                            msg_type,
473
0
                            msg_seq);
474
475
0
         m_send_hs(epoch, HANDSHAKE, frag);
476
477
0
         frag_offset += frag_len;
478
0
         }
479
0
      }
480
481
0
   return no_fragment;
482
0
   }
483
484
}
485
}