1
#include "source/extensions/quic/connection_id_generator/quic_lb/quic_lb.h"
2

            
3
#include "envoy/server/transport_socket_config.h"
4

            
5
#include "source/common/common/base64.h"
6
#include "source/common/config/datasource.h"
7
#include "source/common/network/socket_option_impl.h"
8
#include "source/common/quic/envoy_quic_utils.h"
9

            
10
#include "quiche/quic/load_balancer/load_balancer_encoder.h"
11

            
12
namespace Envoy {
13
namespace Quic {
14
namespace Extensions {
15
namespace ConnectionIdGenerator {
16
namespace QuicLb {
17

            
18
QuicLbConnectionIdGenerator::QuicLbConnectionIdGenerator(
19
    ThreadLocal::TypedSlot<ThreadLocalData>& tls, uint32_t worker_id)
20
23
    : tls_slot_(tls), worker_id_(worker_id) {
21
23
  ASSERT(worker_id <= UINT8_MAX,
22
23
         "worker id constraint should have been validated in Factory::create");
23
23
}
24

            
25
absl::optional<quic::QuicConnectionId>
26
20
QuicLbConnectionIdGenerator::GenerateNextConnectionId(const quic::QuicConnectionId&) {
27
  // Encoder doesn't allow generating connection IDs when not using encryption to prevent making
28
  // a new linkable connection ID.
29
  //
30
  // This code is unsafe right now and for testing; override that by calling
31
  // GenerateConnectionId().
32
20
  auto new_cid = tls_slot_->encoder_.GenerateConnectionId();
33
20
  return appendRoutingId(new_cid);
34
20
}
35

            
36
absl::optional<quic::QuicConnectionId>
37
QuicLbConnectionIdGenerator::MaybeReplaceConnectionId(const quic::QuicConnectionId& original,
38
18
                                                      const quic::ParsedQuicVersion& version) {
39
18
  auto new_cid = tls_slot_->encoder_.MaybeReplaceConnectionId(original, version);
40
18
  if (new_cid.has_value()) {
41
18
    return appendRoutingId(new_cid.value());
42
18
  }
43

            
44
  return absl::nullopt;
45
18
}
46

            
47
89
uint8_t QuicLbConnectionIdGenerator::ConnectionIdLength(uint8_t first_byte) const {
48
89
  return tls_slot_->encoder_.ConnectionIdLength(first_byte) + sizeof(WorkerRoutingIdValue);
49
89
}
50

            
51
absl::optional<quic::QuicConnectionId>
52
39
QuicLbConnectionIdGenerator::appendRoutingId(quic::QuicConnectionId& new_connection_id) {
53
  // TODO(ggreenway): the thread ID should be encoded and protected in such a way that the value
54
  // does not require decrypting the CID, but that a passive observer can not easily link
55
  // thread IDs for different CIDs on the same connection. See
56
  // https://datatracker.ietf.org/doc/html/draft-ietf-quic-load-balancers#name-server-process-demultiplexi.
57

            
58
39
  uint8_t buffer[quic::kQuicMaxConnectionIdWithLengthPrefixLength];
59

            
60
39
  const uint16_t new_length = new_connection_id.length() + sizeof(WorkerRoutingIdValue);
61
39
  if (new_length > sizeof(buffer)) {
62
1
    IS_ENVOY_BUG("Connection id long");
63
1
    return {};
64
1
  }
65

            
66
38
  memcpy(buffer, new_connection_id.data(), new_connection_id.length()); // NOLINT(safe-memcpy)
67

            
68
38
  WorkerRoutingIdValue* routing_info_destination =
69
38
      reinterpret_cast<WorkerRoutingIdValue*>(buffer + new_connection_id.length());
70

            
71
38
  static_assert(
72
38
      sizeof(*routing_info_destination) == sizeof(uint8_t),
73
38
      "Below line needs memcpy due to possibly unaligned data if the size is not a single byte");
74

            
75
  // Stamp the id as the trailing byte. This adds a small amount of linkability.
76
38
  *routing_info_destination = worker_id_;
77

            
78
38
  ENVOY_LOG_MISC(trace, "generating new connection id for worker_id {}, len {} {}", worker_id_,
79
38
                 new_length,
80
38
                 quic::QuicConnectionId(absl::Span<const uint8_t>(buffer, new_length)).ToString());
81

            
82
38
  return quic::QuicConnectionId(absl::Span<const uint8_t>(buffer, new_length));
83
39
}
84

            
85
QuicLbConnectionIdGenerator::ThreadLocalData::ThreadLocalData(
86
    const envoy::extensions::quic::connection_id_generator::quic_lb::v3::Config& config,
87
    absl::string_view server_id)
88
42
    : encoder_(*quic::QuicRandom::GetInstance(), nullptr /* visitor */,
89
42
               true /* len_self_encoded */),
90
42
      unencrypted_mode_(config.unencrypted_mode()),
91
42
      nonce_length_bytes_(config.nonce_length_bytes()), server_id_(server_id) {}
92

            
93
absl::StatusOr<std::shared_ptr<QuicLbConnectionIdGenerator::ThreadLocalData>>
94
QuicLbConnectionIdGenerator::ThreadLocalData::create(
95
    const envoy::extensions::quic::connection_id_generator::quic_lb::v3::Config& config,
96
42
    const absl::string_view server_id) {
97

            
98
42
  std::shared_ptr<QuicLbConnectionIdGenerator::ThreadLocalData> ret(
99
42
      new QuicLbConnectionIdGenerator::ThreadLocalData(config, server_id));
100

            
101
42
  return ret;
102
42
}
103

            
104
absl::Status
105
QuicLbConnectionIdGenerator::ThreadLocalData::updateKeyAndVersion(absl::string_view key,
106
33
                                                                  uint8_t version) {
107
33
  absl::optional<quic::LoadBalancerConfig> lb_config;
108
33
  if (unencrypted_mode_) {
109
4
    lb_config = quic::LoadBalancerConfig::CreateUnencrypted(version, server_id_.length(),
110
4
                                                            nonce_length_bytes_);
111
29
  } else {
112
29
    lb_config =
113
29
        quic::LoadBalancerConfig::Create(version, server_id_.length(), nonce_length_bytes_, key);
114
29
  }
115

            
116
33
  if (!lb_config.has_value()) {
117
    return absl::InvalidArgumentError("error generating quic::LoadBalancerConfig");
118
  }
119
33
  bool success = encoder_.UpdateConfig(lb_config.value(), server_id_);
120
33
  if (!success) {
121
    return absl::InvalidArgumentError("Error setting configuration of quic-lb CID encoder");
122
  }
123

            
124
33
  return absl::OkStatus();
125
33
}
126

            
127
namespace {
128

            
129
Secret::GenericSecretConfigProviderSharedPtr
130
secretsProvider(const envoy::extensions::transport_sockets::tls::v3::SdsSecretConfig& config,
131
                Server::Configuration::ServerFactoryContext& server_context,
132
16
                Init::Manager& init_manager) {
133
16
  if (config.has_sds_config()) {
134
4
    return server_context.secretManager().findOrCreateGenericSecretProvider(
135
4
        config.sds_config(), config.name(), server_context, init_manager);
136
15
  } else {
137
12
    return server_context.secretManager().findStaticGenericSecretProvider(config.name());
138
12
  }
139
16
}
140

            
141
struct KeyAndVersion {
142
  std::string encryption_key_;
143
  uint8_t configuration_version_;
144
};
145

            
146
absl::StatusOr<KeyAndVersion> getAndValidateKeyAndVersion(
147
15
    const envoy::extensions::transport_sockets::tls::v3::GenericSecret& secret, Api::Api& api) {
148
15
  const auto& secrets = secret.secrets();
149

            
150
15
  auto key_it = secrets.find("encryption_key");
151
15
  if (key_it == secrets.end()) {
152
1
    return absl::InvalidArgumentError("Missing 'encryption_key'");
153
1
  }
154

            
155
14
  auto key_or_result = Config::DataSource::read(key_it->second, false, api);
156
14
  RETURN_IF_NOT_OK_REF(key_or_result.status());
157

            
158
14
  const std::string key = key_or_result.value();
159
14
  if (key.size() != quic::kLoadBalancerKeyLen) {
160
1
    return absl::InvalidArgumentError(
161
1
        fmt::format("'encryption_key' length was {}, but it must be length {}", key.size(),
162
1
                    quic::kLoadBalancerKeyLen));
163
1
  }
164

            
165
13
  auto version_it = secrets.find("configuration_version");
166
13
  if (version_it == secrets.end()) {
167
1
    return absl::InvalidArgumentError("Missing 'configuration_version'");
168
1
  }
169

            
170
12
  auto version_or_result = Config::DataSource::read(version_it->second, false, api);
171
12
  RETURN_IF_NOT_OK_REF(version_or_result.status());
172

            
173
12
  if (version_or_result.value().size() != sizeof(uint8_t)) {
174
1
    return absl::InvalidArgumentError(
175
1
        fmt::format("'configuration_version' length was {}, but it must be length 1 byte",
176
1
                    version_or_result.value().size()));
177
1
  }
178
11
  const uint8_t version = version_or_result.value().data()[0];
179
11
  if (version >= quic::kNumLoadBalancerConfigs) {
180
1
    return absl::InvalidArgumentError(
181
1
        fmt::format("'configuration_version' was {}, but must be less than {}", version,
182
1
                    quic::kNumLoadBalancerConfigs));
183
1
  }
184

            
185
10
  return KeyAndVersion{std::move(key), version};
186
11
}
187

            
188
} // namespace
189

            
190
Factory::Factory(
191
    const envoy::extensions::quic::connection_id_generator::quic_lb::v3::Config& config)
192
18
    : config_(config) {}
193

            
194
absl::StatusOr<std::unique_ptr<Factory>>
195
Factory::create(const envoy::extensions::quic::connection_id_generator::quic_lb::v3::Config& config,
196
19
                Server::Configuration::FactoryContext& context) {
197
  // The worker ID is stored in a single byte in the connection ID, so restrict use to concurrency
198
  // values compatible with this scheme.
199
  // TODO(ggreenway): use additional bytes for higher concurrency.
200
19
  if (context.serverFactoryContext().options().concurrency() >= UINT8_MAX) {
201
1
    return absl::InvalidArgumentError("envoy.quic.connection_id_generator.quic_lb cannot be used "
202
1
                                      "with a concurrency greater than 256");
203
1
  }
204

            
205
18
  std::unique_ptr<Factory> ret(new Factory(config));
206

            
207
18
  auto server_id_or_result =
208
18
      Config::DataSource::read(config.server_id(), false, context.serverFactoryContext().api());
209
18
  RETURN_IF_NOT_OK_REF(server_id_or_result.status());
210

            
211
18
  std::string server_id = server_id_or_result.value();
212

            
213
18
  if (config.server_id_base64_encoded()) {
214
1
    server_id = Base64::decodeWithoutPadding(server_id);
215
1
  }
216

            
217
18
  if (config.expected_server_id_length() > 0 &&
218
18
      config.expected_server_id_length() != server_id.size()) {
219
1
    return absl::InvalidArgumentError(
220
1
        fmt::format("'expected_server_id_length' {} does not match actual 'server_id' length {}",
221
1
                    config.expected_server_id_length(), server_id.size()));
222
1
  }
223

            
224
17
  constexpr auto cid_length_and_version_prefix = sizeof(uint8_t);
225
17
  constexpr auto fixed_components_len =
226
17
      cid_length_and_version_prefix + sizeof(WorkerRoutingIdValue);
227
17
  if ((server_id.size() + config.nonce_length_bytes() + fixed_components_len) >
228
17
      quic::kQuicMaxConnectionIdWithLengthPrefixLength) {
229
1
    static_assert(quic::kQuicMaxConnectionIdWithLengthPrefixLength - fixed_components_len == 18);
230
1
    return absl::InvalidArgumentError(fmt::format(
231
1
        "'server_id' length ({}) and 'nonce_length_bytes' ({}) combined must be 18 bytes or less.",
232
1
        server_id.size(), config.nonce_length_bytes()));
233
1
  }
234

            
235
  // Create a test instance using all the same parameters, but with a fake key (because we don't
236
  // have the real key yet) to surface any errors while we're still in the config loading stage.
237
16
  {
238
16
    auto test_instance_or_result =
239
16
        QuicLbConnectionIdGenerator::ThreadLocalData::create(config, server_id);
240
16
    RETURN_IF_NOT_OK_REF(test_instance_or_result.status());
241
16
    std::string test_key(quic::kLoadBalancerKeyLen, '0');
242
16
    auto result = test_instance_or_result.value()->updateKeyAndVersion(test_key, 0);
243
16
    RETURN_IF_NOT_OK_REF(result);
244
16
  }
245

            
246
16
  ENVOY_LOG_MISC(debug, "Configuring quic-lb with server_id length {} and nonce length {}",
247
16
                 server_id.length(), config.nonce_length_bytes());
248

            
249
16
  ret->tls_slot_ = ThreadLocal::TypedSlot<QuicLbConnectionIdGenerator::ThreadLocalData>::makeUnique(
250
16
      context.serverFactoryContext().threadLocal());
251

            
252
  // using InitializeCb = std::function<std::shared_ptr<T>(Event::Dispatcher & dispatcher)>;
253
16
  ret->tls_slot_->set(
254
26
      [=](Event::Dispatcher&) -> std::shared_ptr<QuicLbConnectionIdGenerator::ThreadLocalData> {
255
26
        auto result = QuicLbConnectionIdGenerator::ThreadLocalData::create(config, server_id);
256
26
        ASSERT(result.status().ok()); // Configuration was validated above.
257
26
        return result.value();
258
26
      });
259

            
260
16
  ret->secrets_provider_ = secretsProvider(config.encryption_parameters(),
261
16
                                           context.serverFactoryContext(), context.initManager());
262
16
  if (ret->secrets_provider_ == nullptr) {
263
1
    return absl::InvalidArgumentError("invalid encryption_parameters config");
264
1
  }
265

            
266
15
  ret->secrets_provider_validation_callback_handle_ = ret->secrets_provider_->addValidationCallback(
267
15
      [&api = context.serverFactoryContext().api()](
268
15
          const envoy::extensions::transport_sockets::tls::v3::GenericSecret& secret)
269
15
          -> absl::Status { return getAndValidateKeyAndVersion(secret, api).status(); });
270

            
271
15
  ret->secrets_provider_update_callback_handle_ = ret->secrets_provider_->addUpdateCallback(
272
15
      [&factory = *ret, &api = context.serverFactoryContext().api()]() -> absl::Status {
273
3
        return factory.updateSecret(api);
274
3
      });
275

            
276
15
  if (ret->secrets_provider_->secret()) {
277
11
    auto status = ret->updateSecret(context.serverFactoryContext().api());
278
11
    RETURN_IF_NOT_OK_REF(status);
279
6
  }
280

            
281
10
  return ret;
282
15
}
283

            
284
14
absl::Status Factory::updateSecret(Api::Api& api) {
285
14
  const envoy::extensions::transport_sockets::tls::v3::GenericSecret* secret =
286
14
      secrets_provider_->secret();
287
14
  if (secret == nullptr) {
288
1
    return absl::NotFoundError("secret update callback called with empty secret");
289
1
  }
290

            
291
13
  auto data_or_result = getAndValidateKeyAndVersion(*secret, api);
292

            
293
13
  RETURN_IF_NOT_OK_REF(data_or_result.status());
294

            
295
8
  tls_slot_->runOnAllThreads(
296
17
      [data = data_or_result.value()](OptRef<QuicLbConnectionIdGenerator::ThreadLocalData> obj) {
297
17
        ASSERT(obj.has_value(), "Guaranteed if `set()` was previously called on the tls slot");
298

            
299
17
        auto result = obj->updateKeyAndVersion(data.encryption_key_, data.configuration_version_);
300

            
301
        // Because all parameters were validated earlier, it should not be possible for this
302
        // to fail.
303
17
        ENVOY_BUG(result.ok(), "quic_lb unexpected error in updating configuration; old "
304
17
                               "configuration will still be used");
305
17
      });
306

            
307
8
  return absl::OkStatus();
308
13
}
309

            
310
23
QuicConnectionIdGeneratorPtr Factory::createQuicConnectionIdGenerator(uint32_t worker_id) {
311
23
  return std::make_unique<QuicLbConnectionIdGenerator>(*tls_slot_, worker_id);
312
23
}
313

            
314
Network::Socket::OptionConstSharedPtr
315
2
Factory::createCompatibleLinuxBpfSocketOption(uint32_t concurrency) {
316
2
#if defined(SO_ATTACH_REUSEPORT_CBPF) && defined(__linux__)
317
2
  filter_ = {
318
      // This was generated by running `./compile_bpf.sh` in this directory, using route.bpf.
319
2
      {0x80, 0, 0, 0000000000},  {0x07, 0, 0, 0000000000},   {0x35, 0, 21, 0x00000009},
320
2
      {0x30, 0, 0, 0000000000},  {0x54, 0, 0, 0x00000080},   {0x15, 0, 9, 0000000000},
321
2
      {0x30, 0, 0, 0x00000001},  {0x54, 0, 0, 0x0000001f},   {0x04, 0, 0, 0x00000003},
322
2
      {0x2d, 14, 0, 0000000000}, {0x14, 0, 0, 0x00000001},   {0x07, 0, 0, 0000000000},
323
2
      {0x50, 0, 0, 0000000000},  {0x35, 10, 0, concurrency}, {0x16, 0, 0, 0000000000},
324
2
      {0x80, 0, 0, 0000000000},  {0x35, 0, 7, 0x0000000e},   {0x30, 0, 0, 0x00000005},
325
2
      {0x04, 0, 0, 0x00000006},  {0x2d, 4, 0, 0000000000},   {0x14, 0, 0, 0x00000001},
326
2
      {0x07, 0, 0, 0000000000},  {0x50, 0, 0, 0000000000},   {0x05, 0, 0, 0x00000001},
327
2
      {0x20, 0, 0, 0xfffff020},  {0x94, 0, 0, concurrency},  {0x16, 0, 0, 0000000000},
328
2
  };
329

            
330
  // Note that this option refers to the BPF program data above, which must live until the
331
  // option is used. The program is kept as a member variable for this purpose.
332
2
  prog_.len = filter_.size();
333
2
  prog_.filter = filter_.data();
334
2
  return std::make_shared<Network::SocketOptionImpl>(
335
2
      envoy::config::core::v3::SocketOption::STATE_BOUND, ENVOY_ATTACH_REUSEPORT_CBPF,
336
2
      absl::string_view(reinterpret_cast<char*>(&prog_), sizeof(prog_)));
337
#else
338
  UNREFERENCED_PARAMETER(concurrency);
339
  return nullptr;
340
#endif
341
2
}
342

            
343
static uint32_t bpfEquivalentFunction(const Buffer::Instance& packet, uint8_t concurrency,
344
150
                                      uint32_t default_value) {
345
150
  const uint64_t packet_length = packet.length();
346
150
  if (packet_length < 9) {
347
1
    ENVOY_LOG_MISC(trace, "packet length < 9: {}", packet_length);
348
1
    return default_value;
349
1
  }
350

            
351
149
  uint8_t first_octet;
352
149
  packet.copyOut(0, sizeof(first_octet), &first_octet);
353

            
354
149
  if (first_octet & 0x80) {
355
    // IETF QUIC long header.
356
    // The connection id length is the 6th byte.
357
    // The connection id starts from 7th byte.
358
    // Minimum length of a long header packet is 14.
359
65
    constexpr size_t kCIDLenOffset = 5;
360
65
    constexpr size_t kCIDOffset = 6;
361

            
362
65
    if (packet_length < 14) {
363
1
      ENVOY_LOG_MISC(trace, "long header packet length less than 14: {}", packet_length);
364
1
      return default_value;
365
1
    }
366

            
367
64
    uint8_t connection_id_snippet;
368

            
369
64
    uint8_t id_len;
370
64
    packet.copyOut(kCIDLenOffset, sizeof(id_len), &id_len);
371

            
372
64
    if (packet_length < (kCIDOffset + id_len)) {
373
1
      ENVOY_LOG_MISC(trace, "long header packet {} length less than CID length {}", packet_length,
374
1
                     id_len);
375
1
      return default_value;
376
1
    }
377

            
378
63
    packet.copyOut(kCIDOffset + id_len - sizeof(connection_id_snippet),
379
63
                   sizeof(connection_id_snippet), &connection_id_snippet);
380
63
    ENVOY_LOG_MISC(trace, "long header for worker {}", connection_id_snippet % concurrency);
381
63
    return connection_id_snippet % concurrency;
382
84
  } else {
383
    // IETF QUIC short header.
384
    // The connection id starts from 2nd byte.
385
    // All short headers will have a CID generated by quic-lb.
386
84
    constexpr size_t kCIDOffset = 1;
387

            
388
84
    uint8_t config_version_and_length;
389
84
    packet.copyOut(kCIDOffset, sizeof(config_version_and_length), &config_version_and_length);
390

            
391
    // This length does not include the initial length byte or the trailing worker-id bytes.
392
84
    const uint8_t encrypted_cid_length = config_version_and_length & quic::kLoadBalancerLengthMask;
393

            
394
84
    const uint8_t worker_id_offset =
395
84
        kCIDOffset + sizeof(encrypted_cid_length) + encrypted_cid_length;
396

            
397
84
    WorkerRoutingIdValue worker_id;
398
84
    if (packet_length < (worker_id_offset + sizeof(worker_id))) {
399
1
      ENVOY_LOG_MISC(trace, "short header packet length {} shorter than encoded length {}",
400
1
                     packet_length, encrypted_cid_length);
401
1
      return default_value;
402
1
    }
403

            
404
83
    packet.copyOut(worker_id_offset, sizeof(worker_id), &worker_id);
405
83
    if (worker_id >= concurrency) {
406
1
      ENVOY_LOG_MISC(trace, "short header unexpected value {} >= {}", worker_id, concurrency);
407
1
      return default_value;
408
1
    }
409

            
410
82
    ENVOY_LOG_MISC(trace, "short header for worker {}", worker_id);
411
82
    return worker_id;
412
83
  }
413
149
}
414

            
415
QuicConnectionIdWorkerSelector
416
4
Factory::getCompatibleConnectionIdWorkerSelector(uint32_t concurrency) {
417
150
  return [concurrency](const Buffer::Instance& packet, uint32_t default_value) {
418
150
    return bpfEquivalentFunction(packet, concurrency, default_value);
419
150
  };
420
4
}
421

            
422
} // namespace QuicLb
423
} // namespace ConnectionIdGenerator
424
} // namespace Extensions
425
} // namespace Quic
426
} // namespace Envoy