1
#include "cilium/websocket_config.h"
2

            
3
#include <openssl/digest.h>
4
#include <openssl/sha.h>
5

            
6
#include <chrono>
7
#include <cstdint>
8
#include <string>
9
#include <vector>
10

            
11
#include "envoy/buffer/buffer.h"
12
#include "envoy/common/exception.h"
13
#include "envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.pb.h"
14
#include "envoy/extensions/request_id/uuid/v3/uuid.pb.h"
15
#include "envoy/server/factory_context.h"
16
#include "envoy/stats/stats_macros.h"
17

            
18
#include "source/common/buffer/buffer_impl.h"
19
#include "source/common/common/assert.h"
20
#include "source/common/common/base64.h"
21
#include "source/common/http/request_id_extension_impl.h"
22
#include "source/common/protobuf/utility.h"
23

            
24
#include "absl/strings/ascii.h"
25
#include "absl/strings/string_view.h"
26
#include "cilium/accesslog.h"
27
#include "cilium/api/accesslog.pb.h"
28
#include "cilium/api/websocket.pb.h"
29
#include "cilium/websocket_protocol.h"
30

            
31
namespace Envoy {
32
namespace Cilium {
33
namespace WebSocket {
34

            
35
23
std::vector<uint8_t> Config::getSha1Digest(const Buffer::Instance& buffer) {
36
23
  std::vector<uint8_t> digest(SHA_DIGEST_LENGTH);
37
23
  bssl::ScopedEVP_MD_CTX ctx;
38
23
  auto rc = EVP_DigestInit(ctx.get(), EVP_sha1());
39
23
  RELEASE_ASSERT(rc == 1, "Failed to init digest context");
40
23
  for (const auto& slice : buffer.getRawSlices()) {
41
23
    rc = EVP_DigestUpdate(ctx.get(), slice.mem_, slice.len_);
42
23
    RELEASE_ASSERT(rc == 1, "Failed to update digest");
43
23
  }
44
23
  rc = EVP_DigestFinal(ctx.get(), digest.data(), nullptr);
45
23
  RELEASE_ASSERT(rc == 1, "Failed to finalize digest");
46
23
  return digest;
47
23
}
48

            
49
Config::Config(Server::Configuration::FactoryContext& context, bool client,
50
               const std::string& access_log_path, const std::string& host, const std::string& path,
51
               const std::string& key, const std::string& version, const std::string& origin,
52
               const ProtobufWkt::Duration& handshake_timeout,
53
               const ProtobufWkt::Duration& ping_interval, bool ping_when_idle)
54
24
    : time_source_(context.serverFactoryContext().timeSource()),
55
24
      dispatcher_(context.serverFactoryContext().mainThreadDispatcher()),
56
24
      stats_{ALL_WEBSOCKET_STATS(POOL_COUNTER_PREFIX(context.scope(), "websocket"))},
57
24
      random_(context.serverFactoryContext().api().randomGenerator()), client_(client),
58
24
      host_(absl::AsciiStrToLower(host)), path_(absl::AsciiStrToLower(path)), key_(key),
59
24
      version_(absl::AsciiStrToLower(version)), origin_(absl::AsciiStrToLower(origin)),
60
24
      handshake_timeout_(std::chrono::seconds(5)), ping_interval_(std::chrono::milliseconds(0)),
61
24
      ping_when_idle_(ping_when_idle), access_log_(nullptr) {
62
24
  envoy::extensions::filters::network::http_connection_manager::v3::RequestIDExtension x_rid_config;
63
24
  x_rid_config.mutable_typed_config()->PackFrom(
64
24
      envoy::extensions::request_id::uuid::v3::UuidRequestIdConfig());
65
24
  auto extension_or_error = Http::RequestIDExtensionFactory::fromProto(x_rid_config, context);
66
24
  THROW_IF_NOT_OK_REF(extension_or_error.status());
67
24
  request_id_extension_ = extension_or_error.value();
68

            
69
  // Base64 encode the given/expected key, if any.
70
24
  if (!key_.empty()) {
71
    key_ = Base64::encode(key_.data(), key_.length());
72
  }
73

            
74
24
  if (!access_log_path.empty()) {
75
24
    access_log_ = AccessLog::open(access_log_path, time_source_);
76
24
  }
77

            
78
24
  const uint64_t timeout = DurationUtil::durationToMilliseconds(handshake_timeout);
79
24
  if (timeout > 0) {
80
    handshake_timeout_ = std::chrono::milliseconds(timeout);
81
  }
82

            
83
24
  const uint64_t interval = DurationUtil::durationToMilliseconds(ping_interval);
84
24
  if (interval > 0) {
85
11
    ping_interval_ = std::chrono::milliseconds(interval);
86
13
  } else if (ping_when_idle_) {
87
    throw EnvoyException(
88
        "cilium.network.websocket: ping_when_idle requires ping_interval to be set.");
89
  }
90
24
}
91

            
92
Config::Config(const ::cilium::WebSocketClient& config,
93
               Server::Configuration::FactoryContext& context)
94
11
    : Config(context, true /* client */, config.access_log_path(), config.host(), config.path(),
95
11
             config.key(), config.version(), config.origin(), config.handshake_timeout(),
96
11
             config.ping_interval(), config.ping_when_idle()) {
97
  // Client defaults
98
11
  if (host_.empty()) {
99
    throw EnvoyException("cilium.network.websocket.client: host must be non-empty.");
100
  }
101

            
102
11
  if (path_.empty()) {
103
11
    path_ = "/";
104
11
  }
105
11
  if (version_.empty()) {
106
11
    version_ = "13";
107
11
  }
108
11
  if (key_.empty()) {
109
11
    uint64_t random[2]; // 16 bytes
110
22
    for (unsigned long& i : random) {
111
22
      i = random_.random();
112
22
    }
113
11
    key_ = Base64::encode(reinterpret_cast<char*>(random), sizeof(random));
114
11
  }
115
11
}
116

            
117
Config::Config(const ::cilium::WebSocketServer& config,
118
               Server::Configuration::FactoryContext& context)
119
13
    : Config(context, false /* server */, config.access_log_path(), config.host(), config.path(),
120
13
             config.key(), config.version(), config.origin(), config.handshake_timeout(),
121
13
             config.ping_interval(), config.ping_when_idle()) {}
122

            
123
// Compute expected key response
124
23
std::string Config::keyResponse(absl::string_view key) {
125
23
  Buffer::OwnedImpl buf(key.data(), key.length());
126
23
  buf.add(WEBSOCKET_GUID);
127
23
  auto sha1 = getSha1Digest(buf);
128
23
  return Base64::encode(reinterpret_cast<char*>(sha1.data()), sha1.size());
129
23
}
130

            
131
48
void Config::log(AccessLog::Entry& entry, ::cilium::EntryType type) {
132
48
  if (access_log_) {
133
48
    access_log_->log(entry, type);
134
48
  }
135
48
}
136

            
137
} // namespace WebSocket
138
} // namespace Cilium
139
} // namespace Envoy