Coverage Report

Created: 2020-11-21 08:34

/src/botan/src/lib/tls/tls_handshake_io.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* TLS Handshake IO
3
* (C) 2012,2014,2015 Jack Lloyd
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#include <botan/internal/tls_handshake_io.h>
9
#include <botan/internal/tls_record.h>
10
#include <botan/internal/tls_seq_numbers.h>
11
#include <botan/tls_messages.h>
12
#include <botan/exceptn.h>
13
#include <botan/internal/loadstor.h>
14
#include <chrono>
15
16
namespace Botan {
17
18
namespace TLS {
19
20
namespace {
21
22
inline size_t load_be24(const uint8_t q[3])
23
9.62k
   {
24
9.62k
   return make_uint32(0,
25
9.62k
                      q[0],
26
9.62k
                      q[1],
27
9.62k
                      q[2]);
28
9.62k
   }
29
30
void store_be24(uint8_t out[3], size_t val)
31
152k
   {
32
152k
   out[0] = get_byte(1, static_cast<uint32_t>(val));
33
152k
   out[1] = get_byte(2, static_cast<uint32_t>(val));
34
152k
   out[2] = get_byte(3, static_cast<uint32_t>(val));
35
152k
   }
36
37
uint64_t steady_clock_ms()
38
0
   {
39
0
   return std::chrono::duration_cast<std::chrono::milliseconds>(
40
0
      std::chrono::steady_clock::now().time_since_epoch()).count();
41
0
   }
42
43
}
44
45
Protocol_Version Stream_Handshake_IO::initial_record_version() const
46
47.0k
   {
47
47.0k
   return Protocol_Version::TLS_V10;
48
47.0k
   }
49
50
void Stream_Handshake_IO::add_record(const uint8_t record[],
51
                                     size_t record_len,
52
                                     Record_Type record_type, uint64_t)
53
82.0k
   {
54
82.0k
   if(record_type == HANDSHAKE)
55
70.8k
      {
56
70.8k
      m_queue.insert(m_queue.end(), record, record + record_len);
57
70.8k
      }
58
11.1k
   else if(record_type == CHANGE_CIPHER_SPEC)
59
11.1k
      {
60
11.1k
      if(record_len != 1 || record[0] != 1)
61
124
         throw Decoding_Error("Invalid ChangeCipherSpec");
62
63
      // Pretend it's a regular handshake message of zero length
64
11.0k
      const uint8_t ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 };
65
11.0k
      m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs));
66
11.0k
      }
67
0
   else
68
0
      throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing");
69
82.0k
   }
70
71
std::pair<Handshake_Type, std::vector<uint8_t>>
72
Stream_Handshake_IO::get_next_record(bool)
73
137k
   {
74
137k
   if(m_queue.size() >= 4)
75
88.7k
      {
76
88.7k
      const size_t length = 4 + make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
77
78
88.7k
      if(m_queue.size() >= length)
79
61.2k
         {
80
61.2k
         Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
81
82
61.2k
         if(type == HANDSHAKE_NONE)
83
9
            throw Decoding_Error("Invalid handshake message type");
84
85
61.2k
         std::vector<uint8_t> contents(m_queue.begin() + 4,
86
61.2k
                                       m_queue.begin() + length);
87
88
61.2k
         m_queue.erase(m_queue.begin(), m_queue.begin() + length);
89
90
61.2k
         return std::make_pair(type, contents);
91
61.2k
         }
92
88.7k
      }
93
94
75.8k
   return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
95
75.8k
   }
96
97
std::vector<uint8_t>
98
Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
99
                            Handshake_Type type) const
100
152k
   {
101
152k
   std::vector<uint8_t> send_buf(4 + msg.size());
102
103
152k
   const size_t buf_size = msg.size();
104
105
152k
   send_buf[0] = static_cast<uint8_t>(type);
106
107
152k
   store_be24(&send_buf[1], buf_size);
108
109
152k
   if (msg.size() > 0)
110
124k
      {
111
124k
      copy_mem(&send_buf[4], msg.data(), msg.size());
112
124k
      }
113
114
152k
   return send_buf;
115
152k
   }
116
117
std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
118
0
   {
119
0
   throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
120
0
   }
