Coverage Report

Created: 2020-03-26 13:53

/src/botan/src/lib/tls/tls_handshake_io.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* TLS Handshake IO
3
* (C) 2012,2014,2015 Jack Lloyd
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#include <botan/internal/tls_handshake_io.h>
9
#include <botan/internal/tls_record.h>
10
#include <botan/internal/tls_seq_numbers.h>
11
#include <botan/tls_messages.h>
12
#include <botan/exceptn.h>
13
#include <botan/loadstor.h>
14
#include <chrono>
15
16
namespace Botan {
17
18
namespace TLS {
19
20
namespace {
21
22
inline size_t load_be24(const uint8_t q[3])
23
7.54k
   {
24
7.54k
   return make_uint32(0,
25
7.54k
                      q[0],
26
7.54k
                      q[1],
27
7.54k
                      q[2]);
28
7.54k
   }
29
30
void store_be24(uint8_t out[3], size_t val)
31
141k
   {
32
141k
   out[0] = get_byte(1, static_cast<uint32_t>(val));
33
141k
   out[1] = get_byte(2, static_cast<uint32_t>(val));
34
141k
   out[2] = get_byte(3, static_cast<uint32_t>(val));
35
141k
   }
36
37
uint64_t steady_clock_ms()
38
0
   {
39
0
   return std::chrono::duration_cast<std::chrono::milliseconds>(
40
0
      std::chrono::steady_clock::now().time_since_epoch()).count();
41
0
   }
42
43
}
44
45
Protocol_Version Stream_Handshake_IO::initial_record_version() const
46
45.5k
   {
47
45.5k
   return Protocol_Version::TLS_V10;
48
45.5k
   }
49
50
void Stream_Handshake_IO::add_record(const uint8_t record[],
51
                                     size_t record_len,
52
                                     Record_Type record_type, uint64_t)
53
81.6k
   {
54
81.6k
   if(record_type == HANDSHAKE)
55
73.0k
      {
56
73.0k
      m_queue.insert(m_queue.end(), record, record + record_len);
57
73.0k
      }
58
8.58k
   else if(record_type == CHANGE_CIPHER_SPEC)
59
8.58k
      {
60
8.58k
      if(record_len != 1 || record[0] != 1)
61
114
         throw Decoding_Error("Invalid ChangeCipherSpec");
62
8.46k
63
8.46k
      // Pretend it's a regular handshake message of zero length
64
8.46k
      const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
65
8.46k
      m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
66
8.46k
      }
67
0
   else
68
0
      throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
69
81.6k
   }
70
71
std::pair<Handshake_Type, std::vector<uint8_t>>
72
Stream_Handshake_IO::get_next_record(bool)
73
133k
   {
74
133k
   if(m_queue.size() >= 4)
75
88.7k
      {
76
88.7k
      const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
77
88.7k
78
88.7k
      if(m_queue.size() >= length)
79
57.8k
         {
80
57.8k
         Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
81
57.8k
82
57.8k
         if(type == HANDSHAKE_NONE)
83
7
            throw Decoding_Error("Invalid handshake message type");
84
57.7k
85
57.7k
         std::vector<uint8_t> contents(m_queue.begin() + 4,
86
57.7k
                                       m_queue.begin() + length);
87
57.7k
88
57.7k
         m_queue.erase(m_queue.begin(), m_queue.begin() + length);
89
57.7k
90
57.7k
         return std::make_pair(type, contents);
91
57.7k
         }
92
88.7k
      }
93
75.3k
94
75.3k
   return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
95
75.3k
   }
96
97
std::vector<uint8_t>
98
Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
99
                            Handshake_Type type) const
100
141k
   {
101
141k
   std::vector<uint8_t> send_buf(4 + msg.size());
102
141k
103
141k
   const size_t buf_size = msg.size();
104
141k
105
141k
   send_buf[0] = static_cast<uint8_t>(type);
106
141k
107
141k
   store_be24(&send_buf[1], buf_size);
108
141k
109
141k
   if (msg.size() > 0)
110
116k
      {
111
116k
      copy_mem(&send_buf[4], msg.data(), msg.size());
112
116k
      }
113
141k
114
141k
   return send_buf;
115
141k
   }
116
117
std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
118
0
   {
119
0
   throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
120
0
   }
121
122
std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
123
86.5k
   {
124
86.5k
   const std::vector<uint8_t> msg_bits = msg.serialize();
125
86.5k
126
86.5k
   if(msg.type() == HANDSHAKE_CCS)
127
1.31k
      {
128
1.31k
      m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
129
1.31k
      return std::vector<uint8_t>(); // not included in handshake hashes
130
1.31k
      }
131
85.2k
132
85.2k
   const std::vector<uint8_t> buf = format(msg_bits, msg.type());
133
85.2k
   m_send_hs(HANDSHAKE, buf);
134
85.2k
   return buf;
135
85.2k
   }
136
137
Protocol_Version Datagram_Handshake_IO::initial_record_version() const
138
581
   {
139
581
   return Protocol_Version::DTLS_V10;
140
581
   }
141
142
void Datagram_Handshake_IO::retransmit_last_flight()
143
0
   {
144
0
   const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
145
0
   retransmit_flight(flight_idx);
146
0
   }
147
148
void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
149
0
   {
150
0
   const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
151
0
152
0
   BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
153
0
154
0
   uint16_t epoch = m_flight_data[flight[0]].epoch;
155
0
156
0
   for(auto msg_seq : flight)
157
0
      {
158
0
      auto& msg = m_flight_data[msg_seq];
159
0
160
0
      if(msg.epoch != epoch)
161
0
         {
162
0
         // Epoch gap: insert the CCS
163
0
         std::vector<uint8_t> ccs(1, 1);
164
0
         m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
165
0
         }
166
0
167
0
      send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
168
0
      epoch = msg.epoch;
169
0
      }
170
0
   }
171
172
bool Datagram_Handshake_IO::timeout_check()
173
0
   {
174
0
   if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
175
0
      {
176
0
      /*
177
0
      If we haven't written anything yet obviously no timeout.
178
0
      Also no timeout possible if we are mid-flight,
179
0
      */
180
0
      return false;
181
0
      }
182
0
183
0
   const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
184
0
185
0
   if(ms_since_write < m_next_timeout)
186
0
      return false;
187
0
188
0
   retransmit_last_flight();
189
0
190
0
   m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
191
0
   return true;
192
0
   }
193
194
void Datagram_Handshake_IO::add_record(const uint8_t record[],
195
                                       size_t record_len,
196
                                       Record_Type record_type,
197
                                       uint64_t record_sequence)
198
772
   {
199
772
   const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
200
772
201
772
   if(record_type == CHANGE_CIPHER_SPEC)
202
21
      {
203
21
      if(record_len != 1 || record[0] != 1)
204
10
         throw Decoding_Error("Invalid ChangeCipherSpec");
205
11
206
11
      // TODO: check this is otherwise empty
207
11
      m_ccs_epochs.insert(epoch);
208
11
      return;
209
11
      }
210
751
211
751
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
212
751
213
3.21k
   while(record_len)
214
2.69k
      {
215
2.69k
      if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
216
177
         return; // completely bogus? at least degenerate/weird
217
2.51k
218
2.51k
      const uint8_t msg_type = record[0];
219
2.51k
      const size_t msg_len = load_be24(&record[1]);
220
2.51k
      const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
221
2.51k
      const size_t fragment_offset = load_be24(&record[6]);
222
2.51k
      const size_t fragment_length = load_be24(&record[9]);
223
2.51k
224
2.51k
      const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
225
2.51k
226
2.51k
      if(record_len < total_size)
227
53
         throw Decoding_Error("Bad lengths in DTLS header");
228
2.46k
229
2.46k
      if(message_seq >= m_in_message_seq)
230
2.46k
         {
231
2.46k
         m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
232
2.46k
                                              fragment_length,
233
2.46k
                                              fragment_offset,
234
2.46k
                                              epoch,
235
2.46k
                                              msg_type,
236
2.46k
                                              msg_len);
237
2.46k
         }
238
0
      else
239
0
         {
240
0
         // TODO: detect retransmitted flight
241
0
         }
242
2.46k
243
2.46k
      record += total_size;
244
2.46k
      record_len -= total_size;
245
2.46k
      }
246
751
   }
