ProxyProtocolReadListener.java
package io.undertow.server.protocol.proxy;
import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.protocols.ssl.UndertowXnioSsl;
import io.undertow.server.DelegateOpenListener;
import io.undertow.server.OpenListener;
import io.undertow.util.NetworkUtils;
import io.undertow.util.PooledAdaptor;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.StreamConnection;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.PushBackStreamSourceConduit;
import org.xnio.ssl.SslConnection;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Implementation of version 1 and 2 of the proxy protocol (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)
* <p>
* Even though it is not required by the spec this implementation provides a stateful parser, that can handle
* fragmentation of
*
* @author Stuart Douglas
* @author Ulrich Herberg
*/
class ProxyProtocolReadListener implements ChannelListener<StreamSourceChannel> {
private static final int MAX_HEADER_LENGTH = 107;
private static final byte[] NAME = "PROXY ".getBytes(StandardCharsets.US_ASCII);
private static final String UNKNOWN = "UNKNOWN";
private static final String TCP4 = "TCP4";
private static final String TCP6 = "TCP6";
private static final byte[] SIG = new byte[] {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A};
private final StreamConnection streamConnection;
private final OpenListener openListener;
private final UndertowXnioSsl ssl;
private final ByteBufferPool bufferPool;
private final OptionMap sslOptionMap;
private final StringBuilder stringBuilder = new StringBuilder();
private int byteCount;
private String protocol;
private InetAddress sourceAddress;
private InetAddress destAddress;
private int sourcePort = -1;
private int destPort = -1;
private boolean carriageReturnSeen = false;
private boolean parsingUnknown = false;
ProxyProtocolReadListener(StreamConnection streamConnection, OpenListener openListener, UndertowXnioSsl ssl, ByteBufferPool bufferPool, OptionMap sslOptionMap) {
this.streamConnection = streamConnection;
this.openListener = openListener;
this.ssl = ssl;
this.bufferPool = bufferPool;
this.sslOptionMap = sslOptionMap;
if (bufferPool.getBufferSize() < MAX_HEADER_LENGTH) {
throw UndertowMessages.MESSAGES.bufferPoolTooSmall(MAX_HEADER_LENGTH);
}
}
@Override
public void handleEvent(StreamSourceChannel streamSourceChannel) {
PooledByteBuffer buffer = bufferPool.allocate();
AtomicBoolean freeBuffer = new AtomicBoolean(true);
try {
int res = streamSourceChannel.read(buffer.getBuffer());
if (res == -1) {
IoUtils.safeClose(streamConnection);
return;
} else if (res == 0) {
return;
} else {
buffer.getBuffer().flip();
if (buffer.getBuffer().hasRemaining()) {
byte firstByte = buffer.getBuffer().get(); // get first byte to determine whether Proxy Protocol V1 or V2 is used
byteCount++;
if (firstByte == SIG[0]) { // Could be Proxy Protocol V2
parseProxyProtocolV2(buffer, freeBuffer);
} else if ((char) firstByte == NAME[0]){ // Could be Proxy Protocol V1
parseProxyProtocolV1(buffer, freeBuffer);
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
}
return;
}
} catch (IOException e) {
UndertowLogger.REQUEST_IO_LOGGER.ioException(e);
IoUtils.safeClose(streamConnection);
} catch (Throwable e) {
UndertowLogger.REQUEST_IO_LOGGER.ioException(new IOException(e));
IoUtils.safeClose(streamConnection);
} finally {
if (freeBuffer.get()) {
buffer.close();
}
}
}
private void parseProxyProtocolV2(PooledByteBuffer buffer, AtomicBoolean freeBuffer) throws IOException {
while (byteCount < SIG.length) {
byte c = buffer.getBuffer().get();
//first we verify that we have the correct protocol
if (c != SIG[byteCount]) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
byteCount++;
}
byte ver_cmd = buffer.getBuffer().get();
byte fam = buffer.getBuffer().get();
int len = (buffer.getBuffer().getShort() & 0xffff);
if ((ver_cmd & 0xF0) != 0x20) { // expect version 2
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
switch (ver_cmd & 0x0F) {
case 0x01: // PROXY command
switch (fam) {
case 0x11: { // TCP over IPv4
if (len < 12) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
byte[] sourceAddressBytes = new byte[4];
buffer.getBuffer().get(sourceAddressBytes);
sourceAddress = InetAddress.getByAddress(sourceAddressBytes);
byte[] dstAddressBytes = new byte[4];
buffer.getBuffer().get(dstAddressBytes);
destAddress = InetAddress.getByAddress(dstAddressBytes);
sourcePort = buffer.getBuffer().getShort() & 0xffff;
destPort = buffer.getBuffer().getShort() & 0xffff;
if (len > 12) {
int skipAhead = len - 12;
int currentPosition = buffer.getBuffer().position();
buffer.getBuffer().position(currentPosition + skipAhead);
}
break;
}
case 0x21: { // TCP over IPv6
if (len < 36) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
byte[] sourceAddressBytes = new byte[16];
buffer.getBuffer().get(sourceAddressBytes);
sourceAddress = InetAddress.getByAddress(sourceAddressBytes);
byte[] dstAddressBytes = new byte[16];
buffer.getBuffer().get(dstAddressBytes);
destAddress = InetAddress.getByAddress(dstAddressBytes);
sourcePort = buffer.getBuffer().getShort() & 0xffff;
destPort = buffer.getBuffer().getShort() & 0xffff;
if (len > 36) {
int skipAhead = len - 36;
int currentPosition = buffer.getBuffer().position();
buffer.getBuffer().position(currentPosition + skipAhead);
}
break;
}
default: // AF_UNIX sockets not supported
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
break;
case 0x00: // LOCAL command
if (len > 0) {
int skipAhead = len;
int currentPosition = buffer.getBuffer().position();
buffer.getBuffer().position(currentPosition + skipAhead);
}
if (buffer.getBuffer().hasRemaining()) {
freeBuffer.set(false);
proxyAccept(null, null, buffer);
} else {
proxyAccept(null, null, null);
}
return;
default:
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
SocketAddress s = new InetSocketAddress(sourceAddress, sourcePort);
SocketAddress d = new InetSocketAddress(destAddress, destPort);
if (buffer.getBuffer().hasRemaining()) {
freeBuffer.set(false);
proxyAccept(s, d, buffer);
} else {
proxyAccept(s, d, null);
}
return;
}
private void parseProxyProtocolV1(PooledByteBuffer buffer, AtomicBoolean freeBuffer) throws IOException {
while (buffer.getBuffer().hasRemaining()) {
char c = (char) buffer.getBuffer().get();
if (byteCount < NAME.length) {
//first we verify that we have the correct protocol
if (c != NAME[byteCount]) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else {
if (parsingUnknown) {
//we are parsing the UNKNOWN protocol
//we just ignore everything till \r\n
if (c == '\r') {
carriageReturnSeen = true;
} else if (c == '\n') {
if (!carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
//we are done
if (buffer.getBuffer().hasRemaining()) {
freeBuffer.set(false);
proxyAccept(null, null, buffer);
} else {
proxyAccept(null, null, null);
}
return;
} else if (carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (carriageReturnSeen) {
if (c == '\n') {
//we are done
SocketAddress s = new InetSocketAddress(sourceAddress, sourcePort);
SocketAddress d = new InetSocketAddress(destAddress, destPort);
if (buffer.getBuffer().hasRemaining()) {
freeBuffer.set(false);
proxyAccept(s, d, buffer);
} else {
proxyAccept(s, d, null);
}
return;
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else switch (c) {
case ' ':
//we have a space
if (sourcePort != -1 || stringBuilder.length() == 0) {
//header was invalid, either we are expecting a \r or a \n, or the previous character was a space
throw UndertowMessages.MESSAGES.invalidProxyHeader();
} else if (protocol == null) {
protocol = stringBuilder.toString();
stringBuilder.setLength(0);
if (protocol.equals(UNKNOWN)) {
parsingUnknown = true;
} else if (!protocol.equals(TCP4) && !protocol.equals(TCP6)) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (sourceAddress == null) {
try {
sourceAddress = parseAddress(stringBuilder.toString(), protocol);
} finally {
stringBuilder.setLength(0);
}
} else if (destAddress == null) {
try {
destAddress = parseAddress(stringBuilder.toString(), protocol);
} finally {
stringBuilder.setLength(0);
}
} else {
try {
sourcePort = Integer.parseInt(stringBuilder.toString());
} finally {
stringBuilder.setLength(0);
}
}
break;
case '\r':
if (destPort == -1 && sourcePort != -1 && !carriageReturnSeen && stringBuilder.length() > 0) {
try {
destPort = Integer.parseInt(stringBuilder.toString());
} finally {
stringBuilder.setLength(0);
}
carriageReturnSeen = true;
} else if (protocol == null) {
if (UNKNOWN.equals(stringBuilder.toString())) {
parsingUnknown = true;
carriageReturnSeen = true;
}
stringBuilder.setLength(0);
} else {
stringBuilder.setLength(0);
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
break;
case '\n':
stringBuilder.setLength(0);
throw UndertowMessages.MESSAGES.invalidProxyHeader();
default:
stringBuilder.append(c);
}
}
byteCount++;
if (byteCount == MAX_HEADER_LENGTH) {
throw UndertowMessages.MESSAGES.headerSizeToLarge();
}
}
}
private void proxyAccept(SocketAddress source, SocketAddress dest, PooledByteBuffer additionalData) {
StreamConnection streamConnection = this.streamConnection;
if (source != null) {
streamConnection = new AddressWrappedConnection(streamConnection, source, dest);
}
if (ssl != null) {
//we need to apply the additional data before the SSL wrapping
if (additionalData != null) {
PushBackStreamSourceConduit conduit = new PushBackStreamSourceConduit(streamConnection.getSourceChannel().getConduit());
conduit.pushBack(new PooledAdaptor(additionalData));
streamConnection.getSourceChannel().setConduit(conduit);
}
SslConnection sslConnection = ssl.wrapExistingConnection(streamConnection, sslOptionMap == null ? OptionMap.EMPTY : sslOptionMap, false);
streamConnection = sslConnection;
callOpenListener(streamConnection, null);
} else {
callOpenListener(streamConnection, additionalData);
}
}
private void callOpenListener(StreamConnection streamConnection, final PooledByteBuffer buffer) {
if (openListener instanceof DelegateOpenListener) {
((DelegateOpenListener) openListener).handleEvent(streamConnection, buffer);
} else {
if (buffer != null) {
PushBackStreamSourceConduit conduit = new PushBackStreamSourceConduit(streamConnection.getSourceChannel().getConduit());
conduit.pushBack(new PooledAdaptor(buffer));
streamConnection.getSourceChannel().setConduit(conduit);
}
openListener.handleEvent(streamConnection);
}
}
static InetAddress parseAddress(String addressString, String protocol) throws IOException {
if (protocol.equals(TCP4)) {
return NetworkUtils.parseIpv4Address(addressString);
} else {
return NetworkUtils.parseIpv6Address(addressString);
}
}
private static final class AddressWrappedConnection extends StreamConnection {
private final StreamConnection delegate;
private final SocketAddress source;
private final SocketAddress dest;
AddressWrappedConnection(StreamConnection delegate, SocketAddress source, SocketAddress dest) {
super(delegate.getIoThread());
this.delegate = delegate;
this.source = source;
this.dest = dest;
setSinkConduit(delegate.getSinkChannel().getConduit());
setSourceConduit(delegate.getSourceChannel().getConduit());
}
@Override
protected void notifyWriteClosed() {
IoUtils.safeClose(delegate.getSinkChannel());
}
@Override
protected void notifyReadClosed() {
IoUtils.safeClose(delegate.getSourceChannel());
}
@Override
public SocketAddress getPeerAddress() {
return source;
}
@Override
public SocketAddress getLocalAddress() {
return dest;
}
}
}