1
#include "source/extensions/common/proxy_protocol/proxy_protocol_header.h"
2

            
3
#include <sstream>
4

            
5
#include "envoy/buffer/buffer.h"
6
#include "envoy/network/address.h"
7

            
8
#include "source/common/network/address_impl.h"
9
#include "source/common/runtime/runtime_features.h"
10

            
11
namespace Envoy {
12
namespace Extensions {
13
namespace Common {
14
namespace ProxyProtocol {
15

            
16
void generateV1Header(const std::string& src_addr, const std::string& dst_addr, uint32_t src_port,
17
                      uint32_t dst_port, Network::Address::IpVersion ip_version,
18
38
                      Buffer::Instance& out) {
19
38
  std::ostringstream stream;
20
38
  stream << PROXY_PROTO_V1_SIGNATURE;
21

            
22
38
  switch (ip_version) {
23
33
  case Network::Address::IpVersion::v4:
24
33
    stream << PROXY_PROTO_V1_AF_INET << " ";
25
33
    break;
26
5
  case Network::Address::IpVersion::v6:
27
5
    stream << PROXY_PROTO_V1_AF_INET6 << " ";
28
5
    break;
29
38
  }
30

            
31
38
  stream << src_addr << " ";
32
38
  stream << dst_addr << " ";
33
38
  stream << src_port << " ";
34
38
  stream << dst_port << "\r\n";
35

            
36
38
  out.add(stream.str());
37
38
}
38

            
39
void generateV1Header(const Network::Address::Ip& source_address,
40
26
                      const Network::Address::Ip& dest_address, Buffer::Instance& out) {
41
26
  generateV1Header(source_address.addressAsString(), dest_address.addressAsString(),
42
26
                   source_address.port(), dest_address.port(), source_address.version(), out);
43
26
}
44

            
45
void generateV2Header(const std::string& src_addr, const std::string& dst_addr, uint32_t src_port,
46
                      uint32_t dst_port, Network::Address::IpVersion ip_version,
47
55
                      uint16_t extension_length, Buffer::Instance& out) {
48
55
  out.add(PROXY_PROTO_V2_SIGNATURE, PROXY_PROTO_V2_SIGNATURE_LEN);
49

            
50
55
  const uint8_t version_and_command = PROXY_PROTO_V2_VERSION << 4 | PROXY_PROTO_V2_ONBEHALF_OF;
51
55
  out.add(&version_and_command, 1);
52

            
53
55
  uint8_t address_family_and_protocol;
54
55
  switch (ip_version) {
55
31
  case Network::Address::IpVersion::v4:
56
31
    address_family_and_protocol = PROXY_PROTO_V2_AF_INET << 4;
57
31
    break;
58
24
  case Network::Address::IpVersion::v6:
59
24
    address_family_and_protocol = PROXY_PROTO_V2_AF_INET6 << 4;
60
24
    break;
61
55
  }
62
55
  address_family_and_protocol |= PROXY_PROTO_V2_TRANSPORT_STREAM;
63
55
  out.add(&address_family_and_protocol, 1);
64

            
65
  // Number of following bytes part of the header in V2 protocol.
66
55
  uint16_t addr_length;
67
55
  uint16_t addr_length_n; // Network byte order
68

            
69
55
  switch (ip_version) {
70
31
  case Network::Address::IpVersion::v4: {
71
31
    addr_length = PROXY_PROTO_V2_ADDR_LEN_INET + extension_length;
72
31
    addr_length_n = htons(addr_length);
73
31
    out.add(&addr_length_n, 2);
74
31
    const uint32_t net_src_addr =
75
31
        Network::Address::Ipv4Instance(src_addr, src_port).ip()->ipv4()->address();
76
31
    const uint32_t net_dst_addr =
77
31
        Network::Address::Ipv4Instance(dst_addr, dst_port).ip()->ipv4()->address();
78
31
    out.add(&net_src_addr, 4);
79
31
    out.add(&net_dst_addr, 4);
80
31
    break;
81
  }
82
24
  case Network::Address::IpVersion::v6: {
83
24
    addr_length = PROXY_PROTO_V2_ADDR_LEN_INET6 + extension_length;
84
24
    addr_length_n = htons(addr_length);
85
24
    out.add(&addr_length_n, 2);
86
24
    const absl::uint128 net_src_addr =
87
24
        Network::Address::Ipv6Instance(src_addr, src_port).ip()->ipv6()->address();
88
24
    const absl::uint128 net_dst_addr =
89
24
        Network::Address::Ipv6Instance(dst_addr, dst_port).ip()->ipv6()->address();
90
24
    out.add(&net_src_addr, 16);
91
24
    out.add(&net_dst_addr, 16);
92
24
    break;
93
  }
94
55
  }
95

            
96
55
  const uint16_t net_src_port = htons(static_cast<uint16_t>(src_port));
97
55
  const uint16_t net_dst_port = htons(static_cast<uint16_t>(dst_port));
98
55
  out.add(&net_src_port, 2);
99
55
  out.add(&net_dst_port, 2);
100
55
}
101

            
102
void generateV2Header(const std::string& src_addr, const std::string& dst_addr, uint32_t src_port,
103
                      uint32_t dst_port, Network::Address::IpVersion ip_version,
104
5
                      Buffer::Instance& out) {
105
5
  generateV2Header(src_addr, dst_addr, src_port, dst_port, ip_version, 0, out);
106
5
}
107

            
108
void generateV2Header(const Network::Address::Ip& source_address,
109
7
                      const Network::Address::Ip& dest_address, Buffer::Instance& out) {
110
7
  generateV2Header(source_address.addressAsString(), dest_address.addressAsString(),
111
7
                   source_address.port(), dest_address.port(), source_address.version(), 0, out);
112
7
}
113

            
114
bool generateV2Header(const Network::ProxyProtocolData& proxy_proto_data, Buffer::Instance& out,
115
                      bool pass_all_tlvs, const absl::flat_hash_set<uint8_t>& pass_through_tlvs,
116
43
                      const std::vector<Envoy::Network::ProxyProtocolTLV>& custom_tlvs) {
117
43
  std::vector<Envoy::Network::ProxyProtocolTLV> combined_tlv_vector;
118
43
  std::vector<Envoy::Network::ProxyProtocolTLV> final_tlvs;
119
43
  combined_tlv_vector.reserve(custom_tlvs.size() + proxy_proto_data.tlv_vector_.size());
120

            
121
43
  if (Runtime::runtimeFeatureEnabled(
122
43
          "envoy.reloadable_features.proxy_protocol_allow_duplicate_tlvs")) {
123
41
    absl::flat_hash_set<uint8_t> config_specified_types;
124
41
    for (const auto& tlv : custom_tlvs) {
125
20
      combined_tlv_vector.emplace_back(tlv);
126
20
      config_specified_types.insert(tlv.type);
127
20
    }
128

            
129
    // Combine TLVs from the proxy_proto_data with the custom TLVs.
130
41
    for (const auto& tlv : proxy_proto_data.tlv_vector_) {
131
41
      if (!pass_all_tlvs && !pass_through_tlvs.contains(tlv.type)) {
132
        // Skip any TLV that is not in the set of passthrough TLVs.
133
12
        continue;
134
12
      }
135
29
      if (!config_specified_types.contains(tlv.type)) {
136
23
        combined_tlv_vector.emplace_back(tlv);
137
23
      }
138
29
    }
139
41
  } else {
140
2
    absl::flat_hash_set<uint8_t> seen_types;
141
6
    for (const auto& tlv : custom_tlvs) {
142
6
      ASSERT(!seen_types.contains(tlv.type));
143
6
      combined_tlv_vector.emplace_back(tlv);
144
6
      seen_types.insert(tlv.type);
145
6
    }
146

            
147
    // Combine TLVs from the proxy_proto_data with the custom TLVs.
148
6
    for (const auto& tlv : proxy_proto_data.tlv_vector_) {
149
6
      if (!pass_all_tlvs && !pass_through_tlvs.contains(tlv.type)) {
150
        // Skip any TLV that is not in the set of passthrough TLVs.
151
2
        continue;
152
2
      }
153
4
      if (seen_types.contains(tlv.type)) {
154
        // Skip any duplicate TLVs from being added to the combined TLV vector.
155
2
        ENVOY_LOG_EVERY_POW_2_MISC(info, "Skipping duplicate TLV type {}", tlv.type);
156
2
        continue;
157
2
      }
158
2
      seen_types.insert(tlv.type);
159
2
      combined_tlv_vector.emplace_back(tlv);
160
2
    }
161
2
  }
162

            
163
  // Filter out TLVs that would exceed the 65535 limit.
164
43
  uint64_t extension_length = 0;
165
43
  bool skipped_tlvs = false;
166
51
  for (auto&& tlv : combined_tlv_vector) {
167
51
    uint64_t new_size = extension_length + PROXY_PROTO_V2_TLV_TYPE_LENGTH_LEN + tlv.value.size();
168
51
    if (new_size > std::numeric_limits<uint16_t>::max()) {
169
3
      ENVOY_LOG_MISC(warn, "Skipping TLV type {} because adding it would exceed the 65535 limit.",
170
3
                     tlv.type);
171
3
      skipped_tlvs = true;
172
3
      continue;
173
3
    }
174
48
    extension_length = new_size;
175
48
    final_tlvs.push_back(tlv);
176
48
  }
177

            
178
43
  ASSERT(extension_length <= std::numeric_limits<uint16_t>::max());
179
43
  if (proxy_proto_data.src_addr_ == nullptr || proxy_proto_data.src_addr_->ip() == nullptr) {
180
    IS_ENVOY_BUG("Missing or incorrect source IP in proxy_proto_data_");
181
    return false;
182
  }
183
43
  if (proxy_proto_data.dst_addr_ == nullptr || proxy_proto_data.dst_addr_->ip() == nullptr) {
184
    IS_ENVOY_BUG("Missing or incorrect dest IP in proxy_proto_data_");
185
    return false;
186
  }
187

            
188
43
  const auto& src = *proxy_proto_data.src_addr_->ip();
189
43
  const auto& dst = *proxy_proto_data.dst_addr_->ip();
190
43
  generateV2Header(src.addressAsString(), dst.addressAsString(), src.port(), dst.port(),
191
43
                   src.version(), static_cast<uint16_t>(extension_length), out);
192

            
193
51
  for (auto&& tlv : combined_tlv_vector) {
194
51
    out.add(&tlv.type, 1);
195
51
    uint16_t size = htons(static_cast<uint16_t>(tlv.value.size()));
196
51
    out.add(&size, sizeof(uint16_t));
197
51
    out.add(&tlv.value.front(), tlv.value.size());
198
51
  }
199

            
200
  // return true if no TLVs were skipped, otherwise false to increment the counter
201
  // in the upstream proxy protocol transport socket stats.
202
43
  return !skipped_tlvs;
203
43
}
204

            
205
void generateProxyProtoHeader(const envoy::config::core::v3::ProxyProtocolConfig& config,
206
14
                              const Network::Connection& connection, Buffer::Instance& out) {
207
14
  const Network::Address::Ip& dest_address =
208
14
      *connection.connectionInfoProvider().localAddress()->ip();
209
14
  const Network::Address::Ip& source_address =
210
14
      *connection.connectionInfoProvider().remoteAddress()->ip();
211
14
  if (config.version() == envoy::config::core::v3::ProxyProtocolConfig::V1) {
212
7
    generateV1Header(source_address, dest_address, out);
213
11
  } else if (config.version() == envoy::config::core::v3::ProxyProtocolConfig::V2) {
214
7
    generateV2Header(source_address, dest_address, out);
215
7
  }
216
14
}
217

            
218
8
void generateV2LocalHeader(Buffer::Instance& out) {
219
8
  out.add(PROXY_PROTO_V2_SIGNATURE, PROXY_PROTO_V2_SIGNATURE_LEN);
220
8
  const uint8_t addr_fam_protocol_and_length[4]{PROXY_PROTO_V2_VERSION << 4, 0, 0, 0};
221
8
  out.add(addr_fam_protocol_and_length, 4);
222
8
}
223

            
224
} // namespace ProxyProtocol
225
} // namespace Common
226
} // namespace Extensions
227
} // namespace Envoy