247
248
std::pair<Handshake_Type, std::vector<uint8_t>>
249
Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
250
629
   {
251
629
   // Expecting a message means the last flight is concluded
252
629
   if(!m_flights.rbegin()->empty())
253
0
      m_flights.push_back(std::vector<uint16_t>());
254
629
255
629
   if(expecting_ccs)
256
0
      {
257
0
      if(!m_messages.empty())
258
0
         {
259
0
         const uint16_t current_epoch = m_messages.begin()->second.epoch();
260
0
261
0
         if(m_ccs_epochs.count(current_epoch))
262
0
            return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
263
0
         }
264
0
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
265
0
      }
266
629
267
629
   auto i = m_messages.find(m_in_message_seq);
268
629
269
629
   if(i == m_messages.end() || !i->second.complete())
270
564
      {
271
564
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
272
564
      }
273
65
274
65
   m_in_message_seq += 1;
275
65
276
65
   return i->second.message();
277
65
   }
278
279
void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
280
   const uint8_t fragment[],
281
   size_t fragment_length,
282
   size_t fragment_offset,
283
   uint16_t epoch,
284
   uint8_t msg_type,
285
   size_t msg_length)
286
2.46k
   {
287
2.46k
   if(complete())
288
1.16k
      return; // already have entire message, ignore this
289
1.29k
290
1.29k
   if(m_msg_type == HANDSHAKE_NONE)
291
1.07k
      {
292
1.07k
      m_epoch = epoch;
293
1.07k
      m_msg_type = msg_type;
294
1.07k
      m_msg_length = msg_length;
295
1.07k
      }
296
1.29k
297
1.29k
   if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
298
54
      throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
299
1.23k
300
1.23k
   if(fragment_offset > m_msg_length)
301
13
      throw Decoding_Error("Fragment offset past end of message");
302
1.22k
303
1.22k
   if(fragment_offset + fragment_length > m_msg_length)
304
13
      throw Decoding_Error("Fragment overlaps past end of message");
305
1.21k
306
1.21k
   if(fragment_offset == 0 && fragment_length == m_msg_length)
307
383
      {
308
383
      m_fragments.clear();
309
383
      m_message.assign(fragment, fragment+fragment_length);
310
383
      }
311
830
   else
312
830
      {
313
830
      /*
314
830
      * FIXME. This is a pretty lame way to do defragmentation, huge
315
830
      * overhead with a tree node per byte.
316
830
      *
317
830
      * Also should confirm that all overlaps have no changes,
318
830
      * otherwise we expose ourselves to the classic fingerprinting
319
830
      * and IDS evasion attacks on IP fragmentation.
320
830
      */
321
24.1k
      for(size_t i = 0; i != fragment_length; ++i)
322
23.2k
         m_fragments[fragment_offset+i] = fragment[i];
323
830
324
830
      if(m_fragments.size() == m_msg_length)
325
41
         {
326
41
         m_message.resize(m_msg_length);
327
2.11k
         for(size_t i = 0; i != m_msg_length; ++i)
328
2.07k
            m_message[i] = m_fragments[i];
329
41
         m_fragments.clear();
330
41
         }
331
830
      }
332
1.21k
   }
