1
#pragma once
2

            
3
#include "envoy/config/core/v3/proxy_protocol.pb.h"
4
#include "envoy/network/connection.h"
5
#include "envoy/network/transport_socket.h"
6
#include "envoy/stats/stats.h"
7

            
8
#include "source/common/buffer/buffer_impl.h"
9
#include "source/common/common/logger.h"
10
#include "source/extensions/transport_sockets/common/passthrough.h"
11

            
12
using envoy::config::core::v3::ProxyProtocolConfig;
13
using envoy::config::core::v3::ProxyProtocolConfig_Version;
14

            
15
namespace Envoy {
16
namespace Extensions {
17
namespace TransportSockets {
18
namespace ProxyProtocol {
19

            
20
#define ALL_PROXY_PROTOCOL_TRANSPORT_SOCKET_STATS(COUNTER)                                         \
21
  /* Upstream events counter. */                                                                   \
22
45
  COUNTER(v2_tlvs_exceed_max_length)
23

            
24
/**
25
 * Wrapper struct for upstream ProxyProtocol stats. @see stats_macros.h
26
 */
27
struct UpstreamProxyProtocolStats {
28
  ALL_PROXY_PROTOCOL_TRANSPORT_SOCKET_STATS(GENERATE_COUNTER_STRUCT)
29
};
30

            
31
class UpstreamProxyProtocolSocket : public TransportSockets::PassthroughSocket,
32
                                    public Logger::Loggable<Logger::Id::connection> {
33
public:
34
  UpstreamProxyProtocolSocket(Network::TransportSocketPtr&& transport_socket,
35
                              Network::TransportSocketOptionsConstSharedPtr options,
36
                              ProxyProtocolConfig config, const UpstreamProxyProtocolStats& stats);
37

            
38
  void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override;
39
  Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override;
40
  void onConnected() override;
41

            
42
private:
43
  void generateHeader();
44
  void generateHeaderV1();
45
  void generateHeaderV2();
46
  // Combine host-level and config-level TLVs, with fallback if metadata fails to unpack.
47
  // Host-level has precedence over config-level TLVs.
48
  // If we fail to parse host metadata, we still read config TLVs.
49
  std::vector<Envoy::Network::ProxyProtocolTLV> buildCustomTLVs() const;
50
  Network::IoResult writeHeader();
51

            
52
  Network::TransportSocketOptionsConstSharedPtr options_;
53
  Network::TransportSocketCallbacks* callbacks_{};
54
  Buffer::OwnedImpl header_buffer_{};
55
  ProxyProtocolConfig_Version version_{ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1};
56
  const UpstreamProxyProtocolStats& stats_;
57
  const bool pass_all_tlvs_;
58
  absl::flat_hash_set<uint8_t> pass_through_tlvs_{};
59
  std::vector<Envoy::Network::ProxyProtocolTLV> added_tlvs_{};
60
};
61

            
62
class UpstreamProxyProtocolSocketFactory : public PassthroughFactory {
63
public:
64
  UpstreamProxyProtocolSocketFactory(
65
      Network::UpstreamTransportSocketFactoryPtr transport_socket_factory,
66
      ProxyProtocolConfig config, Stats::Scope& scope);
67

            
68
  // Network::UpstreamTransportSocketFactory
69
  Network::TransportSocketPtr
70
  createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options,
71
                        Upstream::HostDescriptionConstSharedPtr host) const override;
72
  void hashKey(std::vector<uint8_t>& key,
73
               Network::TransportSocketOptionsConstSharedPtr options) const override;
74

            
75
45
  static UpstreamProxyProtocolStats generateUpstreamProxyProtocolStats(Stats::Scope& stats_scope) {
76
45
    const char prefix[]{"upstream.proxyprotocol."};
77
45
    return {ALL_PROXY_PROTOCOL_TRANSPORT_SOCKET_STATS(POOL_COUNTER_PREFIX(stats_scope, prefix))};
78
45
  }
79

            
80
private:
81
  ProxyProtocolConfig config_;
82
  UpstreamProxyProtocolStats stats_;
83
};
84

            
85
} // namespace ProxyProtocol
86
} // namespace TransportSockets
87
} // namespace Extensions
88
} // namespace Envoy