1
#include "source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h"
2

            
3
#include <sstream>
4

            
5
#include "envoy/config/core/v3/proxy_protocol.pb.h"
6
#include "envoy/extensions/transport_sockets/proxy_protocol/v3/upstream_proxy_protocol.pb.h"
7
#include "envoy/extensions/transport_sockets/proxy_protocol/v3/upstream_proxy_protocol.pb.validate.h"
8
#include "envoy/network/transport_socket.h"
9

            
10
#include "source/common/buffer/buffer_impl.h"
11
#include "source/common/common/hex.h"
12
#include "source/common/common/scalar_to_byte_vector.h"
13
#include "source/common/common/utility.h"
14
#include "source/common/config/well_known_names.h"
15
#include "source/common/network/address_impl.h"
16
#include "source/common/protobuf/utility.h"
17
#include "source/extensions/common/proxy_protocol/proxy_protocol_header.h"
18

            
19
using envoy::config::core::v3::PerHostConfig;
20
using envoy::config::core::v3::ProxyProtocolConfig;
21
using envoy::config::core::v3::ProxyProtocolConfig_Version;
22
using envoy::config::core::v3::ProxyProtocolPassThroughTLVs;
23

            
24
namespace Envoy {
25
namespace Extensions {
26
namespace TransportSockets {
27
namespace ProxyProtocol {
28

            
29
UpstreamProxyProtocolSocket::UpstreamProxyProtocolSocket(
30
    Network::TransportSocketPtr&& transport_socket,
31
    Network::TransportSocketOptionsConstSharedPtr options, ProxyProtocolConfig config,
32
    const UpstreamProxyProtocolStats& stats)
33
43
    : PassthroughSocket(std::move(transport_socket)), options_(options), version_(config.version()),
34
43
      stats_(stats),
35
43
      pass_all_tlvs_(config.has_pass_through_tlvs() ? config.pass_through_tlvs().match_type() ==
36
15
                                                          ProxyProtocolPassThroughTLVs::INCLUDE_ALL
37
43
                                                    : false) {
38
43
  if (config.has_pass_through_tlvs() &&
39
43
      config.pass_through_tlvs().match_type() == ProxyProtocolPassThroughTLVs::INCLUDE) {
40
6
    for (const auto& tlv_type : config.pass_through_tlvs().tlv_type()) {
41
4
      pass_through_tlvs_.insert(0xFF & tlv_type);
42
4
    }
43
6
  }
44
43
  for (const auto& entry : config.added_tlvs()) {
45
11
    added_tlvs_.push_back(Network::ProxyProtocolTLV{
46
11
        static_cast<uint8_t>(entry.type()),
47
11
        std::vector<unsigned char>(entry.value().begin(), entry.value().end())});
48
11
  }
49
43
}
50

            
51
void UpstreamProxyProtocolSocket::setTransportSocketCallbacks(
52
43
    Network::TransportSocketCallbacks& callbacks) {
53
43
  transport_socket_->setTransportSocketCallbacks(callbacks);
54
43
  callbacks_ = &callbacks;
55
43
}
56

            
57
79
Network::IoResult UpstreamProxyProtocolSocket::doWrite(Buffer::Instance& buffer, bool end_stream) {
58
79
  if (header_buffer_.length() > 0) {
59
43
    auto header_res = writeHeader();
60
43
    if (header_buffer_.length() == 0 && header_res.action_ == Network::PostIoAction::KeepOpen) {
61
41
      auto inner_res = transport_socket_->doWrite(buffer, end_stream);
62
41
      return {inner_res.action_, header_res.bytes_processed_ + inner_res.bytes_processed_, false};
63
41
    }
64
2
    return header_res;
65
61
  } else {
66
36
    return transport_socket_->doWrite(buffer, end_stream);
67
36
  }
68
79
}
69

            
70
44
void UpstreamProxyProtocolSocket::generateHeader() {
71
44
  if (version_ == ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1) {
72
18
    generateHeaderV1();
73
28
  } else {
74
26
    generateHeaderV2();
75
26
  }
76
44
}
77

            
78
18
void UpstreamProxyProtocolSocket::generateHeaderV1() {
79
  // Default to local addresses. Used if no downstream connection exists or
80
  // downstream address info is not set e.g. health checks
81
18
  auto src_addr = callbacks_->connection().connectionInfoProvider().localAddress();
82
18
  auto dst_addr = callbacks_->connection().connectionInfoProvider().remoteAddress();
83

            
84
18
  if (options_ && options_->proxyProtocolOptions().has_value()) {
85
9
    const auto options = options_->proxyProtocolOptions().value();
86
9
    src_addr = options.src_addr_;
87
9
    dst_addr = options.dst_addr_;
88
9
  }
89

            
90
18
  Common::ProxyProtocol::generateV1Header(*src_addr->ip(), *dst_addr->ip(), header_buffer_);
91
18
}
92

            
93
namespace {
94
std::string toHex(const Buffer::Instance& buffer) {
95
  std::string bufferStr = buffer.toString();
96
  return Hex::encode(reinterpret_cast<uint8_t*>(bufferStr.data()), bufferStr.length());
97
}
98
} // namespace
99

            
100
26
void UpstreamProxyProtocolSocket::generateHeaderV2() {
101
26
  if (!options_ || !options_->proxyProtocolOptions().has_value()) {
102
2
    Common::ProxyProtocol::generateV2LocalHeader(header_buffer_);
103
24
  } else {
104
24
    std::vector<Envoy::Network::ProxyProtocolTLV> custom_tlvs = buildCustomTLVs();
105

            
106
24
    const auto options = options_->proxyProtocolOptions().value();
107
24
    if (!Common::ProxyProtocol::generateV2Header(options, header_buffer_, pass_all_tlvs_,
108
24
                                                 pass_through_tlvs_, custom_tlvs)) {
109
      // There is a warn log in generateV2Header method.
110
1
      stats_.v2_tlvs_exceed_max_length_.inc();
111
1
    }
112

            
113
24
    ENVOY_LOG(trace, "generated proxy protocol v2 header, length: {}, buffer: {}",
114
24
              header_buffer_.length(), toHex(header_buffer_));
115
24
  }
116
26
}
117

            
118
43
Network::IoResult UpstreamProxyProtocolSocket::writeHeader() {
119
43
  Network::PostIoAction action = Network::PostIoAction::KeepOpen;
120
43
  uint64_t bytes_written = 0;
121
84
  do {
122
84
    if (header_buffer_.length() == 0) {
123
41
      break;
124
41
    }
125

            
126
43
    Api::IoCallUint64Result result = callbacks_->ioHandle().write(header_buffer_);
127

            
128
43
    if (result.ok()) {
129
41
      ENVOY_CONN_LOG(trace, "write returns: {}", callbacks_->connection(), result.return_value_);
130
41
      bytes_written += result.return_value_;
131
41
    } else {
132
2
      ENVOY_CONN_LOG(trace, "write error: {}", callbacks_->connection(),
133
2
                     result.err_->getErrorDetails());
134
2
      if (result.err_->getErrorCode() != Api::IoError::IoErrorCode::Again) {
135
1
        action = Network::PostIoAction::Close;
136
1
      }
137
2
      break;
138
2
    }
139
43
  } while (true);
140

            
141
  return {action, bytes_written, false};
142
43
}
143

            
144
44
void UpstreamProxyProtocolSocket::onConnected() {
145
44
  generateHeader();
146
44
  transport_socket_->onConnected();
147
44
}
148

            
149
UpstreamProxyProtocolSocketFactory::UpstreamProxyProtocolSocketFactory(
150
    Network::UpstreamTransportSocketFactoryPtr transport_socket_factory, ProxyProtocolConfig config,
151
    Stats::Scope& scope)
152
18
    : PassthroughFactory(std::move(transport_socket_factory)), config_(config),
153
18
      stats_(generateUpstreamProxyProtocolStats(scope)) {}
154

            
155
Network::TransportSocketPtr UpstreamProxyProtocolSocketFactory::createTransportSocket(
156
    Network::TransportSocketOptionsConstSharedPtr options,
157
17
    Upstream::HostDescriptionConstSharedPtr host) const {
158
17
  auto inner_socket = transport_socket_factory_->createTransportSocket(options, host);
159
17
  if (inner_socket == nullptr) {
160
1
    return nullptr;
161
1
  }
162
16
  return std::make_unique<UpstreamProxyProtocolSocket>(std::move(inner_socket), options, config_,
163
16
                                                       stats_);
164
17
}
165

            
166
void UpstreamProxyProtocolSocketFactory::hashKey(
167
15
    std::vector<uint8_t>& key, Network::TransportSocketOptionsConstSharedPtr options) const {
168
15
  PassthroughFactory::hashKey(key, options);
169
  // Proxy protocol options should only be included in the hash if the upstream
170
  // socket intends to use them.
171
15
  if (options) {
172
15
    const auto& proxy_protocol_options = options->proxyProtocolOptions();
173
15
    if (proxy_protocol_options.has_value()) {
174
15
      pushScalarToByteVector(
175
15
          StringUtil::CaseInsensitiveHash()(proxy_protocol_options.value().asStringForHash()), key);
176
15
    }
177
15
  }
178
15
}
179

            
180
24
std::vector<Envoy::Network::ProxyProtocolTLV> UpstreamProxyProtocolSocket::buildCustomTLVs() const {
181
24
  std::vector<Envoy::Network::ProxyProtocolTLV> custom_tlvs;
182
24
  absl::flat_hash_set<uint8_t> host_level_tlv_types;
183

            
184
24
  const bool runtime_allow_duplicate_tlvs = Runtime::runtimeFeatureEnabled(
185
24
      "envoy.reloadable_features.proxy_protocol_allow_duplicate_tlvs");
186

            
187
  // Attempt to parse host-level TLVs first.
188
24
  const auto& upstream_info = callbacks_->connection().streamInfo().upstreamInfo();
189
24
  if (upstream_info && upstream_info->upstreamHost()) {
190
24
    auto metadata = upstream_info->upstreamHost()->metadata();
191
24
    if (metadata) {
192
8
      const auto filter_it = metadata->typed_filter_metadata().find(
193
8
          Envoy::Config::MetadataFilters::get().ENVOY_TRANSPORT_SOCKETS_PROXY_PROTOCOL);
194
8
      if (filter_it != metadata->typed_filter_metadata().end()) {
195
7
        PerHostConfig host_tlv_metadata;
196
7
        auto status = MessageUtil::unpackTo(filter_it->second, host_tlv_metadata);
197
7
        if (!status.ok()) {
198
1
          ENVOY_LOG(warn,
199
1
                    "Failed to unpack custom TLVs from upstream host metadata for host {}. "
200
1
                    "Error: {}. Will still use config-level TLVs.",
201
1
                    upstream_info->upstreamHost()->address()->asString(), status.message());
202
6
        } else {
203
          // Insert host-level TLVs.
204
6
          if (runtime_allow_duplicate_tlvs) {
205
9
            for (const auto& entry : host_tlv_metadata.added_tlvs()) {
206
9
              custom_tlvs.push_back(Network::ProxyProtocolTLV{
207
9
                  static_cast<uint8_t>(entry.type()),
208
9
                  std::vector<unsigned char>(entry.value().begin(), entry.value().end())});
209
9
              host_level_tlv_types.insert(entry.type());
210
9
            }
211
5
          } else {
212
2
            for (const auto& entry : host_tlv_metadata.added_tlvs()) {
213
2
              if (host_level_tlv_types.contains(entry.type())) {
214
1
                ENVOY_LOG_EVERY_POW_2_MISC(
215
1
                    info, "Skipping duplicate TLV type from host metadata {}", entry.type());
216
1
                continue;
217
1
              }
218
1
              custom_tlvs.push_back(Network::ProxyProtocolTLV{
219
1
                  static_cast<uint8_t>(entry.type()),
220
1
                  std::vector<unsigned char>(entry.value().begin(), entry.value().end())});
221
1
              host_level_tlv_types.insert(entry.type());
222
1
            }
223
1
          }
224
6
        }
225
7
      }
226
8
    }
227
24
  }
228

            
229
  // If host-level parse failed or was not present, we still read config-level TLVs.
230
24
  if (runtime_allow_duplicate_tlvs) {
231
23
    for (const auto& tlv : added_tlvs_) {
232
8
      if (!host_level_tlv_types.contains(tlv.type)) {
233
5
        custom_tlvs.push_back(tlv);
234
5
      }
235
8
    }
236
23
  } else {
237
3
    for (const auto& tlv : added_tlvs_) {
238
3
      if (host_level_tlv_types.contains(tlv.type)) {
239
1
        ENVOY_LOG_EVERY_POW_2_MISC(info, "Skipping duplicate TLV type from added_tlvs {}",
240
1
                                   tlv.type);
241
1
        continue;
242
1
      }
243
2
      custom_tlvs.push_back(tlv);
244
2
      host_level_tlv_types.insert(tlv.type);
245
2
    }
246
1
  }
247

            
248
24
  return custom_tlvs;
249
24
}
250

            
251
} // namespace ProxyProtocol
252
} // namespace TransportSockets
253
} // namespace Extensions
254
} // namespace Envoy