Coverage Report

Created: 2020-05-23 13:54

/src/botan/src/lib/tls/tls_handshake_io.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* TLS Handshake IO
3
* (C) 2012,2014,2015 Jack Lloyd
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#include <botan/internal/tls_handshake_io.h>
9
#include <botan/internal/tls_record.h>
10
#include <botan/internal/tls_seq_numbers.h>
11
#include <botan/tls_messages.h>
12
#include <botan/exceptn.h>
13
#include <botan/loadstor.h>
14
#include <chrono>
15
16
namespace Botan {
17
18
namespace TLS {
19
20
namespace {
21
22
inline size_t load_be24(const uint8_t q[3])
23
6.87k
   {
24
6.87k
   return make_uint32(0,
25
6.87k
                      q[0],
26
6.87k
                      q[1],
27
6.87k
                      q[2]);
28
6.87k
   }
29
30
void store_be24(uint8_t out[3], size_t val)
31
136k
   {
32
136k
   out[0] = get_byte(1, static_cast<uint32_t>(val));
33
136k
   out[1] = get_byte(2, static_cast<uint32_t>(val));
34
136k
   out[2] = get_byte(3, static_cast<uint32_t>(val));
35
136k
   }
36
37
uint64_t steady_clock_ms()
38
0
   {
39
0
   return std::chrono::duration_cast<std::chrono::milliseconds>(
40
0
      std::chrono::steady_clock::now().time_since_epoch()).count();
41
0
   }
42
43
}
44
45
Protocol_Version Stream_Handshake_IO::initial_record_version() const
46
44.6k
   {
47
44.6k
   return Protocol_Version::TLS_V10;
48
44.6k
   }
49
50
void Stream_Handshake_IO::add_record(const uint8_t record[],
51
                                     size_t record_len,
52
                                     Record_Type record_type, uint64_t)
53
79.1k
   {
54
79.1k
   if(record_type == HANDSHAKE)
55
71.2k
      {
56
71.2k
      m_queue.insert(m_queue.end(), record, record + record_len);
57
71.2k
      }
58
7.88k
   else if(record_type == CHANGE_CIPHER_SPEC)
59
7.88k
      {
60
7.88k
      if(record_len != 1 || record[0] != 1)
61
118
         throw Decoding_Error("Invalid ChangeCipherSpec");
62
7.76k
63
7.76k
      // Pretend it's a regular handshake message of zero length
64
7.76k
      const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
65
7.76k
      m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
66
7.76k
      }
67
0
   else
68
0
      throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
69
79.1k
   }
70
71
std::pair<Handshake_Type, std::vector<uint8_t>>
72
Stream_Handshake_IO::get_next_record(bool)
73
128k
   {
74
128k
   if(m_queue.size() >= 4)
75
84.3k
      {
76
84.3k
      const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
77
84.3k
78
84.3k
      if(m_queue.size() >= length)
79
55.1k
         {
80
55.1k
         Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
81
55.1k
82
55.1k
         if(type == HANDSHAKE_NONE)
83
6
            throw Decoding_Error("Invalid handshake message type");
84
55.1k
85
55.1k
         std::vector<uint8_t> contents(m_queue.begin() + 4,
86
55.1k
                                       m_queue.begin() + length);
87
55.1k
88
55.1k
         m_queue.erase(m_queue.begin(), m_queue.begin() + length);
89
55.1k
90
55.1k
         return std::make_pair(type, contents);
91
55.1k
         }
92
84.3k
      }
93
73.3k
94
73.3k
   return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
95
73.3k
   }
96
97
std::vector<uint8_t>
98
Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
99
                            Handshake_Type type) const
100
136k
   {
101
136k
   std::vector<uint8_t> send_buf(4 + msg.size());
102
136k
103
136k
   const size_t buf_size = msg.size();
104
136k
105
136k
   send_buf[0] = static_cast<uint8_t>(type);
106
136k
107
136k
   store_be24(&send_buf[1], buf_size);
108
136k
109
136k
   if (msg.size() > 0)
110
111k
      {
111
111k
      copy_mem(&send_buf[4], msg.data(), msg.size());
112
111k
      }
113
136k
114
136k
   return send_buf;
115
136k
   }
116
117
std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
118
0
   {
119
0
   throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
120
0
   }
121
122
std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
123
83.6k
   {
124
83.6k
   const std::vector<uint8_t> msg_bits = msg.serialize();
125
83.6k
126
83.6k
   if(msg.type() == HANDSHAKE_CCS)
127
1.20k
      {
128
1.20k
      m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
129
1.20k
      return std::vector<uint8_t>(); // not included in handshake hashes
130
1.20k
      }
131
82.4k
132
82.4k
   const std::vector<uint8_t> buf = format(msg_bits, msg.type());
133
82.4k
   m_send_hs(HANDSHAKE, buf);
134
82.4k
   return buf;
135
82.4k
   }
136
137
Protocol_Version Datagram_Handshake_IO::initial_record_version() const
138
569
   {
139
569
   return Protocol_Version::DTLS_V10;
140
569
   }
141
142
void Datagram_Handshake_IO::retransmit_last_flight()
143
0
   {
144
0
   const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
145
0
   retransmit_flight(flight_idx);
146
0
   }
147
148
void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
149
0
   {
150
0
   const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
151
0
152
0
   BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
153
0
154
0
   uint16_t epoch = m_flight_data[flight[0]].epoch;
155
0
156
0
   for(auto msg_seq : flight)
157
0
      {
158
0
      auto& msg = m_flight_data[msg_seq];
159
0
160
0
      if(msg.epoch != epoch)
161
0
         {
162
0
         // Epoch gap: insert the CCS
163
0
         std::vector<uint8_t> ccs(1, 1);
164
0
         m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
165
0
         }
166
0
167
0
      send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
168
0
      epoch = msg.epoch;
169
0
      }
170
0
   }
171
172
bool Datagram_Handshake_IO::timeout_check()
173
0
   {
174
0
   if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
175
0
      {
176
0
      /*
177
0
      If we haven't written anything yet obviously no timeout.
178
0
      Also no timeout possible if we are mid-flight,
179
0
      */
180
0
      return false;
181
0
      }
182
0
183
0
   const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
184
0
185
0
   if(ms_since_write < m_next_timeout)
186
0
      return false;
187
0
188
0
   retransmit_last_flight();
189
0
190
0
   m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
191
0
   return true;
192
0
   }
193
194
void Datagram_Handshake_IO::add_record(const uint8_t record[],
195
                                       size_t record_len,
196
                                       Record_Type record_type,
197
                                       uint64_t record_sequence)
198
759
   {
199
759
   const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
200
759
201
759
   if(record_type == CHANGE_CIPHER_SPEC)
202
21
      {
203
21
      if(record_len != 1 || record[0] != 1)
204
8
         throw Decoding_Error("Invalid ChangeCipherSpec");
205
13
206
13
      // TODO: check this is otherwise empty
207
13
      m_ccs_epochs.insert(epoch);
208
13
      return;
209
13
      }
210
738
211
738
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
212
738
213
2.97k
   while(record_len)
214
2.46k
      {
215
2.46k
      if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
216
173
         return; // completely bogus? at least degenerate/weird
217
2.29k
218
2.29k
      const uint8_t msg_type = record[0];
219
2.29k
      const size_t msg_len = load_be24(&record[1]);
220
2.29k
      const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
221
2.29k
      const size_t fragment_offset = load_be24(&record[6]);
222
2.29k
      const size_t fragment_length = load_be24(&record[9]);
223
2.29k
224
2.29k
      const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
225
2.29k
226
2.29k
      if(record_len < total_size)
227
52
         throw Decoding_Error("Bad lengths in DTLS header");
228
2.23k
229
2.23k
      if(message_seq >= m_in_message_seq)
230
2.23k
         {
231
2.23k
         m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
232
2.23k
                                              fragment_length,
233
2.23k
                                              fragment_offset,
234
2.23k
                                              epoch,
235
2.23k
                                              msg_type,
236
2.23k
                                              msg_len);
237
2.23k
         }
238
0
      else
239
0
         {
240
0
         // TODO: detect retransmitted flight
241
0
         }
242
2.23k
243
2.23k
      record += total_size;
244
2.23k
      record_len -= total_size;
245
2.23k
      }
246
738
   }
