TestAsyncPageTransportServlet.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.server;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.buffer.OutputBuffers.OutputBufferId;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertThrows;
import static org.testng.Assert.fail;
@Test(singleThreaded = true)
public class TestAsyncPageTransportServlet
{
class TestServlet
extends AsyncPageTransportServlet
{
TaskId taskId;
OutputBufferId bufferId;
String requestURI;
HttpServletRequest request;
long token;
void parse(String uri) throws IOException
{
parseURI(uri, null, null);
}
void parse(String uri, HttpServletRequest request) throws IOException
{
parseURI(uri, request, null);
}
@Override
protected void processRequest(
String requestURI, TaskId taskId, OutputBufferId bufferId, long token,
HttpServletRequest request, HttpServletResponse response)
{
this.requestURI = requestURI;
this.taskId = taskId;
this.bufferId = bufferId;
this.token = token;
this.request = request;
}
@Override
protected void reportFailure(HttpServletResponse response, String message)
{
throw new IllegalArgumentException(message);
}
}
private TestServlet parse(String str)
{
TestServlet servlet = new TestServlet();
try {
servlet.parse(str);
}
catch (IOException e) {
fail(e.getMessage());
}
return servlet;
}
@Test
public void testParsing()
{
TestServlet servlet = parse("/v1/task/async/0.1.2.3.4/results/456/789");
assertEquals("0.1.2.3.4", servlet.taskId.toString());
assertEquals("456", servlet.bufferId.toString());
assertEquals(789, servlet.token);
}
@DataProvider(name = "testSanitizationProvider")
public Object[][] testSanitizationProvider()
{
return new Object[][] {
{"ke\ny", "value"},
{"key", "valu\ne"},
{"ke\ry", "value"},
{"key", "valu\re"}};
}
@Test(dataProvider = "testSanitizationProvider")
public void testSanitization(String key, String value)
{
ListMultimap<String, String> headers = ImmutableListMultimap.of(key, value);
HttpServletRequest request = new MockHttpServletRequest(headers, "", ImmutableMap.of());
TestServlet servlet = new TestServlet();
assertThrows(
IllegalArgumentException.class,
() -> { servlet.parse("/v1/task/async/0.1.2.3.4/results/456/789", request); });
}
@Test (expectedExceptions = { IllegalArgumentException.class })
public void testParseTooFewElements()
{
parse("/v1/task/async/SomeQueryId.1.2.3.4/results/456");
}
@Test (expectedExceptions = { IllegalArgumentException.class })
public void testParseTooManyElements()
{
parse("/v1/task/async/SomeQueryId.1.2.3.4/results/456/789/foo");
}
}