SearchDefaultDialectCommandsTestBase.java

package redis.clients.jedis.commands.unified.search;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.emptyOrNullString;
import static org.hamcrest.Matchers.not;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static redis.clients.jedis.util.AssertUtil.assertOK;

import java.util.*;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import redis.clients.jedis.EndpointConfig;
import redis.clients.jedis.Endpoints;
import redis.clients.jedis.RedisProtocol;
import redis.clients.jedis.UnifiedJedis;
import redis.clients.jedis.exceptions.JedisDataException;
import redis.clients.jedis.search.*;
import redis.clients.jedis.search.schemafields.NumericField;
import redis.clients.jedis.search.schemafields.TextField;
import redis.clients.jedis.search.aggr.AggregationBuilder;
import redis.clients.jedis.search.aggr.AggregationResult;
import redis.clients.jedis.search.aggr.Reducers;
import redis.clients.jedis.search.aggr.Row;
import redis.clients.jedis.util.RedisVersionCondition;
import redis.clients.jedis.util.EnvCondition;

@Tag("search")
public abstract class SearchDefaultDialectCommandsTestBase {

  @RegisterExtension
  public RedisVersionCondition versionCondition = new RedisVersionCondition(() -> endpoint);

  @RegisterExtension
  public static EnvCondition envCondition = new EnvCondition();

  protected static final String INDEX = "dialect-INDEX";

  protected static EndpointConfig endpoint;

  protected final RedisProtocol protocol;

  protected UnifiedJedis jedis;

  public SearchDefaultDialectCommandsTestBase(RedisProtocol protocol) {
    this.protocol = protocol;
  }

  protected abstract UnifiedJedis createTestClient();

  public static void prepareEndpoint() {
    endpoint = Endpoints.getRedisEndpoint("modules-docker");
  }

  @BeforeEach
  public void setUp() {
    jedis = createTestClient();
    // Clean up before each test
    try {
      jedis.ftDropIndex(INDEX);
    } catch (Exception e) {
      // Index might not exist, ignore
    }
    jedis.flushAll();
    jedis.setDefaultSearchDialect(SearchProtocol.DEFAULT_DIALECT);
  }

  @AfterEach
  public void tearDown() throws Exception {
    if (jedis != null) {
      jedis.close();
    }
  }

  private void addDocument(Document doc) {
    String key = doc.getId();
    Map<String, String> map = new LinkedHashMap<>();
    doc.getProperties().forEach(entry -> map.put(entry.getKey(), String.valueOf(entry.getValue())));
    jedis.hset(key, map);
  }

  @Test
  public void testQueryParams() {
    Schema sc = new Schema().addNumericField("numval");
    assertEquals("OK", jedis.ftCreate(INDEX, IndexOptions.defaultOptions(), sc));

    jedis.hset("1", "numval", "1");
    jedis.hset("2", "numval", "2");
    jedis.hset("3", "numval", "3");

    Query query = new Query("@numval:[$min $max]").addParam("min", 1).addParam("max", 2);
    assertEquals(2, jedis.ftSearch(INDEX, query).getTotalResults());
  }

  @Test
  public void testQueryParamsWithParams() {
    assertOK(jedis.ftCreate(INDEX, NumericField.of("numval")));

    jedis.hset("1", "numval", "1");
    jedis.hset("2", "numval", "2");
    jedis.hset("3", "numval", "3");

    assertEquals(2, jedis.ftSearch(INDEX, "@numval:[$min $max]",
      FTSearchParams.searchParams().addParam("min", 1).addParam("max", 2)).getTotalResults());

    Map<String, Object> paramValues = new HashMap<>();
    paramValues.put("min", 1);
    paramValues.put("max", 2);
    assertEquals(2,
      jedis
          .ftSearch(INDEX, "@numval:[$min $max]", FTSearchParams.searchParams().params(paramValues))
          .getTotalResults());
  }

  @Test
  public void testDialectsWithFTExplain() throws Exception {
    Map<String, Object> attr = new HashMap<>();
    attr.put("TYPE", "FLOAT32");
    attr.put("DIM", 2);
    attr.put("DISTANCE_METRIC", "L2");

    Schema sc = new Schema().addFlatVectorField("v", attr).addTagField("title")
        .addTextField("t1", 1.0).addTextField("t2", 1.0).addNumericField("num");
    assertEquals("OK", jedis.ftCreate(INDEX, IndexOptions.defaultOptions(), sc));

    jedis.hset("1", "t1", "hello");

    String q = "(*)";
    Query query = new Query(q).dialect(1);
    assertSyntaxError(query, jedis); // dialect=1 throws syntax error
    query = new Query(q); // dialect=default=2 should return execution plan
    assertThat(jedis.ftExplain(INDEX, query), containsString("WILDCARD"));

    q = "$hello";
    query = new Query(q).dialect(1);
    assertSyntaxError(query, jedis); // dialect=1 throws syntax error
    query = new Query(q).addParam("hello", "hello"); // dialect=default=2 should return execution
                                                     // plan
    assertThat(jedis.ftExplain(INDEX, query), not(emptyOrNullString()));

    q = "@title:(@num:[0 10])";
    query = new Query(q).dialect(1); // dialect=1 should return execution plan
    assertThat(jedis.ftExplain(INDEX, query), not(emptyOrNullString()));
    query = new Query(q); // dialect=default=2
    assertSyntaxError(query, jedis); // dialect=2 throws syntax error

    q = "@t1:@t2:@t3:hello";
    query = new Query(q).dialect(1); // dialect=1 should return execution plan
    assertThat(jedis.ftExplain(INDEX, query), not(emptyOrNullString()));
    query = new Query(q); // dialect=default=2
    assertSyntaxError(query, jedis); // dialect=2 throws syntax error

    q = "@title:{foo}}}}}";
    query = new Query(q).dialect(1); // dialect=1 should return execution plan
    assertThat(jedis.ftExplain(INDEX, query), not(emptyOrNullString()));
    query = new Query(q); // dialect=default=2
    assertSyntaxError(query, jedis); // dialect=2 throws syntax error
  }

  @Test
  public void testAggregationBuilderParamsDialect() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("count");
    jedis.ftCreate(INDEX, IndexOptions.defaultOptions(), sc);
    addDocument(new Document("data1").set("name", "abc").set("count", 10));
    addDocument(new Document("data2").set("name", "def").set("count", 5));
    addDocument(new Document("data3").set("name", "def").set("count", 25));

    Map<String, Object> params = new HashMap<>();
    params.put("name", "abc");

    AggregationBuilder r = new AggregationBuilder("$name")
        .groupBy("@name", Reducers.sum("@count").as("sum")).params(params);

    AggregationResult res = jedis.ftAggregate(INDEX, r);
    assertEquals(1, res.getTotalResults());

    Row r1 = res.getRow(0);
    assertNotNull(r1);
    assertEquals("abc", r1.getString("name"));
    assertEquals(10, r1.getLong("sum"));
  }

  @Test
  public void dialectBoundSpellCheck() {
    jedis.ftCreate(INDEX, TextField.of("t"));
    JedisDataException error = assertThrows(JedisDataException.class,
      () -> jedis.ftSpellCheck(INDEX, "Tooni toque kerfuffle",
        FTSpellCheckParams.spellCheckParams().dialect(0)));
    assertThat(error.getMessage(), containsString("DIALECT requires a non negative integer"));
  }

  private void assertSyntaxError(Query query, UnifiedJedis client) {
    JedisDataException error = assertThrows(JedisDataException.class,
      () -> client.ftExplain(INDEX, query));
    assertThat(error.getMessage(), containsString("Syntax error"));
  }
}