Line data Source code
1 : #include "source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h" 2 : 3 : #include <sstream> 4 : 5 : #include "envoy/config/core/v3/proxy_protocol.pb.h" 6 : #include "envoy/network/transport_socket.h" 7 : 8 : #include "source/common/buffer/buffer_impl.h" 9 : #include "source/common/common/hex.h" 10 : #include "source/common/common/scalar_to_byte_vector.h" 11 : #include "source/common/common/utility.h" 12 : #include "source/common/network/address_impl.h" 13 : #include "source/extensions/common/proxy_protocol/proxy_protocol_header.h" 14 : 15 : using envoy::config::core::v3::ProxyProtocolConfig; 16 : using envoy::config::core::v3::ProxyProtocolConfig_Version; 17 : using envoy::config::core::v3::ProxyProtocolPassThroughTLVs; 18 : 19 : namespace Envoy { 20 : namespace Extensions { 21 : namespace TransportSockets { 22 : namespace ProxyProtocol { 23 : 24 0 : UpstreamProxyProtocolStats generateUpstreamProxyProtocolStats(Stats::Scope& stats_scope) { 25 0 : const char prefix[]{"upstream.proxyprotocol."}; 26 0 : return {ALL_PROXY_PROTOCOL_TRANSPORT_SOCKET_STATS(POOL_COUNTER_PREFIX(stats_scope, prefix))}; 27 0 : } 28 : 29 : UpstreamProxyProtocolSocket::UpstreamProxyProtocolSocket( 30 : Network::TransportSocketPtr&& transport_socket, 31 : Network::TransportSocketOptionsConstSharedPtr options, ProxyProtocolConfig config, 32 : Stats::Scope& scope) 33 : : PassthroughSocket(std::move(transport_socket)), options_(options), version_(config.version()), 34 : stats_(generateUpstreamProxyProtocolStats(scope)), 35 : pass_all_tlvs_(config.has_pass_through_tlvs() ? config.pass_through_tlvs().match_type() == 36 : ProxyProtocolPassThroughTLVs::INCLUDE_ALL 37 0 : : false) { 38 0 : if (config.has_pass_through_tlvs() && 39 0 : config.pass_through_tlvs().match_type() == ProxyProtocolPassThroughTLVs::INCLUDE) { 40 0 : for (const auto& tlv_type : config.pass_through_tlvs().tlv_type()) { 41 0 : pass_through_tlvs_.insert(0xFF & tlv_type); 42 0 : } 43 0 : } 44 0 : } 45 : 46 : void UpstreamProxyProtocolSocket::setTransportSocketCallbacks( 47 0 : Network::TransportSocketCallbacks& callbacks) { 48 0 : transport_socket_->setTransportSocketCallbacks(callbacks); 49 0 : callbacks_ = &callbacks; 50 0 : } 51 : 52 0 : Network::IoResult UpstreamProxyProtocolSocket::doWrite(Buffer::Instance& buffer, bool end_stream) { 53 0 : if (header_buffer_.length() > 0) { 54 0 : auto header_res = writeHeader(); 55 0 : if (header_buffer_.length() == 0 && header_res.action_ == Network::PostIoAction::KeepOpen) { 56 0 : auto inner_res = transport_socket_->doWrite(buffer, end_stream); 57 0 : return {inner_res.action_, header_res.bytes_processed_ + inner_res.bytes_processed_, false}; 58 0 : } 59 0 : return header_res; 60 0 : } else { 61 0 : return transport_socket_->doWrite(buffer, end_stream); 62 0 : } 63 0 : } 64 : 65 0 : void UpstreamProxyProtocolSocket::generateHeader() { 66 0 : if (version_ == ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1) { 67 0 : generateHeaderV1(); 68 0 : } else { 69 0 : generateHeaderV2(); 70 0 : } 71 0 : } 72 : 73 0 : void UpstreamProxyProtocolSocket::generateHeaderV1() { 74 : // Default to local addresses. Used if no downstream connection exists or 75 : // downstream address info is not set e.g. health checks 76 0 : auto src_addr = callbacks_->connection().connectionInfoProvider().localAddress(); 77 0 : auto dst_addr = callbacks_->connection().connectionInfoProvider().remoteAddress(); 78 : 79 0 : if (options_ && options_->proxyProtocolOptions().has_value()) { 80 0 : const auto options = options_->proxyProtocolOptions().value(); 81 0 : src_addr = options.src_addr_; 82 0 : dst_addr = options.dst_addr_; 83 0 : } 84 : 85 0 : Common::ProxyProtocol::generateV1Header(*src_addr->ip(), *dst_addr->ip(), header_buffer_); 86 0 : } 87 : 88 : namespace { 89 0 : std::string toHex(const Buffer::Instance& buffer) { 90 0 : std::string bufferStr = buffer.toString(); 91 0 : return Hex::encode(reinterpret_cast<uint8_t*>(bufferStr.data()), bufferStr.length()); 92 0 : } 93 : } // namespace 94 : 95 0 : void UpstreamProxyProtocolSocket::generateHeaderV2() { 96 0 : if (!options_ || !options_->proxyProtocolOptions().has_value()) { 97 0 : Common::ProxyProtocol::generateV2LocalHeader(header_buffer_); 98 0 : } else { 99 0 : const auto options = options_->proxyProtocolOptions().value(); 100 0 : if (!Common::ProxyProtocol::generateV2Header(options, header_buffer_, pass_all_tlvs_, 101 0 : pass_through_tlvs_)) { 102 : // There is a warn log in generateV2Header method. 103 0 : stats_.v2_tlvs_exceed_max_length_.inc(); 104 0 : } 105 : 106 0 : ENVOY_LOG(trace, "generated proxy protocol v2 header, length: {}, buffer: {}", 107 0 : header_buffer_.length(), toHex(header_buffer_)); 108 0 : } 109 0 : } 110 : 111 0 : Network::IoResult UpstreamProxyProtocolSocket::writeHeader() { 112 0 : Network::PostIoAction action = Network::PostIoAction::KeepOpen; 113 0 : uint64_t bytes_written = 0; 114 0 : do { 115 0 : if (header_buffer_.length() == 0) { 116 0 : break; 117 0 : } 118 : 119 0 : Api::IoCallUint64Result result = callbacks_->ioHandle().write(header_buffer_); 120 : 121 0 : if (result.ok()) { 122 0 : ENVOY_CONN_LOG(trace, "write returns: {}", callbacks_->connection(), result.return_value_); 123 0 : bytes_written += result.return_value_; 124 0 : } else { 125 0 : ENVOY_CONN_LOG(trace, "write error: {}", callbacks_->connection(), 126 0 : result.err_->getErrorDetails()); 127 0 : if (result.err_->getErrorCode() != Api::IoError::IoErrorCode::Again) { 128 0 : action = Network::PostIoAction::Close; 129 0 : } 130 0 : break; 131 0 : } 132 0 : } while (true); 133 : 134 0 : return {action, bytes_written, false}; 135 0 : } 136 : 137 0 : void UpstreamProxyProtocolSocket::onConnected() { 138 0 : generateHeader(); 139 0 : transport_socket_->onConnected(); 140 0 : } 141 : 142 : UpstreamProxyProtocolSocketFactory::UpstreamProxyProtocolSocketFactory( 143 : Network::UpstreamTransportSocketFactoryPtr transport_socket_factory, ProxyProtocolConfig config, 144 : Stats::Scope& scope) 145 0 : : PassthroughFactory(std::move(transport_socket_factory)), config_(config), scope_(scope) {} 146 : 147 : Network::TransportSocketPtr UpstreamProxyProtocolSocketFactory::createTransportSocket( 148 : Network::TransportSocketOptionsConstSharedPtr options, 149 0 : Upstream::HostDescriptionConstSharedPtr host) const { 150 0 : auto inner_socket = transport_socket_factory_->createTransportSocket(options, host); 151 0 : if (inner_socket == nullptr) { 152 0 : return nullptr; 153 0 : } 154 0 : return std::make_unique<UpstreamProxyProtocolSocket>(std::move(inner_socket), options, config_, 155 0 : scope_); 156 0 : } 157 : 158 : void UpstreamProxyProtocolSocketFactory::hashKey( 159 0 : std::vector<uint8_t>& key, Network::TransportSocketOptionsConstSharedPtr options) const { 160 0 : PassthroughFactory::hashKey(key, options); 161 : // Proxy protocol options should only be included in the hash if the upstream 162 : // socket intends to use them. 163 0 : if (options) { 164 0 : const auto& proxy_protocol_options = options->proxyProtocolOptions(); 165 0 : if (proxy_protocol_options.has_value()) { 166 0 : pushScalarToByteVector( 167 0 : StringUtil::CaseInsensitiveHash()(proxy_protocol_options.value().asStringForHash()), key); 168 0 : } 169 0 : } 170 0 : } 171 : 172 : } // namespace ProxyProtocol 173 : } // namespace TransportSockets 174 : } // namespace Extensions 175 : } // namespace Envoy