AggregationTest.java

package redis.clients.jedis.modules.search;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.fail;
import static redis.clients.jedis.util.RedisConditions.ModuleVersion.SEARCH_MOD_VER_80M3;

import io.redis.test.annotations.SinceRedisVersion;
import io.redis.test.utils.RedisVersion;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedClass;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import redis.clients.jedis.RedisProtocol;
import redis.clients.jedis.exceptions.JedisDataException;
import redis.clients.jedis.modules.RedisModuleCommandsTestBase;
import redis.clients.jedis.search.*;
import redis.clients.jedis.search.aggr.*;
import redis.clients.jedis.search.schemafields.NumericField;
import redis.clients.jedis.search.schemafields.TextField;
import redis.clients.jedis.util.RedisConditions;
import redis.clients.jedis.util.RedisVersionUtil;

@ParameterizedClass
@MethodSource("redis.clients.jedis.commands.CommandsTestsParameters#respVersions")
public class AggregationTest extends RedisModuleCommandsTestBase {

  private static final String index = "aggbindex";

  @BeforeAll
  public static void prepare() {
    RedisModuleCommandsTestBase.prepare();
  }

  public AggregationTest(RedisProtocol redisProtocol) {
    super(redisProtocol);
  }

  private void addDocument(Document doc) {
    String key = doc.getId();
    Map<String, String> map = new LinkedHashMap<>();
    doc.getProperties().forEach(entry -> map.put(entry.getKey(), String.valueOf(entry.getValue())));
    client.hset(key, map);
  }

  private void addDocument(String key, Map<String, Object> objMap) {
    Map<String, String> strMap = new HashMap<>();
    objMap.entrySet().forEach(entry -> strMap.put(entry.getKey(), String.valueOf(entry.getValue())));
    client.hset(key, strMap);
  }

  @Test
  public void testAggregations() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("count");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);