247
248
std::pair<Handshake_Type, std::vector<uint8_t>>
249
Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
250
620
   {
251
620
   // Expecting a message means the last flight is concluded
252
620
   if(!m_flights.rbegin()->empty())
253
0
      m_flights.push_back(std::vector<uint16_t>());
254
620
255
620
   if(expecting_ccs)
256
0
      {
257
0
      if(!m_messages.empty())
258
0
         {
259
0
         const uint16_t current_epoch = m_messages.begin()->second.epoch();
260
0
261
0
         if(m_ccs_epochs.count(current_epoch))
262
0
            return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
263
0
         }
264
0
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
265
0
      }
266
620
267
620
   auto i = m_messages.find(m_in_message_seq);
268
620
269
620
   if(i == m_messages.end() || !i->second.complete())
270
556
      {
271
556
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
272
556
      }
273
64
274
64
   m_in_message_seq += 1;
275
64
276
64
   return i->second.message();
277
64
   }
278
279
void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
280
   const uint8_t fragment[],
281
   size_t fragment_length,
282
   size_t fragment_offset,
283
   uint16_t epoch,
284
   uint8_t msg_type,
285
   size_t msg_length)
286
2.23k
   {
287
2.23k
   if(complete())
288
1.06k
      return; // already have entire message, ignore this
289
1.17k
290
1.17k
   if(m_msg_type == HANDSHAKE_NONE)
291
982
      {
292
982
      m_epoch = epoch;
293
982
      m_msg_type = msg_type;
294
982
      m_msg_length = msg_length;
295
982
      }
296
1.17k
297
1.17k
   if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
298
54
      throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
299
1.11k
300
1.11k
   if(fragment_offset > m_msg_length)
301
13
      throw Decoding_Error("Fragment offset past end of message");
302
1.10k
303
1.10k
   if(fragment_offset + fragment_length > m_msg_length)
304
12
      throw Decoding_Error("Fragment overlaps past end of message");
305
1.09k
306
1.09k
   if(fragment_offset == 0 && fragment_length == m_msg_length)
307
318
      {
308
318
      m_fragments.clear();
309
318
      m_message.assign(fragment, fragment+fragment_length);
310
318
      }
311
776
   else
312
776
      {
313
776
      /*
314
776
      * FIXME. This is a pretty lame way to do defragmentation, huge
315
776
      * overhead with a tree node per byte.
316
776
      *
317
776
      * Also should confirm that all overlaps have no changes,
318
776
      * otherwise we expose ourselves to the classic fingerprinting
319
776
      * and IDS evasion attacks on IP fragmentation.
320
776
      */
321
26.9k
      for(size_t i = 0; i != fragment_length; ++i)
322
26.1k
         m_fragments[fragment_offset+i] = fragment[i];
323
776
324
776
      if(m_fragments.size() == m_msg_length)
325
40
         {
326
40
         m_message.resize(m_msg_length);
327
2.05k
         for(size_t i = 0; i != m_msg_length; ++i)
328
2.01k
            m_message[i] = m_fragments[i];
329
40
         m_fragments.clear();
330
40
         }
331
776
      }
332
1.09k
   }
