AuthenticationFilter.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.server.security;
import com.facebook.airlift.http.server.AuthenticationException;
import com.facebook.airlift.http.server.Authenticator;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.PrestoException;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.net.HttpHeaders;
import javax.inject.Inject;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.security.Principal;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.facebook.presto.spi.StandardErrorCode.HEADER_MODIFICATION_ATTEMPT;
import static com.google.common.io.ByteStreams.copy;
import static com.google.common.io.ByteStreams.nullOutputStream;
import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE;
import static com.google.common.net.MediaType.PLAIN_TEXT_UTF_8;
import static java.util.Collections.enumeration;
import static java.util.Collections.list;
import static java.util.Objects.requireNonNull;
import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED;
public class AuthenticationFilter
implements Filter
{
private static final String HTTPS_PROTOCOL = "https";
private final List<Authenticator> authenticators;
private final boolean allowForwardedHttps;
private final ClientRequestFilterManager clientRequestFilterManager;
private final List<String> headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token");
@Inject
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager)
{
this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null"));
this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps();
this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null");
}
@Override
public void init(FilterConfig filterConfig) {}
@Override
public void destroy() {}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain nextFilter)
throws IOException, ServletException
{
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
// skip authentication if non-secure or not configured
if (!doesRequestSupportAuthentication(request)) {
nextFilter.doFilter(request, response);
return;
}
// try to authenticate, collecting errors and authentication headers
Set<String> messages = new LinkedHashSet<>();
Set<String> authenticateHeaders = new LinkedHashSet<>();
for (Authenticator authenticator : authenticators) {
Principal principal;
try {
principal = authenticator.authenticate(request);
}
catch (AuthenticationException e) {
if (e.getMessage() != null) {
messages.add(e.getMessage());
}
e.getAuthenticateHeader().ifPresent(authenticateHeaders::add);
continue;
}
// authentication succeeded
HttpServletRequest wrappedRequest = mergeExtraHeaders(request, principal);
nextFilter.doFilter(withPrincipal(wrappedRequest, principal), response);
return;
}
// authentication failed
skipRequestBody(request);
for (String value : authenticateHeaders) {
response.addHeader(WWW_AUTHENTICATE, value);
}
if (messages.isEmpty()) {
messages.add("Unauthorized");
}
// The error string is used by clients for exception messages and
// is presented to the end user, thus it should be a single line.
String error = Joiner.on(" | ").join(messages);
// Clients should use the response body rather than the HTTP status
// message (which does not exist with HTTP/2), but the status message
// still needs to be sent for compatibility with existing clients.
response.setStatus(SC_UNAUTHORIZED, error);
response.setContentType(PLAIN_TEXT_UTF_8.toString());
try (PrintWriter writer = response.getWriter()) {
writer.write(error);
}
}
public HttpServletRequest mergeExtraHeaders(HttpServletRequest request, Principal principal)
{
List<ClientRequestFilter> clientRequestFilters = clientRequestFilterManager.getClientRequestFilters();
if (clientRequestFilters.isEmpty()) {
return request;
}
ImmutableMap.Builder<String, String> extraHeadersMapBuilder = ImmutableMap.builder();
Set<String> addedHeaders = new HashSet<>();
for (ClientRequestFilter requestFilter : clientRequestFilters) {
boolean headersPresent = requestFilter.getExtraHeaderKeys().stream()
.allMatch(headerName -> request.getHeader(headerName) != null);
if (!headersPresent) {
Map<String, String> extraHeaderValueMap = requestFilter.getExtraHeaders(principal);
if (!extraHeaderValueMap.isEmpty()) {
for (Map.Entry<String, String> extraHeaderEntry : extraHeaderValueMap.entrySet()) {
String headerKey = extraHeaderEntry.getKey();
if (headersBlockList.contains(headerKey)) {
throw new PrestoException(HEADER_MODIFICATION_ATTEMPT,
"Modification attempt detected: The header " + headerKey + " is not allowed to be modified. The following headers cannot be modified: " +
String.join(", ", headersBlockList));
}
if (addedHeaders.contains(headerKey)) {
throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, "Header conflict detected: " + headerKey + " already added by another filter.");
}
if (request.getHeader(headerKey) == null && requestFilter.getExtraHeaderKeys().contains(headerKey)) {
extraHeadersMapBuilder.put(headerKey, extraHeaderEntry.getValue());
addedHeaders.add(headerKey);
}
}
}
}
}
return new ModifiedHttpServletRequest(request, extraHeadersMapBuilder.build());
}
private boolean doesRequestSupportAuthentication(HttpServletRequest request)
{
if (authenticators.isEmpty()) {
return false;
}
if (request.isSecure()) {
return true;
}
if (allowForwardedHttps) {
return Strings.nullToEmpty(request.getHeader(HttpHeaders.X_FORWARDED_PROTO)).equalsIgnoreCase(HTTPS_PROTOCOL);
}
return false;
}
private static ServletRequest withPrincipal(HttpServletRequest request, Principal principal)
{
requireNonNull(principal, "principal is null");
return new HttpServletRequestWrapper(request)
{
@Override
public Principal getUserPrincipal()
{
return principal;
}
};
}
private static void skipRequestBody(HttpServletRequest request)
throws IOException
{
// If we send the challenge without consuming the body of the request,
// the server will close the connection after sending the response.
// The client may interpret this as a failed request and not resend the
// request with the authentication header. We can avoid this behavior
// in the client by reading and discarding the entire body of the
// unauthenticated request before sending the response.
try (InputStream inputStream = request.getInputStream()) {
copy(inputStream, nullOutputStream());
}
}
public static class ModifiedHttpServletRequest
extends HttpServletRequestWrapper
{
private final Map<String, String> customHeaders;
public ModifiedHttpServletRequest(HttpServletRequest request, Map<String, String> headers)
{
super(request);
this.customHeaders = ImmutableMap.copyOf(requireNonNull(headers, "headers is null"));
}
@Override
public String getHeader(String name)
{
if (customHeaders.containsKey(name)) {
return customHeaders.get(name);
}
return super.getHeader(name);
}
@Override
public Enumeration<String> getHeaderNames()
{
return enumeration(ImmutableSet.<String>builder()
.addAll(customHeaders.keySet())
.addAll(list(super.getHeaderNames()))
.build());
}
@Override
public Enumeration<String> getHeaders(String name)
{
if (customHeaders.containsKey(name)) {
return enumeration(ImmutableList.of(customHeaders.get(name)));
}
return super.getHeaders(name);
}
}
}