Coverage Report

Created: 2020-11-21 08:34

/src/botan/build/include/botan/internal/tls_handshake_io.h
Line
Count
Source (jump to first uncovered line)
1
/*
2
* TLS Handshake Serialization
3
* (C) 2012,2014 Jack Lloyd
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#ifndef BOTAN_TLS_HANDSHAKE_IO_H_
9
#define BOTAN_TLS_HANDSHAKE_IO_H_
10
11
#include <botan/tls_magic.h>
12
#include <botan/tls_version.h>
13
#include <functional>
14
#include <vector>
15
#include <deque>
16
#include <map>
17
#include <set>
18
#include <utility>
19
20
namespace Botan {
21
22
namespace TLS {
23
24
class Handshake_Message;
25
26
/**
27
* Handshake IO Interface
28
*/
29
class Handshake_IO
30
   {
31
   public:
32
      virtual Protocol_Version initial_record_version() const = 0;
33
34
      virtual std::vector<uint8_t> send(const Handshake_Message& msg) = 0;
35
36
      virtual std::vector<uint8_t> send_under_epoch(const Handshake_Message& msg, uint16_t epoch) = 0;
37
38
      virtual bool timeout_check() = 0;
39
40
      virtual std::vector<uint8_t> format(
41
         const std::vector<uint8_t>& handshake_msg,
42
         Handshake_Type handshake_type) const = 0;
43
44
      virtual void add_record(const uint8_t record[],
45
                              size_t record_len,
46
                              Record_Type type,
47
                              uint64_t sequence_number) = 0;
48
49
      /**
50
      * Returns (HANDSHAKE_NONE, std::vector<>()) if no message currently available
51
      */
52
      virtual std::pair<Handshake_Type, std::vector<uint8_t>>
53
         get_next_record(bool expecting_ccs) = 0;
54
55
47.6k
      Handshake_IO() = default;
56
57
      Handshake_IO(const Handshake_IO&) = delete;
58
59
      Handshake_IO& operator=(const Handshake_IO&) = delete;
60
61
47.6k
      virtual ~Handshake_IO() = default;
62
   };
63
64
/**
65
* Handshake IO for stream-based handshakes
66
*/
67
class Stream_Handshake_IO final : public Handshake_IO
68
   {
69
   public:
70
      typedef std::function<void (uint8_t, const std::vector<uint8_t>&)> writer_fn;
71
72
47.0k
      explicit Stream_Handshake_IO(writer_fn writer) : m_send_hs(writer) {}
73
74
      Protocol_Version initial_record_version() const override;
75
76
0
      bool timeout_check() override { return false; }
77
78
      std::vector<uint8_t> send(const Handshake_Message& msg) override;
79
80
      std::vector<uint8_t> send_under_epoch(const Handshake_Message& msg, uint16_t epoch) override;
81
82
      std::vector<uint8_t> format(
83
         const std::vector<uint8_t>& handshake_msg,
84
         Handshake_Type handshake_type) const override;
85
86
      void add_record(const uint8_t record[],
87
                      size_t record_len,
88
                      Record_Type type,
89
                      uint64_t sequence_number) override;
90
91
      std::pair<Handshake_Type, std::vector<uint8_t>>
92
         get_next_record(bool expecting_ccs) override;
93
   private:
94
      std::deque<uint8_t> m_queue;
95
      writer_fn m_send_hs;
96
   };