//    client.addDocument(new Document("data1").set("name", "abc").set("count", 10));
//    client.addDocument(new Document("data2").set("name", "def").set("count", 5));
//    client.addDocument(new Document("data3").set("name", "def").set("count", 25));
    addDocument(new Document("data1").set("name", "abc").set("count", 10));
    addDocument(new Document("data2").set("name", "def").set("count", 5));
    addDocument(new Document("data3").set("name", "def").set("count", 25));

    AggregationBuilder r = new AggregationBuilder()
        .groupBy("@name", Reducers.sum("@count").as("sum"))
        .sortBy(10, SortedField.desc("@sum"));

    // actual search
    AggregationResult res = client.ftAggregate(index, r);
    assertEquals(2, res.getTotalResults());

    Row r1 = res.getRow(0);
    assertNotNull(r1);
    assertEquals("def", r1.getString("name"));
    assertEquals(30, r1.getLong("sum"));
    assertEquals(30., r1.getDouble("sum"), 0);

    assertEquals(0L, r1.getLong("nosuchcol"));
    assertEquals(0.0, r1.getDouble("nosuchcol"), 0);
    assertEquals("", r1.getString("nosuchcol"));

    Row r2 = res.getRow(1);
    assertNotNull(r2);
    assertEquals("abc", r2.getString("name"));
    assertEquals(10, r2.getLong("sum"));
  }

  @Test
  public void testAggregations2() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("count");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);

    addDocument(new Document("data1").set("name", "abc").set("count", 10));
    addDocument(new Document("data2").set("name", "def").set("count", 5));
    addDocument(new Document("data3").set("name", "def").set("count", 25));

    AggregationBuilder r = new AggregationBuilder()
        .groupBy("@name", Reducers.sum("@count").as("sum"))
        .sortBy(10, SortedField.desc("@sum"));

    // actual search
    AggregationResult res = client.ftAggregate(index, r);
    assertEquals(2, res.getTotalResults());

    List<Row> rows = res.getRows();
    assertEquals("def", rows.get(0).get("name"));
    assertEquals("30", rows.get(0).get("sum"));
    assertNull(rows.get(0).get("nosuchcol"));

    assertEquals("abc", rows.get(1).get("name"));
    assertEquals("10", rows.get(1).get("sum"));
  }

  @Test
  public void testAggregations2Profile() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("count");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);

    addDocument(new Document("data1").set("name", "abc").set("count", 10));
    addDocument(new Document("data2").set("name", "def").set("count", 5));
    addDocument(new Document("data3").set("name", "def").set("count", 25));

    AggregationBuilder aggr = new AggregationBuilder()
        .groupBy("@name", Reducers.sum("@count").as("sum"))
        .sortBy(10, SortedField.desc("@sum"));

    Map.Entry<AggregationResult, ProfilingInfo> reply
        = client.ftProfileAggregate(index, FTProfileParams.profileParams(), aggr);

    // actual search
    AggregationResult result = reply.getKey();
    assertEquals(2, result.getTotalResults());

    List<Row> rows = result.getRows();
    assertEquals("def", rows.get(0).get("name"));
    assertEquals("30", rows.get(0).get("sum"));
    assertNull(rows.get(0).get("nosuchcol"));

    assertEquals("abc", rows.get(1).get("name"));
    assertEquals("10", rows.get(1).get("sum"));

    // profile
    Object profileObject = reply.getValue().getProfilingInfo();
    if (protocol != RedisProtocol.RESP3) {
      assertThat(profileObject, Matchers.isA(List.class));
      if (RedisVersionUtil.getRedisVersion(client).isGreaterThanOrEqualTo(RedisVersion.V8_0_0_PRE)) {
        assertThat((List<Object>) profileObject, Matchers.hasItems("Shards", "Coordinator"));
      }
    } else {
      assertThat(profileObject, Matchers.isA(Map.class));
      if (RedisVersionUtil.getRedisVersion(client).isGreaterThanOrEqualTo(RedisVersion.V8_0_0_PRE)) {
        assertThat(((Map<String, Object>) profileObject).keySet(), Matchers.hasItems("Shards", "Coordinator"));
      }
    }
  }

  @Test
  public void testAggregationBuilderVerbatim() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);
    addDocument(new Document("data1").set("name", "hello kitty"));

    AggregationBuilder r = new AggregationBuilder("kitti");

    AggregationResult res = client.ftAggregate(index, r);
    assertEquals(1, res.getTotalResults());

    r = new AggregationBuilder("kitti")
            .verbatim();

    res = client.ftAggregate(index, r);
    assertEquals(0, res.getTotalResults());
  }

  @Test
  @SinceRedisVersion(value = "7.4.0", message = "ADDSCORES")
  public void testAggregationBuilderAddScores() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("age");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);
    addDocument(new Document("data1").set("name", "Adam").set("age", 33));
    addDocument(new Document("data2").set("name", "Sara").set("age", 44));

    AggregationBuilder r = new AggregationBuilder("sara").addScores()
        .apply("@__score * 100", "normalized_score").dialect(3);

    AggregationResult res = client.ftAggregate(index, r);
    if (RedisConditions.of(client).moduleVersionIsGreaterThanOrEqual(SEARCH_MOD_VER_80M3)) {
      // Default scorer is BM25
      assertEquals(0.6931, res.getRow(0).getDouble("__score"), 0.0001);
      assertEquals(69.31, res.getRow(0).getDouble("normalized_score"), 0.01);
    } else {
      // Default scorer is TF-IDF
      assertEquals(2, res.getRow(0).getLong("__score"));
      assertEquals(200, res.getRow(0).getLong("normalized_score"));
    }
  }

  @Test
  public void testAggregationBuilderTimeout() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("count");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);
    addDocument(new Document("data1").set("name", "abc").set("count", 10));
    addDocument(new Document("data2").set("name", "def").set("count", 5));
    addDocument(new Document("data3").set("name", "def").set("count", 25));

    AggregationBuilder r = new AggregationBuilder()
            .groupBy("@name", Reducers.sum("@count").as("sum"))
            .timeout(5000);

    AggregationResult res = client.ftAggregate(index, r);
    assertEquals(2, res.getTotalResults());
  }

  @Test
  public void testAggregationBuilderParamsDialect() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("count");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);
    addDocument(new Document("data1").set("name", "abc").set("count", 10));
    addDocument(new Document("data2").set("name", "def").set("count", 5));
    addDocument(new Document("data3").set("name", "def").set("count", 25));

    Map<String, Object> params = new HashMap<>();
    params.put("name", "abc");

    AggregationBuilder r = new AggregationBuilder("$name")
            .groupBy("@name", Reducers.sum("@count").as("sum"))
            .params(params)
            .dialect(2); // From documentation - To use PARAMS, DIALECT must be set to 2

    AggregationResult res = client.ftAggregate(index, r);
    assertEquals(1, res.getTotalResults());

    Row r1 = res.getRow(0);
    assertNotNull(r1);
    assertEquals("abc", r1.getString("name"));
    assertEquals(10, r1.getLong("sum"));
  }

  @Test
  public void testApplyAndFilterAggregations() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("subj1");
    sc.addSortableNumericField("subj2");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);