333
334
bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
335
2.69k
   {
336
2.69k
   return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
337
2.69k
   }
338
339
std::pair<Handshake_Type, std::vector<uint8_t>>
340
Datagram_Handshake_IO::Handshake_Reassembly::message() const
341
65
   {
342
65
   if(!complete())
343
0
      throw Internal_Error("Datagram_Handshake_IO - message not complete");
344
65
345
65
   return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
346
65
   }
347
348
std::vector<uint8_t>
349
Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
350
                                       size_t frag_len,
351
                                       uint16_t frag_offset,
352
                                       uint16_t msg_len,
353
                                       Handshake_Type type,
354
                                       uint16_t msg_sequence) const
355
14
   {
356
14
   std::vector<uint8_t> send_buf(12 + frag_len);
357
14
358
14
   send_buf[0] = static_cast<uint8_t>(type);
359
14
360
14
   store_be24(&send_buf[1], msg_len);
361
14
362
14
   store_be(msg_sequence, &send_buf[4]);
363
14
364
14
   store_be24(&send_buf[6], frag_offset);
365
14
   store_be24(&send_buf[9], frag_len);
366
14
367
14
   if (frag_len > 0)
368
10
      {
369
10
      copy_mem(&send_buf[12], fragment, frag_len);
370
10
      }
371
14
372
14
   return send_buf;
373
14
   }
