IcebergRestCatalogServlet.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.iceberg.rest;
import com.facebook.airlift.log.Logger;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.iceberg.exceptions.RESTException;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.io.CharStreams;
import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod;
import org.apache.iceberg.rest.RESTCatalogAdapter.Route;
import org.apache.iceberg.rest.responses.ErrorResponse;
import org.apache.iceberg.util.Pair;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.lang.String.format;
import static javax.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR;
/**
* The IcebergRestCatalogServlet provides a servlet implementation used in combination with a
* RESTCatalogAdaptor to proxy the REST Spec to any Catalog implementation.
*/
public class IcebergRestCatalogServlet
extends HttpServlet
{
private static final Logger LOG = Logger.get(IcebergRestCatalogServlet.class);
private final RESTCatalogAdapter restCatalogAdapter;
private final Map<String, String> responseHeaders =
ImmutableMap.of(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
public IcebergRestCatalogServlet(RESTCatalogAdapter restCatalogAdapter)
{
this.restCatalogAdapter = restCatalogAdapter;
}
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws IOException
{
execute(ServletRequestContext.from(request), response);
}
@Override
protected void doHead(HttpServletRequest request, HttpServletResponse response)
throws IOException
{
execute(ServletRequestContext.from(request), response);
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws IOException
{
execute(ServletRequestContext.from(request), response);
}
@Override
protected void doDelete(HttpServletRequest request, HttpServletResponse response)
throws IOException
{
execute(ServletRequestContext.from(request), response);
}
protected void execute(ServletRequestContext context, HttpServletResponse response)
throws IOException
{
response.setStatus(HttpServletResponse.SC_OK);
responseHeaders.forEach(response::setHeader);
String token = context.headers.get("Authorization");
if (token != null && isRestUserSessionToken(token) && !isAuthorizedRestUserSessionToken(token)) {
context.errorResponse = ErrorResponse.builder()
.responseCode(HttpServletResponse.SC_FORBIDDEN)
.withMessage("User not authorized")
.build();
}
if (context.error().isPresent()) {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
RESTObjectMapper.mapper().writeValue(response.getWriter(), context.error().get());
return;
}
try {
Object responseBody =
restCatalogAdapter.execute(
context.method(),
context.path(),
context.queryParams(),
context.body(),
context.route().responseClass(),
context.headers(),
handle(response));
if (responseBody != null) {
RESTObjectMapper.mapper().writeValue(response.getWriter(), responseBody);
}
}
catch (RESTException e) {
if ((context.route() == Route.LOAD_TABLE && e.getLocalizedMessage().contains("NoSuchTableException")) ||
(context.route() == Route.LOAD_VIEW && e.getLocalizedMessage().contains("NoSuchViewException"))) {
// Suppress stack trace for load_table requests, most of which occur immediately
// preceding a create_table request
LOG.warn("Table at endpoint %s does not exist", context.path());
}
else {
LOG.error(e, "Error processing REST request at endpoint %s", context.path());
}
response.setStatus(SC_INTERNAL_SERVER_ERROR);
}
catch (Exception e) {
LOG.error(e, "Unexpected exception when processing REST request");
response.setStatus(SC_INTERNAL_SERVER_ERROR);
}
}
protected Consumer<ErrorResponse> handle(HttpServletResponse response)
{
return (errorResponse) -> {
response.setStatus(errorResponse.code());
try {
RESTObjectMapper.mapper().writeValue(response.getWriter(), errorResponse);
}
catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}
protected Claims getTokenClaims(String token)
{
token = token.replaceAll("Bearer token-exchange-token:sub=", "");
return Jwts.parserBuilder().build().parseClaimsJwt(token).getBody();
}
protected boolean isRestUserSessionToken(String token)
{
try {
getTokenClaims(token);
}
catch (MalformedJwtException mje) {
// Not a json web token
return false;
}
return true;
}
protected boolean isAuthorizedRestUserSessionToken(String jwt)
{
Claims jwtClaims = getTokenClaims(jwt);
return jwtClaims.getSubject().equals("user") &&
jwtClaims.getIssuer().equals("testversion") &&
jwtClaims.get("user").equals("user") &&
jwtClaims.get("source").equals("test");
}
public static class ServletRequestContext
{
private HTTPMethod method;
private Route route;
private String path;
private Map<String, String> headers;
private Map<String, String> queryParams;
private Object body;
private ErrorResponse errorResponse;
private ServletRequestContext(ErrorResponse errorResponse)
{
this.errorResponse = errorResponse;
}
private ServletRequestContext(
HTTPMethod method,
Route route,
String path,
Map<String, String> headers,
Map<String, String> queryParams,
Object body)
{
this.method = method;
this.route = route;
this.path = path;
this.headers = headers;
this.queryParams = queryParams;
this.body = body;
}
static ServletRequestContext from(HttpServletRequest request)
throws IOException
{
HTTPMethod method = HTTPMethod.valueOf(request.getMethod());
String path = request.getRequestURI().substring(1);
Pair<Route, Map<String, String>> routeContext = Route.from(method, path);
if (routeContext == null) {
return new ServletRequestContext(
ErrorResponse.builder()
.responseCode(400)
.withType("BadRequestException")
.withMessage(format("No route for request: %s %s", method, path))
.build());
}
Route route = routeContext.first();
Object requestBody = null;
if (route.requestClass() != null) {
requestBody =
RESTObjectMapper.mapper().readValue(request.getReader(), route.requestClass());
}
else if (route == Route.TOKENS) {
try (Reader reader = new InputStreamReader(request.getInputStream())) {
requestBody = RESTUtil.decodeFormData(CharStreams.toString(reader));
}
}
Map<String, String> queryParams =
request.getParameterMap().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue()[0]));
Map<String, String> headers =
Collections.list(request.getHeaderNames()).stream()
.collect(Collectors.toMap(Function.identity(), request::getHeader));
return new ServletRequestContext(method, route, path, headers, queryParams, requestBody);
}
public HTTPMethod method()
{
return method;
}
public Route route()
{
return route;
}
public String path()
{
return path;
}
public Map<String, String> headers()
{
return headers;
}
public Map<String, String> queryParams()
{
return queryParams;
}
public Object body()
{
return body;
}
public Optional<ErrorResponse> error()
{
return Optional.ofNullable(errorResponse);
}
}
}