TestClientRequestFilterPlugin.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto;
import com.facebook.airlift.http.server.Authenticator;
import com.facebook.presto.server.MockHttpServletRequest;
import com.facebook.presto.server.security.AuthenticationFilter;
import com.facebook.presto.server.security.SecurityConfig;
import com.facebook.presto.server.testing.TestingPrestoServer;
import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.ClientRequestFilterFactory;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;
import javax.servlet.http.HttpServletRequest;
import java.security.Principal;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.testng.Assert.assertEquals;
public class TestClientRequestFilterPlugin
{
@Test
public void testCustomRequestFilterWithHeaders() throws Exception
{
MockHttpServletRequest request = new MockHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header", "CustomValue"));
List<ClientRequestFilterFactory> requestFilterFactory = getClientRequestFilterFactory();
AuthenticationFilter filter = setupAuthenticationFilter(requestFilterFactory);
PrincipalStub testPrincipal = new PrincipalStub();
HttpServletRequest wrappedRequest = filter.mergeExtraHeaders(request, testPrincipal);
assertEquals("CustomValue", wrappedRequest.getHeader("X-Custom-Header"));
assertEquals("ExpectedExtraValue", wrappedRequest.getHeader("ExpectedExtraHeader"));
}
@Test(
expectedExceptions = RuntimeException.class,
expectedExceptionsMessageRegExp = "Modification attempt detected: The header X-Presto-Transaction-Id is not allowed to be modified. The following headers cannot be modified: " +
"X-Presto-Transaction-Id, X-Presto-Started-Transaction-Id, X-Presto-Clear-Transaction-Id, X-Presto-Trace-Token")
public void testCustomRequestFilterWithHeadersInBlockList() throws Exception
{
MockHttpServletRequest request = new MockHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header", "CustomValue"));
List<ClientRequestFilterFactory> requestFilterFactory = getClientRequestFilterInBlockList();
AuthenticationFilter filter = setupAuthenticationFilter(requestFilterFactory);
PrincipalStub testPrincipal = new PrincipalStub();
filter.mergeExtraHeaders(request, testPrincipal);
}
@Test(
expectedExceptions = RuntimeException.class,
expectedExceptionsMessageRegExp = "Header conflict detected: ExpectedExtraValue already added by another filter.")
public void testCustomRequestFilterHandlesConflict() throws Exception
{
MockHttpServletRequest request = new MockHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header", "CustomValue"));
List<ClientRequestFilterFactory> requestFilterFactory = getClientRequestFilterFactoryHandlesConflict();
AuthenticationFilter filter = setupAuthenticationFilter(requestFilterFactory);
PrincipalStub testPrincipal = new PrincipalStub();
filter.mergeExtraHeaders(request, testPrincipal);
}
private List<ClientRequestFilterFactory> getClientRequestFilterFactory()
{
return createFilterFactories(
new String[][] {
{"CustomModifier", "ExpectedExtraHeader", "ExpectedExtraValue"}
});
}
private List<ClientRequestFilterFactory> getClientRequestFilterInBlockList()
{
return createFilterFactories(
new String[][] {
{"BlockListModifier", "X-Presto-Transaction-Id", "CustomValue"}
});
}
private List<ClientRequestFilterFactory> getClientRequestFilterFactoryHandlesConflict()
{
return createFilterFactories(
new String[][] {
{"Filter1", "ExpectedExtraValue", "ExpectedExtraHeader_1"},
{"Filter2", "ExpectedExtraValue", "ExpectedExtraHeader_2"}
});
}
private AuthenticationFilter setupAuthenticationFilter(List<ClientRequestFilterFactory> requestFilterFactory) throws Exception
{
try (TestingPrestoServer testingPrestoServer = new TestingPrestoServer()) {
ClientRequestFilterManager clientRequestFilterManager = testingPrestoServer.getClientRequestFilterManager(requestFilterFactory);
List<Authenticator> authenticators = createAuthenticators();
SecurityConfig securityConfig = createSecurityConfig();
return new AuthenticationFilter(authenticators, securityConfig, clientRequestFilterManager);
}
}
private List<ClientRequestFilterFactory> createFilterFactories(String[][] filterConfigs)
{
ImmutableList.Builder<ClientRequestFilterFactory> factories = ImmutableList.builder();
for (String[] config : filterConfigs) {
factories.add(new GenericClientRequestFilterFactory(config[0], config[1], config[2]));
}
return factories.build();
}
private List<Authenticator> createAuthenticators()
{
return Collections.emptyList();
}
private SecurityConfig createSecurityConfig()
{
return new SecurityConfig() {
@Override
public boolean getAllowForwardedHttps()
{
return true;
}
};
}
static class GenericClientRequestFilterFactory
implements ClientRequestFilterFactory
{
private final String name;
private final String headerName;
private final String headerValue;
public GenericClientRequestFilterFactory(String name, String headerName, String headerValue)
{
this.name = name;
this.headerName = headerName;
this.headerValue = headerValue;
}
@Override
public String getName()
{
return name;
}
@Override
public ClientRequestFilter create()
{
return new CustomClientRequestFilter();
}
private class CustomClientRequestFilter
implements ClientRequestFilter
{
@Override
public Set<String> getExtraHeaderKeys()
{
return ImmutableSet.of(headerName);
}
@Override
public Map<String, String> getExtraHeaders(Principal principal)
{
return ImmutableMap.of(headerName, headerValue);
}
}
}
static class PrincipalStub
implements Principal
{
@Override
public String getName()
{
return "TestPrincipal";
}
}
}