121
122
std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
123
94.2k
   {
124
94.2k
   const std::vector<uint8_t> msg_bits = msg.serialize();
125
126
94.2k
   if(msg.type() == HANDSHAKE_CCS)
127
1.63k
      {
128
1.63k
      m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
129
1.63k
      return std::vector<uint8_t>(); // not included in handshake hashes
130
1.63k
      }
131
132
92.6k
   const std::vector<uint8_t> buf = format(msg_bits, msg.type());
133
92.6k
   m_send_hs(HANDSHAKE, buf);
134
92.6k
   return buf;
135
92.6k
   }
136
137
Protocol_Version Datagram_Handshake_IO::initial_record_version() const
138
578
   {
139
578
   return Protocol_Version::DTLS_V10;
140
578
   }
141
142
void Datagram_Handshake_IO::retransmit_last_flight()
143
0
   {
144
0
   const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
145
0
   retransmit_flight(flight_idx);
146
0
   }
147
148
void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
149
0
   {
150
0
   const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
151
152
0
   BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
153
154
0
   uint16_t epoch = m_flight_data[flight[0]].epoch;
155
156
0
   for(auto msg_seq : flight)
157
0
      {
158
0
      auto& msg = m_flight_data[msg_seq];
159
160
0
      if(msg.epoch != epoch)
161
0
         {
162
         // Epoch gap: insert the CCS
163
0
         std::vector<uint8_t> ccs(1, 1);
164
0
         m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs);
165
0
         }
166
167
0
      send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
168
0
      epoch = msg.epoch;
169
0
      }
170
0
   }
171
172
bool Datagram_Handshake_IO::timeout_check()
173
0
   {
174
0
   if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
175
0
      {
176
      /*
177
      If we haven't written anything yet obviously no timeout.
178
      Also no timeout possible if we are mid-flight,
179
      */
180
0
      return false;
181
0
      }
182
183
0
   const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
184
185
0
   if(ms_since_write < m_next_timeout)
186
0
      return false;
187
188
0
   retransmit_last_flight();
189
190
0
   m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
191
0
   return true;
192
0
   }
193
194
void Datagram_Handshake_IO::add_record(const uint8_t record[],
195
                                       size_t record_len,
196
                                       Record_Type record_type,
197
                                       uint64_t record_sequence)
198
839
   {
199
839
   const uint16_t epoch = static_cast<uint16_t>(record_sequence >> 48);
200
201
839
   if(record_type == CHANGE_CIPHER_SPEC)
202
28
      {
203
28
      if(record_len != 1 || record[0] != 1)
204
6
         throw Decoding_Error("Invalid ChangeCipherSpec");
205
206
      // TODO: check this is otherwise empty
207
22
      m_ccs_epochs.insert(epoch);
208
22
      return;
209
22
      }
210
211
811
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
212
213
3.96k
   while(record_len)
214
3.43k
      {
215
3.43k
      if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
216
222
         return; // completely bogus? at least degenerate/weird
217
218
3.20k
      const uint8_t msg_type = record[0];
219
3.20k
      const size_t msg_len = load_be24(&record[1]);
220
3.20k
      const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
221
3.20k
      const size_t fragment_offset = load_be24(&record[6]);
222
3.20k
      const size_t fragment_length = load_be24(&record[9]);
223
224
3.20k
      const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
225
226
3.20k
      if(record_len < total_size)
227
60
         throw Decoding_Error("Bad lengths in DTLS header");
228
229
3.14k
      if(message_seq >= m_in_message_seq)
230
3.14k
         {
231
3.14k
         m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
232
3.14k
                                              fragment_length,
233
3.14k
                                              fragment_offset,
234
3.14k
                                              epoch,
235
3.14k
                                              msg_type,
236
3.14k
                                              msg_len);
237
3.14k
         }
238
0
      else
239
0
         {
240
         // TODO: detect retransmitted flight
241
0
         }
242
243
3.14k
      record += total_size;
244
3.14k
      record_len -= total_size;
245
3.14k
      }
246
811
   }
