QuerySessionSupplier.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.server;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.common.WarningHandlingLevel;
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.execution.warnings.WarningCollectorFactory;
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.server.security.SecurityConfig;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.security.AccessControl;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.sql.SqlEnvironmentConfig;
import com.facebook.presto.transaction.TransactionManager;

import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

import java.util.Locale;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.Session.SessionBuilder;
import static com.facebook.presto.SystemSessionProperties.WARNING_HANDLING;
import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey;
import static com.facebook.presto.security.AccessControlUtils.checkPermissions;
import static com.facebook.presto.security.AccessControlUtils.getAuthorizedIdentity;
import static java.util.Map.Entry;
import static java.util.Objects.requireNonNull;

@ThreadSafe
public class QuerySessionSupplier
        implements SessionSupplier
{
    private final Logger log = Logger.get(QuerySessionSupplier.class);

    private final TransactionManager transactionManager;
    private final AccessControl accessControl;
    private final SessionPropertyManager sessionPropertyManager;
    private final Optional<TimeZoneKey> forcedSessionTimeZone;
    private final SecurityConfig securityConfig;

    @Inject
    public QuerySessionSupplier(
            TransactionManager transactionManager,
            AccessControl accessControl,
            SessionPropertyManager sessionPropertyManager,
            SqlEnvironmentConfig sqlEnvironmentConfig,
            SecurityConfig securityConfig)
    {
        this.transactionManager = requireNonNull(transactionManager, "transactionManager is null");
        this.accessControl = requireNonNull(accessControl, "accessControl is null");
        this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
        requireNonNull(sqlEnvironmentConfig, "sqlEnvironmentConfig is null");
        this.forcedSessionTimeZone = requireNonNull(sqlEnvironmentConfig.getForcedSessionTimeZone(), "forcedSessionTimeZone is null");
        this.securityConfig = requireNonNull(securityConfig, "securityConfig is null");
    }

    @Override
    public SessionBuilder createSessionBuilder(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory)
    {
        SessionBuilder sessionBuilder = Session.builder(sessionPropertyManager)
                .setQueryId(queryId)
                .setIdentity(authenticateIdentity(queryId, context))
                .setSource(context.getSource())
                .setCatalog(context.getCatalog())
                .setSchema(context.getSchema())
                .setRemoteUserAddress(context.getRemoteUserAddress())
                .setUserAgent(context.getUserAgent())
                .setClientInfo(context.getClientInfo())
                .setClientTags(context.getClientTags())
                .setTraceToken(context.getTraceToken())
                .setResourceEstimates(context.getResourceEstimates())
                .setTracer(context.getTracer())
                .setRuntimeStats(context.getRuntimeStats());

        if (forcedSessionTimeZone.isPresent()) {
            sessionBuilder.setTimeZoneKey(forcedSessionTimeZone.get());
        }
        else if (context.getTimeZoneId() != null) {
            sessionBuilder.setTimeZoneKey(getTimeZoneKey(context.getTimeZoneId()));
        }

        if (context.getLanguage() != null) {
            sessionBuilder.setLocale(Locale.forLanguageTag(context.getLanguage()));
        }

        for (Entry<String, String> entry : context.getSystemProperties().entrySet()) {
            sessionBuilder.setSystemProperty(entry.getKey(), entry.getValue());
        }
        for (Entry<String, Map<String, String>> catalogProperties : context.getCatalogSessionProperties().entrySet()) {
            String catalog = catalogProperties.getKey();
            for (Entry<String, String> entry : catalogProperties.getValue().entrySet()) {
                sessionBuilder.setCatalogSessionProperty(catalog, entry.getKey(), entry.getValue());
            }
        }

        for (Entry<String, String> preparedStatement : context.getPreparedStatements().entrySet()) {
            sessionBuilder.addPreparedStatement(preparedStatement.getKey(), preparedStatement.getValue());
        }

        if (context.supportClientTransaction()) {
            sessionBuilder.setClientTransactionSupport();
        }

        for (Entry<SqlFunctionId, SqlInvokedFunction> entry : context.getSessionFunctions().entrySet()) {
            sessionBuilder.addSessionFunction(entry.getKey(), entry.getValue());
        }

        // Put after setSystemProperty are called
        WarningCollector warningCollector = warningCollectorFactory.create(sessionBuilder.getSystemProperty(WARNING_HANDLING, WarningHandlingLevel.class));
        sessionBuilder.setWarningCollector(warningCollector);

        return sessionBuilder;
    }

    private Identity authenticateIdentity(QueryId queryId, SessionContext context)
    {
        checkPermissions(accessControl, securityConfig, queryId, context);
        Optional<AuthorizedIdentity> authorizedIdentity = context.getAuthorizedIdentity();
        authorizedIdentity = authorizedIdentity.isPresent() ? authorizedIdentity : getAuthorizedIdentity(accessControl, securityConfig, queryId, context);

        return authorizedIdentity.map(identity -> new Identity(
                context.getIdentity().getUser(),
                context.getIdentity().getPrincipal(),
                context.getIdentity().getRoles(),
                context.getIdentity().getExtraCredentials(),
                context.getIdentity().getExtraAuthenticators(),
                Optional.of(identity.getUserName()),
                identity.getReasonForSelect())).orElseGet(context::getIdentity);
    }
}