374
375
std::vector<uint8_t>
376
Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
377
                                    Handshake_Type type,
378
                                    uint16_t msg_sequence) const
379
14
   {
380
14
   return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
381
14
   }
382
383
std::vector<uint8_t>
384
Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
385
                              Handshake_Type type) const
386
14
   {
387
14
   return format_w_seq(msg, type, m_in_message_seq - 1);
388
14
   }
389
390
std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
391
0
   {
392
0
   return this->send_under_epoch(msg, m_seqs.current_write_epoch());
393
0
   }
394
395
std::vector<uint8_t>
396
Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
397
0
   {
398
0
   const std::vector<uint8_t> msg_bits = msg.serialize();
399
0
   const Handshake_Type msg_type = msg.type();
400
0
401
0
   if(msg_type == HANDSHAKE_CCS)
402
0
      {
403
0
      m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
404
0
      return std::vector<uint8_t>(); // not included in handshake hashes
405
0
      }
406
0
   else if(msg_type == HELLO_VERIFY_REQUEST)
407
0
      {
408
0
      // This message is not included in the handshake hashes
409
0
      send_message(m_out_message_seq, epoch, msg_type, msg_bits);
410
0
      m_out_message_seq += 1;
411
0
      return std::vector<uint8_t>();
412
0
      }
413
0
414
0
   // Note: not saving CCS, instead we know it was there due to change in epoch
415
0
   m_flights.rbegin()->push_back(m_out_message_seq);
416
0
   m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
417
0
418
0
   m_out_message_seq += 1;
419
0
   m_last_write = steady_clock_ms();
420
0
   m_next_timeout = m_initial_timeout;
421
0
422
0
   return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
423
0
   }
424
425
std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
426
                                                      uint16_t epoch,
427
                                                      Handshake_Type msg_type,
428
                                                      const std::vector<uint8_t>& msg_bits)
429
0
   {
430
0
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
431
0
432
0
   const std::vector<uint8_t> no_fragment =
433
0
      format_w_seq(msg_bits, msg_type, msg_seq);
434
0
435
0
   if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
436
0
      {
437
0
      m_send_hs(epoch, HANDSHAKE, no_fragment);
438
0
      }
439
0
   else
440
0
      {
441
0
      size_t frag_offset = 0;
442
0
443
0
      /**
444
0
      * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
445
0
      * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
446
0
      * over-estimate here. When CBC ciphers are removed this can be reduced
447
0
      * since AEAD modes have no padding, at most 16 byte mac, and smaller
448
0
      * per-record nonce.
449
0
      */
450
0
      const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
451
0
      const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
452
0
453
0
      if(m_mtu <= (header_overhead + ciphersuite_overhead))
454
0
         throw Invalid_Argument("DTLS MTU is too small to send headers");
455
0
456
0
      const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
457
0
458
0
      while(frag_offset != msg_bits.size())
459
0
         {
460
0
         const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
461
0
462
0
         const std::vector<uint8_t> frag =
463
0
            format_fragment(&msg_bits[frag_offset],
464
0
                            frag_len,
465
0
                            static_cast<uint16_t>(frag_offset),
466
0
                            static_cast<uint16_t>(msg_bits.size()),
467
0
                            msg_type,
468
0
                            msg_seq);
469
0
470
0
         m_send_hs(epoch, HANDSHAKE, frag);
471
0
472
0
         frag_offset += frag_len;
473
0
         }
474
0
      }
475
0
476
0
   return no_fragment;
477
0
   }
478
479
}
480
}