333
334
bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
335
2.47k
   {
336
2.47k
   return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
337
2.47k
   }
338
339
std::pair<Handshake_Type, std::vector<uint8_t>>
340
Datagram_Handshake_IO::Handshake_Reassembly::message() const
341
64
   {
342
64
   if(!complete())
343
0
      throw Internal_Error("Datagram_Handshake_IO - message not complete");
344
64
345
64
   return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
346
64
   }
347
348
std::vector<uint8_t>
349
Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
350
                                       size_t frag_len,
351
                                       uint16_t frag_offset,
352
                                       uint16_t msg_len,
353
                                       Handshake_Type type,
354
                                       uint16_t msg_sequence) const
355
13
   {
356
13
   std::vector<uint8_t> send_buf(12 + frag_len);
357
13
358
13
   send_buf[0] = static_cast<uint8_t>(type);
359
13
360
13
   store_be24(&send_buf[1], msg_len);
361
13
362
13
   store_be(msg_sequence, &send_buf[4]);
363
13
364
13
   store_be24(&send_buf[6], frag_offset);
365
13
   store_be24(&send_buf[9], frag_len);
366
13
367
13
   if (frag_len > 0)
368
9
      {
369
9
      copy_mem(&send_buf[12], fragment, frag_len);
370
9
      }
371
13
372
13
   return send_buf;
373
13
   }