247
248
std::pair<Handshake_Type, std::vector<uint8_t>>
249
Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
250
694
   {
251
   // Expecting a message means the last flight is concluded
252
694
   if(!m_flights.rbegin()->empty())
253
0
      m_flights.push_back(std::vector<uint16_t>());
254
255
694
   if(expecting_ccs)
256
0
      {
257
0
      if(!m_messages.empty())
258
0
         {
259
0
         const uint16_t current_epoch = m_messages.begin()->second.epoch();
260
261
0
         if(m_ccs_epochs.count(current_epoch))
262
0
            return std::make_pair(HANDSHAKE_CCS, std::vector<uint8_t>());
263
0
         }
264
0
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
265
0
      }
266
267
694
   auto i = m_messages.find(m_in_message_seq);
268
269
694
   if(i == m_messages.end() || !i->second.complete())
270
638
      {
271
638
      return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
272
638
      }
273
274
56
   m_in_message_seq += 1;
275
276
56
   return i->second.message();
277
56
   }
278
279
void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
280
   const uint8_t fragment[],
281
   size_t fragment_length,
282
   size_t fragment_offset,
283
   uint16_t epoch,
284
   uint8_t msg_type,
285
   size_t msg_length)
286
3.14k
   {
287
3.14k
   if(complete())
288
1.67k
      return; // already have entire message, ignore this
289
290
1.47k
   if(m_msg_type == HANDSHAKE_NONE)
291
1.21k
      {
292
1.21k
      m_epoch = epoch;
293
1.21k
      m_msg_type = msg_type;
294
1.21k
      m_msg_length = msg_length;
295
1.21k
      }
296
297
1.47k
   if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
298
49
      throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header");
299
300
1.42k
   if(fragment_offset > m_msg_length)
301
18
      throw Decoding_Error("Fragment offset past end of message");
302
303
1.40k
   if(fragment_offset + fragment_length > m_msg_length)
304
12
      throw Decoding_Error("Fragment overlaps past end of message");
305
306
1.39k
   if(fragment_offset == 0 && fragment_length == m_msg_length)
307
521
      {
308
521
      m_fragments.clear();
309
521
      m_message.assign(fragment, fragment+fragment_length);
310
521
      }
311
873
   else
312
873
      {
313
      /*
314
      * FIXME. This is a pretty lame way to do defragmentation, huge
315
      * overhead with a tree node per byte.
316
      *
317
      * Also should confirm that all overlaps have no changes,
318
      * otherwise we expose ourselves to the classic fingerprinting
319
      * and IDS evasion attacks on IP fragmentation.
320
      */
321
16.7k
      for(size_t i = 0; i != fragment_length; ++i)
322
15.8k
         m_fragments[fragment_offset+i] = fragment[i];
323
324
873
      if(m_fragments.size() == m_msg_length)
325
40
         {
326
40
         m_message.resize(m_msg_length);
327
1.77k
         for(size_t i = 0; i != m_msg_length; ++i)
328
1.73k
            m_message[i] = m_fragments[i];
329
40
         m_fragments.clear();
330
40
         }
331
873
      }
332
1.39k
   }
333
334
bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
335
3.37k
   {
336
3.37k
   return (m_msg_type != HANDSHAKE_NONE && m_message.size() == m_msg_length);
337
3.37k
   }
338
339
std::pair<Handshake_Type, std::vector<uint8_t>>
340
Datagram_Handshake_IO::Handshake_Reassembly::message() const
341
56
   {
342
56
   if(!complete())
343
0
      throw Internal_Error("Datagram_Handshake_IO - message not complete");
344
345
56
   return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
346
56
   }
347
348
std::vector<uint8_t>
349
Datagram_Handshake_IO::format_fragment(const uint8_t fragment[],
350
                                       size_t frag_len,
351
                                       uint16_t frag_offset,
352
                                       uint16_t msg_len,
353
                                       Handshake_Type type,
354
                                       uint16_t msg_sequence) const
355
13
   {
356
13
   std::vector<uint8_t> send_buf(12 + frag_len);
357
358
13
   send_buf[0] = static_cast<uint8_t>(type);
359
360
13
   store_be24(&send_buf[1], msg_len);
361
362
13
   store_be(msg_sequence, &send_buf[4]);
363
364
13
   store_be24(&send_buf[6], frag_offset);
365
13
   store_be24(&send_buf[9], frag_len);
366
367
13
   if (frag_len > 0)
368
8
      {
369
8
      copy_mem(&send_buf[12], fragment, frag_len);
370
8
      }
371
372
13
   return send_buf;
373
13
   }
