TestRestSqlFunctionExecutorRouting.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.functionNamespace.rest;

import com.facebook.airlift.http.client.HttpClient;
import com.facebook.airlift.http.client.Request;
import com.facebook.airlift.http.client.testing.TestingHttpClient;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.function.SqlFunctionResult;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.RemoteScalarFunctionImplementation;
import com.facebook.presto.spi.function.RestFunctionHandle;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.page.PagesSerde;
import com.google.common.collect.ImmutableList;
import com.google.common.net.MediaType;
import io.airlift.slice.DynamicSliceOutput;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import java.net.URI;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

import static com.facebook.airlift.http.client.HttpStatus.OK;
import static com.facebook.airlift.http.client.testing.TestingResponse.mockResponse;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.spi.function.FunctionImplementationType.REST;
import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP;
import static com.facebook.presto.spi.page.PagesSerdeUtil.writeSerializedPage;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;

@Test(singleThreaded = true)
public class TestRestSqlFunctionExecutorRouting
{
    private static final MediaType PRESTO_PAGES = MediaType.create("application", "X-presto-pages");

    private RestSqlFunctionExecutor executor;
    private AtomicReference<Request> capturedRequest;
    private RestBasedFunctionNamespaceManagerConfig config;

    @BeforeMethod
    public void setup()
    {
        capturedRequest = new AtomicReference<>();
        config = new RestBasedFunctionNamespaceManagerConfig()
                .setRestUrl("http://default-server.example.com");

        HttpClient httpClient = new TestingHttpClient(request -> {
            capturedRequest.set(request);
            // Return a valid response
            PagesSerde pagesSerde = new PagesSerde(new BlockEncodingManager(), Optional.empty(), Optional.empty(), Optional.empty());
            PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT));
            pageBuilder.declarePosition();
            BIGINT.writeLong(pageBuilder.getBlockBuilder(0), 42);
            Page resultPage = pageBuilder.build();

            DynamicSliceOutput sliceOutput = new DynamicSliceOutput(resultPage.getPositionCount());
            writeSerializedPage(sliceOutput, pagesSerde.serialize(resultPage));
            return mockResponse(OK, PRESTO_PAGES, sliceOutput.slice().toStringUtf8());
        });

        executor = new RestSqlFunctionExecutor(config, httpClient);
        executor.setBlockEncodingSerde(new BlockEncodingManager());
    }

    @Test
    public void testRoutesToDefaultServerWhenNoExecutionEndpoint()
            throws Exception
    {
        SqlFunctionId functionId = new SqlFunctionId(
                QualifiedObjectName.valueOf("test.schema.function"),
                ImmutableList.of(parseTypeSignature("bigint")));

        Signature signature = new Signature(
                QualifiedObjectName.valueOf("test.schema.function"),
                FunctionKind.SCALAR,
                parseTypeSignature("bigint"),
                ImmutableList.of(parseTypeSignature("bigint")));

        RestFunctionHandle handle = new RestFunctionHandle(
                functionId,
                "1.0",
                signature);  // No execution endpoint

        RemoteScalarFunctionImplementation implementation = new RemoteScalarFunctionImplementation(
                handle,
                CPP,
                REST);

        PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT));
        pageBuilder.declarePosition();
        BIGINT.writeLong(pageBuilder.getBlockBuilder(0), 10);
        Page input = pageBuilder.build();

        SqlFunctionResult result = executor.executeFunction(
                "test-source",
                implementation,
                input,
                ImmutableList.of(0),
                ImmutableList.of(BIGINT),
                BIGINT).get();

        assertTrue(result.getResult().getPositionCount() == 1);
        assertEquals(BIGINT.getLong(result.getResult(), 0), 42L);
        assertNotNull(capturedRequest.get());
        URI uri = capturedRequest.get().getUri();
        assertEquals(uri.getHost(), "default-server.example.com");
        assertTrue(uri.getPath().contains("/v1/functions/schema/function/"));
    }

    @Test
    public void testRoutesToCustomExecutionEndpoint()
            throws Exception
    {
        SqlFunctionId functionId = new SqlFunctionId(
                QualifiedObjectName.valueOf("test.schema.function"),
                ImmutableList.of(parseTypeSignature("bigint")));

        Signature signature = new Signature(
                QualifiedObjectName.valueOf("test.schema.function"),
                FunctionKind.SCALAR,
                parseTypeSignature("bigint"),
                ImmutableList.of(parseTypeSignature("bigint")));

        RestFunctionHandle handle = new RestFunctionHandle(
                functionId,
                "1.0",
                signature,
                Optional.of(URI.create("https://compute-cluster-1.example.com")));

        RemoteScalarFunctionImplementation implementation = new RemoteScalarFunctionImplementation(
                handle,
                CPP,
                REST);

        PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT));
        pageBuilder.declarePosition();
        BIGINT.writeLong(pageBuilder.getBlockBuilder(0), 10);
        Page input = pageBuilder.build();

        SqlFunctionResult result = executor.executeFunction(
                "test-source",
                implementation,
                input,
                ImmutableList.of(0),
                ImmutableList.of(BIGINT),
                BIGINT).get();

        assertTrue(result.getResult().getPositionCount() == 1);
        assertEquals(BIGINT.getLong(result.getResult(), 0), 42L);
        assertNotNull(capturedRequest.get());
        URI uri = capturedRequest.get().getUri();
        assertEquals(uri.getScheme(), "https");
        assertEquals(uri.getHost(), "compute-cluster-1.example.com");
        assertTrue(uri.getPath().contains("/v1/functions/schema/function/"));
    }

    @Test
    public void testMultipleFunctionsRouteToDifferentServers()
            throws Exception
    {
        // Test that different functions can route to different execution servers
        SqlFunctionId functionId1 = new SqlFunctionId(
                QualifiedObjectName.valueOf("test.schema.function1"),
                ImmutableList.of(parseTypeSignature("bigint")));

        SqlFunctionId functionId2 = new SqlFunctionId(
                QualifiedObjectName.valueOf("test.schema.function2"),
                ImmutableList.of(parseTypeSignature("bigint")));

        Signature signature1 = new Signature(
                QualifiedObjectName.valueOf("test.schema.function1"),
                FunctionKind.SCALAR,
                parseTypeSignature("bigint"),
                ImmutableList.of(parseTypeSignature("bigint")));

        Signature signature2 = new Signature(
                QualifiedObjectName.valueOf("test.schema.function2"),
                FunctionKind.SCALAR,
                parseTypeSignature("bigint"),
                ImmutableList.of(parseTypeSignature("bigint")));

        RestFunctionHandle handle1 = new RestFunctionHandle(
                functionId1,
                "1.0",
                signature1,
                Optional.of(URI.create("https://server1.example.com")));

        RestFunctionHandle handle2 = new RestFunctionHandle(
                functionId2,
                "1.0",
                signature2,
                Optional.of(URI.create("https://server2.example.com")));

        RemoteScalarFunctionImplementation implementation1 = new RemoteScalarFunctionImplementation(
                handle1,
                CPP,
                REST);

        RemoteScalarFunctionImplementation implementation2 = new RemoteScalarFunctionImplementation(
                handle2,
                CPP,
                REST);

        PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT));
        pageBuilder.declarePosition();
        BIGINT.writeLong(pageBuilder.getBlockBuilder(0), 10);
        Page input = pageBuilder.build();

        // Execute function 1
        executor.executeFunction("test-source", implementation1, input, ImmutableList.of(0), ImmutableList.of(BIGINT), BIGINT).get();
        assertEquals(capturedRequest.get().getUri().getHost(), "server1.example.com");

        // Execute function 2
        executor.executeFunction("test-source", implementation2, input, ImmutableList.of(0), ImmutableList.of(BIGINT), BIGINT).get();
        assertEquals(capturedRequest.get().getUri().getHost(), "server2.example.com");
    }
}