374
375
std::vector<uint8_t>
376
Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
377
                                    Handshake_Type type,
378
                                    uint16_t msg_sequence) const
379
13
   {
380
13
   return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
381
13
   }
382
383
std::vector<uint8_t>
384
Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
385
                              Handshake_Type type) const
386
13
   {
387
13
   return format_w_seq(msg, type, m_in_message_seq - 1);
388
13
   }
389
390
std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
391
0
   {
392
0
   return this->send_under_epoch(msg, m_seqs.current_write_epoch());
393
0
   }
394
395
std::vector<uint8_t>
396
Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
397
0
   {
398
0
   const std::vector<uint8_t> msg_bits = msg.serialize();
399
0
   const Handshake_Type msg_type = msg.type();
400
0
401
0
   if(msg_type == HANDSHAKE_CCS)
402
0
      {
403
0
      m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
404
0
      return std::vector<uint8_t>(); // not included in handshake hashes
405
0
      }
406
0
   else if(msg_type == HELLO_VERIFY_REQUEST)
407
0
      {
408
0
      // This message is not included in the handshake hashes
409
0
      send_message(m_out_message_seq, epoch, msg_type, msg_bits);
410
0
      m_out_message_seq += 1;
411
0
      return std::vector<uint8_t>();
412
0
      }
413
0
414
0
   // Note: not saving CCS, instead we know it was there due to change in epoch
415
0
   m_flights.rbegin()->push_back(m_out_message_seq);
416
0
   m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
417
0
418
0
   m_out_message_seq += 1;
419
0
   m_last_write = steady_clock_ms();
420
0
   m_next_timeout = m_initial_timeout;
421
0
422
0
   return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
423
0
   }
424
425
std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
426
                                                      uint16_t epoch,
427
                                                      Handshake_Type msg_type,
428
                                                      const std::vector<uint8_t>& msg_bits)
429
0
   {
430
0
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
431
0
432
0
   const std::vector<uint8_t> no_fragment =
433
0
      format_w_seq(msg_bits, msg_type, msg_seq);
434
0
435
0
   if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
436
0
      {
437
0
      m_send_hs(epoch, HANDSHAKE, no_fragment);
438
0
      }
439
0
   else
440
0
      {
441
0
      size_t frag_offset = 0;
442
0
443
0
      /**
444
0
      * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
445
0
      * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
446
0
      * over-estimate here. When CBC ciphers are removed this can be reduced
447
0
      * since AEAD modes have no padding, at most 16 byte mac, and smaller
448
0
      * per-record nonce.
449
0
      */
450
0
      const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
451
0
      const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
452
0
453
0
      if(m_mtu <= (header_overhead + ciphersuite_overhead))
454
0
         throw Invalid_Argument("DTLS MTU is too small to send headers");
455
0
456
0
      const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
457
0
458
0
      while(frag_offset != msg_bits.size())
459
0
         {
460
0
         const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
461
0
462
0
         const std::vector<uint8_t> frag =
463
0
            format_fragment(&msg_bits[frag_offset],
464
0
                            frag_len,
465
0
                            static_cast<uint16_t>(frag_offset),
466
0
                            static_cast<uint16_t>(msg_bits.size()),
467
0
                            msg_type,
468
0
                            msg_seq);
469
0
470
0
         m_send_hs(epoch, HANDSHAKE, frag);
471
0
472
0
         frag_offset += frag_len;
473
0
         }
474
0
      }
475
0
476
0
   return no_fragment;
477
0
   }
478
479
}
480
}