//    client.addDocument(new Document("data1").set("name", "abc").set("subj1", 20).set("subj2", 70));
//    client.addDocument(new Document("data2").set("name", "def").set("subj1", 60).set("subj2", 40));
//    client.addDocument(new Document("data3").set("name", "ghi").set("subj1", 50).set("subj2", 80));
//    client.addDocument(new Document("data4").set("name", "abc").set("subj1", 30).set("subj2", 20));
//    client.addDocument(new Document("data5").set("name", "def").set("subj1", 65).set("subj2", 45));
//    client.addDocument(new Document("data6").set("name", "ghi").set("subj1", 70).set("subj2", 70));
    addDocument(new Document("data1").set("name", "abc").set("subj1", 20).set("subj2", 70));
    addDocument(new Document("data2").set("name", "def").set("subj1", 60).set("subj2", 40));
    addDocument(new Document("data3").set("name", "ghi").set("subj1", 50).set("subj2", 80));
    addDocument(new Document("data4").set("name", "abc").set("subj1", 30).set("subj2", 20));
    addDocument(new Document("data5").set("name", "def").set("subj1", 65).set("subj2", 45));
    addDocument(new Document("data6").set("name", "ghi").set("subj1", 70).set("subj2", 70));

    AggregationBuilder r = new AggregationBuilder().apply("(@subj1+@subj2)/2", "attemptavg")
        .groupBy("@name", Reducers.avg("@attemptavg").as("avgscore"))
        .filter("@avgscore>=50")
        .sortBy(10, SortedField.asc("@name"));

    // actual search
    AggregationResult res = client.ftAggregate(index, r);
    assertEquals(3, res.getTotalResults());

    Row r1 = res.getRow(0);
    assertNotNull(r1);
    assertEquals("def", r1.getString("name"));
    assertEquals(52.5, r1.getDouble("avgscore"), 0);

    Row r2 = res.getRow(1);
    assertNotNull(r2);
    assertEquals("ghi", r2.getString("name"));
    assertEquals(67.5, r2.getDouble("avgscore"), 0);
  }

  @Test
  public void load() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("subj1");
    sc.addSortableNumericField("subj2");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);
