TestExtractSpatialInnerJoin.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.geospatial;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.rule.ExtractSpatialJoins.ExtractSpatialInnerJoin;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.util.Arrays;
import java.util.Map;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.geospatial.SphericalGeographyType.SPHERICAL_GEOGRAPHY;
import static com.facebook.presto.geospatial.type.GeometryType.GEOMETRY;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.spatialJoin;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
public class TestExtractSpatialInnerJoin
extends BaseRuleTest
{
private TestingRowExpressionTranslator sqlToRowExpressionTranslator;
public TestExtractSpatialInnerJoin()
{
}
@BeforeClass
public void setupTranslator()
{
this.sqlToRowExpressionTranslator = new TestingRowExpressionTranslator(tester().getMetadata());
}
@Test
public void testDoesNotFire()
{
// scalar expression
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(ST_GeometryFromText('POLYGON ((0 0, 0 0, 0 0, 0 0))'), b)",
ImmutableMap.of("b", GEOMETRY)),
p.join(INNER,
p.values(),
p.values(p.variable("b", GEOMETRY)))))
.doesNotFire();
// OR operand
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(ST_GeometryFromText(wkt), point) OR name_1 != name_2",
ImmutableMap.of("wkt", VARCHAR, "point", GEOMETRY, "name_1", BIGINT, "name_2", BIGINT)),
p.join(INNER,
p.values(p.variable("wkt", VARCHAR), p.variable("name_1")),
p.values(p.variable("point", GEOMETRY), p.variable("name_2")))))
.doesNotFire();
// NOT operator
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"NOT ST_Contains(ST_GeometryFromText(wkt), point)",
ImmutableMap.of("wkt", VARCHAR, "point", GEOMETRY, "name_1", BIGINT, "name_2", BIGINT)),
p.join(INNER,
p.values(p.variable("wkt", VARCHAR), p.variable("name_1")),
p.values(p.variable("point", GEOMETRY), p.variable("name_2")))))
.doesNotFire();
// ST_Distance(...) > r
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Distance(a, b) > 5",
ImmutableMap.of("a", GEOMETRY, "b", GEOMETRY)),
p.join(INNER,
p.values(p.variable("a", GEOMETRY)),
p.values(p.variable("b", GEOMETRY)))))
.doesNotFire();
}
@Test(enabled = false)
public void testSphericalGeographiesDoesNotFire()
{
// TODO enable once #13133 is merged
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(polygon, point)",
ImmutableMap.of("polygon", SPHERICAL_GEOGRAPHY, "point", SPHERICAL_GEOGRAPHY)),
p.join(INNER,
p.values(p.variable("polygon", SPHERICAL_GEOGRAPHY)),
p.values(p.variable("point", SPHERICAL_GEOGRAPHY)))))
.doesNotFire();
}
@Test
public void testDistanceQueries()
{
testSimpleDistanceQuery("ST_Distance(a, b) <= r", "ST_Distance(a, b) <= r");
testSimpleDistanceQuery("ST_Distance(b, a) <= r", "ST_Distance(b, a) <= r");
testSimpleDistanceQuery("r >= ST_Distance(a, b)", "ST_Distance(a, b) <= r");
testSimpleDistanceQuery("r >= ST_Distance(b, a)", "ST_Distance(b, a) <= r");
testSimpleDistanceQuery("ST_Distance(a, b) < r", "ST_Distance(a, b) < r");
testSimpleDistanceQuery("ST_Distance(b, a) < r", "ST_Distance(b, a) < r");
testSimpleDistanceQuery("r > ST_Distance(a, b)", "ST_Distance(a, b) < r");
testSimpleDistanceQuery("r > ST_Distance(b, a)", "ST_Distance(b, a) < r");
testSimpleDistanceQuery("ST_Distance(a, b) <= r AND name_a != name_b", "ST_Distance(a, b) <= r AND name_a != name_b");
testSimpleDistanceQuery("r > ST_Distance(a, b) AND name_a != name_b", "ST_Distance(a, b) < r AND name_a != name_b");
testRadiusExpressionInDistanceQuery("ST_Distance(a, b) <= decimal '1.2'", "ST_Distance(a, b) <= radius", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("ST_Distance(b, a) <= decimal '1.2'", "ST_Distance(b, a) <= radius", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("decimal '1.2' >= ST_Distance(a, b)", "ST_Distance(a, b) <= radius", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("decimal '1.2' >= ST_Distance(b, a)", "ST_Distance(b, a) <= radius", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("ST_Distance(a, b) < decimal '1.2'", "ST_Distance(a, b) < radius", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("ST_Distance(b, a) < decimal '1.2'", "ST_Distance(b, a) < radius", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("decimal '1.2' > ST_Distance(a, b)", "ST_Distance(a, b) < radius", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("decimal '1.2' > ST_Distance(b, a)", "ST_Distance(b, a) < radius", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("ST_Distance(a, b) <= decimal '1.2' AND name_a != name_b", "ST_Distance(a, b) <= radius AND name_a != name_b", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("decimal '1.2' > ST_Distance(a, b) AND name_a != name_b", "ST_Distance(a, b) < radius AND name_a != name_b", "decimal '1.2'");
testRadiusExpressionInDistanceQuery("ST_Distance(a, b) <= 2 * r", "ST_Distance(a, b) <= radius", "2 * r");
testRadiusExpressionInDistanceQuery("ST_Distance(b, a) <= 2 * r", "ST_Distance(b, a) <= radius", "2 * r");
testRadiusExpressionInDistanceQuery("2 * r >= ST_Distance(a, b)", "ST_Distance(a, b) <= radius", "2 * r");
testRadiusExpressionInDistanceQuery("2 * r >= ST_Distance(b, a)", "ST_Distance(b, a) <= radius", "2 * r");
testRadiusExpressionInDistanceQuery("ST_Distance(a, b) < 2 * r", "ST_Distance(a, b) < radius", "2 * r");
testRadiusExpressionInDistanceQuery("ST_Distance(b, a) < 2 * r", "ST_Distance(b, a) < radius", "2 * r");
testRadiusExpressionInDistanceQuery("2 * r > ST_Distance(a, b)", "ST_Distance(a, b) < radius", "2 * r");
testRadiusExpressionInDistanceQuery("2 * r > ST_Distance(b, a)", "ST_Distance(b, a) < radius", "2 * r");
testRadiusExpressionInDistanceQuery("ST_Distance(a, b) <= 2 * r AND name_a != name_b", "ST_Distance(a, b) <= radius AND name_a != name_b", "2 * r");
testRadiusExpressionInDistanceQuery("2 * r > ST_Distance(a, b) AND name_a != name_b", "ST_Distance(a, b) < radius AND name_a != name_b", "2 * r");
testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) <= 5", "ST_Distance(point_a, point_b) <= radius", "5");
testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a)) <= 5", "ST_Distance(point_b, point_a) <= radius", "5");
testPointExpressionsInDistanceQuery("5 >= ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b))", "ST_Distance(point_a, point_b) <= radius", "5");
testPointExpressionsInDistanceQuery("5 >= ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a))", "ST_Distance(point_b, point_a) <= radius", "5");
testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) < 5", "ST_Distance(point_a, point_b) < radius", "5");
testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a)) < 5", "ST_Distance(point_b, point_a) < radius", "5");
testPointExpressionsInDistanceQuery("5 > ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b))", "ST_Distance(point_a, point_b) < radius", "5");
testPointExpressionsInDistanceQuery("5 > ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a))", "ST_Distance(point_b, point_a) < radius", "5");
testPointExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) <= 5 AND name_a != name_b", "ST_Distance(point_a, point_b) <= radius AND name_a != name_b", "5");
testPointExpressionsInDistanceQuery("5 > ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) AND name_a != name_b", "ST_Distance(point_a, point_b) < radius AND name_a != name_b", "5");
testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) <= 500 / (111000 * cos(lat_b))", "ST_Distance(point_a, point_b) <= radius", "500 / (111000 * cos(lat_b))");
testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a)) <= 500 / (111000 * cos(lat_b))", "ST_Distance(point_b, point_a) <= radius", "500 / (111000 * cos(lat_b))");
testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) >= ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b))", "ST_Distance(point_a, point_b) <= radius", "500 / (111000 * cos(lat_b))");
testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) >= ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a))", "ST_Distance(point_b, point_a) <= radius", "500 / (111000 * cos(lat_b))");
testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) < 500 / (111000 * cos(lat_b))", "ST_Distance(point_a, point_b) < radius", "500 / (111000 * cos(lat_b))");
testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a)) < 500 / (111000 * cos(lat_b))", "ST_Distance(point_b, point_a) < radius", "500 / (111000 * cos(lat_b))");
testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) > ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b))", "ST_Distance(point_a, point_b) < radius", "500 / (111000 * cos(lat_b))");
testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) > ST_Distance(ST_Point(lng_b, lat_b), ST_Point(lng_a, lat_a))", "ST_Distance(point_b, point_a) < radius", "500 / (111000 * cos(lat_b))");
testPointAndRadiusExpressionsInDistanceQuery("ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) <= 500 / (111000 * cos(lat_b)) AND name_a != name_b", "ST_Distance(point_a, point_b) <= radius AND name_a != name_b", "500 / (111000 * cos(lat_b))");
testPointAndRadiusExpressionsInDistanceQuery("500 / (111000 * cos(lat_b)) > ST_Distance(ST_Point(lng_a, lat_a), ST_Point(lng_b, lat_b)) AND name_a != name_b", "ST_Distance(point_a, point_b) < radius AND name_a != name_b", "500 / (111000 * cos(lat_b))");
}
private void testSimpleDistanceQuery(String filter, String newFilter)
{
assertRuleApplication()
.on(p ->
p.filter(sqlToRowExpression(filter, ImmutableMap.of("a", GEOMETRY, "b", GEOMETRY, "name_a", BIGINT, "name_b", BIGINT, "r", BIGINT)),
p.join(INNER,
p.values(p.variable("a", GEOMETRY), p.variable("name_a")),
p.values(p.variable("b", GEOMETRY), p.variable("name_b"), p.variable("r")))))
.matches(
spatialJoin(newFilter,
values(ImmutableMap.of("a", 0, "name_a", 1)),
values(ImmutableMap.of("b", 0, "name_b", 1, "r", 2))));
}
private void testRadiusExpressionInDistanceQuery(String filter, String newFilter, String radiusExpression)
{
assertRuleApplication()
.on(p ->
p.filter(sqlToRowExpression(filter, ImmutableMap.of("a", GEOMETRY, "b", GEOMETRY, "name_a", BIGINT, "name_b", BIGINT, "r", BIGINT)),
p.join(INNER,
p.values(p.variable("a", GEOMETRY), p.variable("name_a")),
p.values(p.variable("b", GEOMETRY), p.variable("name_b"), p.variable("r")))))
.matches(
spatialJoin(newFilter,
values(ImmutableMap.of("a", 0, "name_a", 1)),
project(ImmutableMap.of("radius", expression(radiusExpression)),
values(ImmutableMap.of("b", 0, "name_b", 1, "r", 2)))));
}
private void testPointExpressionsInDistanceQuery(String filter, String newFilter, String radiusExpression)
{
assertRuleApplication()
.on(p ->
p.filter(sqlToRowExpression(filter, buildBigIntTypeProviderMap("lat_a", "lng_a", "lat_b", "lng_b", "name_a", "name_b")),
p.join(INNER,
p.values(p.variable("lat_a"), p.variable("lng_a"), p.variable("name_a")),
p.values(p.variable("lat_b"), p.variable("lng_b"), p.variable("name_b")))))
.matches(
spatialJoin(newFilter,
project(ImmutableMap.of("point_a", expression("ST_Point(lng_a, lat_a)")),
values(ImmutableMap.of("lat_a", 0, "lng_a", 1, "name_a", 2))),
project(ImmutableMap.of("point_b", expression("ST_Point(lng_b, lat_b)")),
project(ImmutableMap.of("radius", expression(radiusExpression)), values(ImmutableMap.of("lat_b", 0, "lng_b", 1, "name_b", 2))))));
}
private void testPointAndRadiusExpressionsInDistanceQuery(String filter, String newFilter, String radiusExpression)
{
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(filter, buildBigIntTypeProviderMap("lat_a", "lng_a", "lat_b", "lng_b", "name_a", "name_b")),
p.join(INNER,
p.values(p.variable("lat_a"), p.variable("lng_a"), p.variable("name_a")),
p.values(p.variable("lat_b"), p.variable("lng_b"), p.variable("name_b")))))
.matches(
spatialJoin(newFilter,
project(ImmutableMap.of("point_a", expression("ST_Point(lng_a, lat_a)")),
values(ImmutableMap.of("lat_a", 0, "lng_a", 1, "name_a", 2))),
project(ImmutableMap.of("point_b", expression("ST_Point(lng_b, lat_b)")),
project(ImmutableMap.of("radius", expression(radiusExpression)),
values(ImmutableMap.of("lat_b", 0, "lng_b", 1, "name_b", 2))))));
}
@Test
public void testSphericalGeographiesDoesFire()
{
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Distance(a, b) < 5000",
ImmutableMap.of("a", SPHERICAL_GEOGRAPHY, "b", SPHERICAL_GEOGRAPHY)),
p.join(INNER,
p.values(p.variable("a", SPHERICAL_GEOGRAPHY)),
p.values(p.variable("b", SPHERICAL_GEOGRAPHY)))))
.matches(
spatialJoin("ST_Distance(a, b) < radius",
values(ImmutableMap.of("a", 0)),
project(ImmutableMap.of(
"b", expression("b"),
"radius", expression("BIGINT '5000'")),
values(ImmutableMap.of("b", 0)))));
// to_spherical_geography() operand
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Distance(to_spherical_geography(ST_GeometryFromText(wkt)), point) < 5",
ImmutableMap.of("wkt", VARCHAR, "point", SPHERICAL_GEOGRAPHY)),
p.join(INNER,
p.values(p.variable("wkt", VARCHAR)),
p.values(p.variable("point", SPHERICAL_GEOGRAPHY)))))
.matches(
spatialJoin("ST_Distance(to_spherical_geography, point) < radius",
project(ImmutableMap.of("to_spherical_geography", expression("to_spherical_geography(ST_GeometryFromText(wkt))")),
values(ImmutableMap.of("wkt", 0))),
project(ImmutableMap.of(
"point", expression("point"),
"radius", expression("INTEGER '5'")),
values(ImmutableMap.of("point", 0)))));
}
@Test
public void testConvertToSpatialJoin()
{
// symbols
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(a, b)",
ImmutableMap.of("a", GEOMETRY, "b", GEOMETRY)),
p.join(INNER,
p.values(p.variable("a", GEOMETRY)),
p.values(p.variable("b", GEOMETRY)))))
.matches(
spatialJoin("ST_Contains(a, b)",
values(ImmutableMap.of("a", 0)),
values(ImmutableMap.of("b", 0))));
// AND
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"name_1 != name_2 AND ST_Contains(a, b)",
ImmutableMap.of("a", GEOMETRY, "b", GEOMETRY, "name_1", BIGINT, "name_2", BIGINT)),
p.join(INNER,
p.values(p.variable("a", GEOMETRY), p.variable("name_1")),
p.values(p.variable("b", GEOMETRY), p.variable("name_2")))))
.matches(
spatialJoin("name_1 != name_2 AND ST_Contains(a, b)",
values(ImmutableMap.of("a", 0, "name_1", 1)),
values(ImmutableMap.of("b", 0, "name_2", 1))));
// AND
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(a1, b1) AND ST_Contains(a2, b2)",
ImmutableMap.of("a1", GEOMETRY, "a2", GEOMETRY, "b1", GEOMETRY, "b2", GEOMETRY)),
p.join(INNER,
p.values(p.variable("a1", GEOMETRY), p.variable("a2", GEOMETRY)),
p.values(p.variable("b1", GEOMETRY), p.variable("b2", GEOMETRY)))))
.matches(
spatialJoin("ST_Contains(a1, b1) AND ST_Contains(a2, b2)",
values(ImmutableMap.of("a1", 0, "a2", 1)),
values(ImmutableMap.of("b1", 0, "b2", 1))));
}
@Test
public void testPushDownFirstArgument()
{
assertRuleApplication()
.on(p ->
p.filter(sqlToRowExpression(
"ST_Contains(ST_GeometryFromText(wkt), point)",
ImmutableMap.of("wkt", VARCHAR, "point", GEOMETRY)),
p.join(INNER,
p.values(p.variable("wkt", VARCHAR)),
p.values(p.variable("point", GEOMETRY)))))
.matches(
spatialJoin("ST_Contains(st_geometryfromtext, point)",
project(ImmutableMap.of("st_geometryfromtext", expression("ST_GeometryFromText(wkt)")), values(ImmutableMap.of("wkt", 0))),
values(ImmutableMap.of("point", 0))));
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(ST_GeometryFromText(wkt), ST_Point(0, 0))",
ImmutableMap.of("wkt", VARCHAR)),
p.join(INNER,
p.values(p.variable("wkt", VARCHAR)),
p.values())))
.doesNotFire();
}
@Test
public void testPushDownSecondArgument()
{
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(polygon, ST_Point(lng, lat))",
ImmutableMap.of("polygon", GEOMETRY, "lat", BIGINT, "lng", BIGINT)),
p.join(INNER,
p.values(p.variable("polygon", GEOMETRY)),
p.values(p.variable("lat"), p.variable("lng")))))
.matches(
spatialJoin("ST_Contains(polygon, st_point)",
values(ImmutableMap.of("polygon", 0)),
project(ImmutableMap.of("st_point", expression("ST_Point(lng, lat)")), values(ImmutableMap.of("lat", 0, "lng", 1)))));
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(ST_GeometryFromText('POLYGON ((0 0, 0 0, 0 0, 0 0))'), ST_Point(lng, lat))",
ImmutableMap.of("lat", BIGINT, "lng", BIGINT)),
p.join(INNER,
p.values(),
p.values(p.variable("lat"), p.variable("lng")))))
.doesNotFire();
}
@Test
public void testPushDownBothArguments()
{
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))",
ImmutableMap.of("wkt", VARCHAR, "lat", BIGINT, "lng", BIGINT)),
p.join(INNER,
p.values(p.variable("wkt", VARCHAR)),
p.values(p.variable("lat"), p.variable("lng")))))
.matches(
spatialJoin("ST_Contains(st_geometryfromtext, st_point)",
project(ImmutableMap.of("st_geometryfromtext", expression("ST_GeometryFromText(wkt)")), values(ImmutableMap.of("wkt", 0))),
project(ImmutableMap.of("st_point", expression("ST_Point(lng, lat)")), values(ImmutableMap.of("lat", 0, "lng", 1)))));
}
@Test
public void testPushDownOppositeOrder()
{
assertRuleApplication()
.on(p ->
p.filter(sqlToRowExpression("ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", ImmutableMap.of("wkt", VARCHAR, "lat", BIGINT, "lng", BIGINT)),
p.join(INNER,
p.values(p.variable("lat"), p.variable("lng")),
p.values(p.variable("wkt", VARCHAR)))))
.matches(
spatialJoin("ST_Contains(st_geometryfromtext, st_point)",
project(ImmutableMap.of("st_point", expression("ST_Point(lng, lat)")), values(ImmutableMap.of("lat", 0, "lng", 1))),
project(ImmutableMap.of("st_geometryfromtext", expression("ST_GeometryFromText(wkt)")), values(ImmutableMap.of("wkt", 0)))));
}
@Test
public void testPushDownAnd()
{
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"name_1 != name_2 AND ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))",
ImmutableMap.of("wkt", VARCHAR, "lat", BIGINT, "lng", BIGINT, "name_1", BIGINT, "name_2", BIGINT)),
p.join(INNER,
p.values(p.variable("wkt", VARCHAR), p.variable("name_1")),
p.values(p.variable("lat"), p.variable("lng"), p.variable("name_2")))))
.matches(
spatialJoin("name_1 != name_2 AND ST_Contains(st_geometryfromtext, st_point)",
project(ImmutableMap.of("st_geometryfromtext", expression("ST_GeometryFromText(wkt)")), values(ImmutableMap.of("wkt", 0, "name_1", 1))),
project(ImmutableMap.of("st_point", expression("ST_Point(lng, lat)")), values(ImmutableMap.of("lat", 0, "lng", 1, "name_2", 2)))));
// Multiple spatial functions - only the first one is being processed
assertRuleApplication()
.on(p ->
p.filter(
sqlToRowExpression(
"ST_Contains(ST_GeometryFromText(wkt1), geometry1) AND ST_Contains(ST_GeometryFromText(wkt2), geometry2)",
ImmutableMap.of("wkt1", VARCHAR, "wkt2", VARCHAR, "geometry1", GEOMETRY, "geometry2", GEOMETRY)),
p.join(INNER,
p.values(p.variable("wkt1", VARCHAR), p.variable("wkt2", VARCHAR)),
p.values(p.variable("geometry1", GEOMETRY), p.variable("geometry2", GEOMETRY)))))
.matches(
spatialJoin("ST_Contains(st_geometryfromtext, geometry1) AND ST_Contains(ST_GeometryFromText(wkt2), geometry2)",
project(ImmutableMap.of("st_geometryfromtext", expression("ST_GeometryFromText(wkt1)")), values(ImmutableMap.of("wkt1", 0, "wkt2", 1))),
values(ImmutableMap.of("geometry1", 0, "geometry2", 1))));
}
private RuleAssert assertRuleApplication()
{
RuleTester tester = tester();
return tester.assertThat(new ExtractSpatialInnerJoin(tester.getMetadata(), tester.getSplitManager(), tester.getPageSourceManager()));
}
private RowExpression sqlToRowExpression(String sql, Map<String, Type> typeMap)
{
return sqlToRowExpressionTranslator.translateAndOptimize(PlanBuilder.expression(sql), TypeProvider.copyOf(typeMap));
}
private static Map<String, Type> buildBigIntTypeProviderMap(String... variables)
{
ImmutableMap.Builder<String, Type> builder = ImmutableMap.builder();
Arrays.stream(variables).forEach(variable -> builder.put(variable, BIGINT));
return builder.build();
}
}