97
98
/**
99
* Handshake IO for datagram-based handshakes
100
*/
101
class Datagram_Handshake_IO final : public Handshake_IO
102
   {
103
   public:
104
      typedef std::function<void (uint16_t, uint8_t, const std::vector<uint8_t>&)> writer_fn;
105
106
      Datagram_Handshake_IO(writer_fn writer,
107
                            class Connection_Sequence_Numbers& seq,
108
                            uint16_t mtu, uint64_t initial_timeout_ms, uint64_t max_timeout_ms) :
109
         m_seqs(seq),
110
         m_flights(1),
111
         m_initial_timeout(initial_timeout_ms),
112
         m_max_timeout(max_timeout_ms),
113
         m_send_hs(writer),
114
         m_mtu(mtu)
115
578
         {}
116
117
      Protocol_Version initial_record_version() const override;
118
119
      bool timeout_check() override;
120
121
      std::vector<uint8_t> send(const Handshake_Message& msg) override;
122
123
      std::vector<uint8_t> send_under_epoch(const Handshake_Message& msg, uint16_t epoch) override;
124
125
      std::vector<uint8_t> format(
126
         const std::vector<uint8_t>& handshake_msg,
127
         Handshake_Type handshake_type) const override;
128
129
      void add_record(const uint8_t record[],
130
                      size_t record_len,
131
                      Record_Type type,
132
                      uint64_t sequence_number) override;
133
134
      std::pair<Handshake_Type, std::vector<uint8_t>>
135
         get_next_record(bool expecting_ccs) override;
136
   private:
137
      void retransmit_flight(size_t flight);
138
      void retransmit_last_flight();
139
140
      std::vector<uint8_t> format_fragment(
141
         const uint8_t fragment[],
142
         size_t fragment_len,
143
         uint16_t frag_offset,
144
         uint16_t msg_len,
145
         Handshake_Type type,
146
         uint16_t msg_sequence) const;
147
148
      std::vector<uint8_t> format_w_seq(
149
         const std::vector<uint8_t>& handshake_msg,
150
         Handshake_Type handshake_type,
151
         uint16_t msg_sequence) const;
152
153
      std::vector<uint8_t> send_message(uint16_t msg_seq, uint16_t epoch,
154
                                     Handshake_Type msg_type,
155
                                     const std::vector<uint8_t>& msg);
156
157
      class Handshake_Reassembly final
158
         {
159
         public:
160
            void add_fragment(const uint8_t fragment[],
161
                              size_t fragment_length,
162
                              size_t fragment_offset,
163
                              uint16_t epoch,
164
                              uint8_t msg_type,
165
                              size_t msg_length);
166
167
            bool complete() const;
168
169
0
            uint16_t epoch() const { return m_epoch; }
170
171
            std::pair<Handshake_Type, std::vector<uint8_t>> message() const;
172
         private:
173
            uint8_t m_msg_type = HANDSHAKE_NONE;
174
            size_t m_msg_length = 0;
175
            uint16_t m_epoch = 0;
176
177
            // vector<bool> m_seen;
178
            // vector<uint8_t> m_fragments
179
            std::map<size_t, uint8_t> m_fragments;
180
            std::vector<uint8_t> m_message;
181
         };
182
183
      struct Message_Info final
184
         {
185
         Message_Info(uint16_t e, Handshake_Type mt, const std::vector<uint8_t>& msg) :
186
0
            epoch(e), msg_type(mt), msg_bits(msg) {}
187
188
0
         Message_Info() : epoch(0xFFFF), msg_type(HANDSHAKE_NONE) {}
189
190
         uint16_t epoch;
191
         Handshake_Type msg_type;
192
         std::vector<uint8_t> msg_bits;
193
         };
194
195
      class Connection_Sequence_Numbers& m_seqs;
196
      std::map<uint16_t, Handshake_Reassembly> m_messages;
197
      std::set<uint16_t> m_ccs_epochs;
198
      std::vector<std::vector<uint16_t>> m_flights;
199
      std::map<uint16_t, Message_Info> m_flight_data;
200
201
      uint64_t m_initial_timeout = 0;
202
      uint64_t m_max_timeout = 0;
203
204
      uint64_t m_last_write = 0;
205
      uint64_t m_next_timeout = 0;
206
207
      uint16_t m_in_message_seq = 0;
208
      uint16_t m_out_message_seq = 0;
209
210
      writer_fn m_send_hs;
211
      uint16_t m_mtu;
212
   };
213
214
}
215
216
}
217
218
#endif