//    client.addDocument(new Document("data1").set("name", "abc").set("subj1", 20).set("subj2", 70));
//    client.addDocument(new Document("data2").set("name", "def").set("subj1", 60).set("subj2", 40));
    addDocument(new Document("data1").set("name", "abc").set("subj1", 20).set("subj2", 70));
    addDocument(new Document("data2").set("name", "def").set("subj1", 60).set("subj2", 40));

    AggregationBuilder builder = new AggregationBuilder()
        .load(FieldName.of("@subj1").as("a"), FieldName.of("@subj2").as("b"))
        .apply("(@a+@b)/2", "avg").sortByDesc("@avg");

    AggregationResult result = client.ftAggregate(index, builder);
    assertEquals(50.0, result.getRow(0).getDouble("avg"), 0d);
    assertEquals(45.0, result.getRow(1).getDouble("avg"), 0d);
  }

  @Test
  public void loadAll() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("subj1");
    sc.addSortableNumericField("subj2");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);
    addDocument(new Document("data1").set("name", "abc").set("subj1", 20).set("subj2", 70));
    addDocument(new Document("data2").set("name", "def").set("subj1", 60).set("subj2", 40));

    AggregationBuilder builder = new AggregationBuilder()
        .loadAll()
        .apply("(@subj1+@subj2)/2", "avg").sortByDesc("@avg");

    AggregationResult result = client.ftAggregate(index, builder);
    assertEquals(50.0, result.getRow(0).getDouble("avg"), 0d);
    assertEquals(45.0, result.getRow(1).getDouble("avg"), 0d);
  }

  @Test
  public void cursor() {
    Schema sc = new Schema();
    sc.addSortableTextField("name", 1.0);
    sc.addSortableNumericField("count");
    client.ftCreate(index, IndexOptions.defaultOptions(), sc);
//    client.addDocument(new Document("data1").set("name", "abc").set("count", 10));
//    client.addDocument(new Document("data2").set("name", "def").set("count", 5));
//    client.addDocument(new Document("data3").set("name", "def").set("count", 25));
    addDocument(new Document("data1").set("name", "abc").set("count", 10));
    addDocument(new Document("data2").set("name", "def").set("count", 5));
    addDocument(new Document("data3").set("name", "def").set("count", 25));

    AggregationBuilder r = new AggregationBuilder()
        .groupBy("@name", Reducers.sum("@count").as("sum"))
        .sortBy(10, SortedField.desc("@sum"))
        .cursor(1, 3000);

    // actual search
    AggregationResult res = client.ftAggregate(index, r);
    assertEquals(2, res.getTotalResults());

    Row row = res.getRow(0);
    assertNotNull(row);
    assertEquals("def", row.getString("name"));
    assertEquals(30, row.getLong("sum"));
    assertEquals(30., row.getDouble("sum"), 0);

    assertEquals(0L, row.getLong("nosuchcol"));
    assertEquals(0.0, row.getDouble("nosuchcol"), 0);
    assertEquals("", row.getString("nosuchcol"));

    res = client.ftCursorRead(index, res.getCursorId(), 1);
    Row row2 = res.getRow(0);
    assertNotNull(row2);
    assertEquals("abc", row2.getString("name"));
    assertEquals(10, row2.getLong("sum"));

    assertEquals("OK", client.ftCursorDel(index, res.getCursorId()));

    try {
      client.ftCursorRead(index, res.getCursorId(), 1);
      fail();
    } catch (JedisDataException e) {
      // ignore
    }
  }

  @Test
  public void aggregateIteration() {
    client.ftCreate(index, TextField.of("name").sortable(), NumericField.of("count"));

    addDocument(new Document("data1").set("name", "abc").set("count", 10));
    addDocument(new Document("data2").set("name", "def").set("count", 5));
    addDocument(new Document("data3").set("name", "def").set("count", 25));
    addDocument(new Document("data4").set("name", "ghi").set("count", 15));
    addDocument(new Document("data5").set("name", "jkl").set("count", 20));

    AggregationBuilder agg = new AggregationBuilder()
        .groupBy("@name", Reducers.sum("@count").as("sum"))
        .sortBy(10, SortedField.desc("@sum"))
        .cursor(2, 10000);

    FtAggregateIteration rr = client.ftAggregateIteration(index, agg);
    int total = 0;
    while (!rr.isIterationCompleted()) {
      AggregationResult res = rr.nextBatch();
      int count = res.getRows().size();
      assertThat(count, Matchers.lessThanOrEqualTo(2));
      total += count;
    }
    assertEquals(4, total);
  }

  @Test
  public void aggregateIterationCollect() {
    client.ftCreate(index, TextField.of("name").sortable(), NumericField.of("count"));

    addDocument(new Document("data1").set("name", "abc").set("count", 10));
    addDocument(new Document("data2").set("name", "def").set("count", 5));
    addDocument(new Document("data3").set("name", "def").set("count", 25));
    addDocument(new Document("data4").set("name", "ghi").set("count", 15));
    addDocument(new Document("data5").set("name", "jkl").set("count", 20));

    AggregationBuilder agg = new AggregationBuilder()
        .groupBy("@name", Reducers.sum("@count").as("sum"))
        .sortBy(10, SortedField.desc("@sum"))
        .cursor(2, 10000);

    assertEquals(4, client.ftAggregateIteration(index, agg).collect(new ArrayList<>()).size());
  }

  @Test
  public void testWrongAggregation() throws InterruptedException {
    Schema sc = new Schema()
        .addTextField("title", 5.0)
        .addTextField("body", 1.0)
        .addTextField("state", 1.0)
        .addNumericField("price");

    client.ftCreate(index, IndexOptions.defaultOptions(), sc);

    // insert document(s)
    Map<String, Object> fields = new HashMap<>();
    fields.put("title", "hello world");
    fields.put("state", "NY");
    fields.put("body", "lorem ipsum");
    fields.put("price", "1337");
//    client.addDocument("doc1", fields);
    addDocument("doc1", fields);

    // wrong aggregation query
    AggregationBuilder builder = new AggregationBuilder("hello")
        .apply("@price/1000", "k")
        .groupBy("@state", Reducers.avg("@k").as("avgprice"))
        .filter("@avgprice>=2")
        .sortBy(10, SortedField.asc("@state"));

    try {
      client.ftAggregate(index, builder);
      fail();
    } catch (JedisDataException e) {
      // should throw JedisDataException on wrong aggregation query 
    }
  }
}