374
375
std::vector<uint8_t>
376
Datagram_Handshake_IO::format_w_seq(const std::vector<uint8_t>& msg,
377
                                    Handshake_Type type,
378
                                    uint16_t msg_sequence) const
379
13
   {
380
13
   return format_fragment(msg.data(), msg.size(), 0, static_cast<uint16_t>(msg.size()), type, msg_sequence);
381
13
   }
382
383
std::vector<uint8_t>
384
Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
385
                              Handshake_Type type) const
386
13
   {
387
13
   return format_w_seq(msg, type, m_in_message_seq - 1);
388
13
   }
389
390
std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
391
0
   {
392
0
   return this->send_under_epoch(msg, m_seqs.current_write_epoch());
393
0
   }
394
395
std::vector<uint8_t>
396
Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
397
0
   {
398
0
   const std::vector<uint8_t> msg_bits = msg.serialize();
399
0
   const Handshake_Type msg_type = msg.type();
400
401
0
   if(msg_type == HANDSHAKE_CCS)
402
0
      {
403
0
      m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
404
0
      return std::vector<uint8_t>(); // not included in handshake hashes
405
0
      }
406
0
   else if(msg_type == HELLO_VERIFY_REQUEST)
407
0
      {
408
      // This message is not included in the handshake hashes
409
0
      send_message(m_out_message_seq, epoch, msg_type, msg_bits);
410
0
      m_out_message_seq += 1;
411
0
      return std::vector<uint8_t>();
412
0
      }
413
414
   // Note: not saving CCS, instead we know it was there due to change in epoch
415
0
   m_flights.rbegin()->push_back(m_out_message_seq);
416
0
   m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
417
418
0
   m_out_message_seq += 1;
419
0
   m_last_write = steady_clock_ms();
420
0
   m_next_timeout = m_initial_timeout;
421
422
0
   return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
423
0
   }
424
425
std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
426
                                                      uint16_t epoch,
427
                                                      Handshake_Type msg_type,
428
                                                      const std::vector<uint8_t>& msg_bits)
429
0
   {
430
0
   const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
431
432
0
   const std::vector<uint8_t> no_fragment =
433
0
      format_w_seq(msg_bits, msg_type, msg_seq);
434
435
0
   if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
436
0
      {
437
0
      m_send_hs(epoch, HANDSHAKE, no_fragment);
438
0
      }
439
0
   else
440
0
      {
441
0
      size_t frag_offset = 0;
442
443
      /**
444
      * Largest possible overhead is for SHA-384 CBC ciphers, with 16 byte IV,
445
      * 16+ for padding and 48 bytes for MAC. 128 is probably a strict
446
      * over-estimate here. When CBC ciphers are removed this can be reduced
447
      * since AEAD modes have no padding, at most 16 byte mac, and smaller
448
      * per-record nonce.
449
      */
450
0
      const size_t ciphersuite_overhead = (epoch > 0) ? 128 : 0;
451
0
      const size_t header_overhead = DTLS_HEADER_SIZE + DTLS_HANDSHAKE_HEADER_LEN;
452
453
0
      if(m_mtu <= (header_overhead + ciphersuite_overhead))
454
0
         throw Invalid_Argument("DTLS MTU is too small to send headers");
455
456
0
      const size_t max_rec_size = m_mtu - (header_overhead + ciphersuite_overhead);
457
458
0
      while(frag_offset != msg_bits.size())
459
0
         {
460
0
         const size_t frag_len = std::min<size_t>(msg_bits.size() - frag_offset, max_rec_size);
461
462
0
         const std::vector<uint8_t> frag =
463
0
            format_fragment(&msg_bits[frag_offset],
464
0
                            frag_len,
465
0
                            static_cast<uint16_t>(frag_offset),
466
0
                            static_cast<uint16_t>(msg_bits.size()),
467
0
                            msg_type,
468
0
                            msg_seq);
469
470
0
         m_send_hs(epoch, HANDSHAKE, frag);
471
472
0
         frag_offset += frag_len;
473
0
         }
474
0
      }
475
476
0
   return no_fragment;
477
0
   }
478
479
}
480
}