UDFTest.java

package org.sqlite;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Offset.offset;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.LinkedList;
import java.util.List;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/** Tests User Defined Functions. */
public class UDFTest {
    private static int val = 0;
    private static final byte[] b1 = new byte[] {2, 5, -4, 8, -1, 3, -5};
    private static int gotTrigger = 0;

    private Connection conn;
    private Statement stat;

    @BeforeEach
    public void connect() throws Exception {
        conn = DriverManager.getConnection("jdbc:sqlite:");
        stat = conn.createStatement();
    }

    @AfterEach
    public void close() throws SQLException {
        stat.close();
        conn.close();
    }

    @Test
    public void calling() throws SQLException {
        Function.create(
                conn,
                "f1",
                new Function() {
                    @Override
                    public void xFunc() {
                        val = 4;
                    }
                });
        stat.executeQuery("select f1();").close();
        assertThat(val).isEqualTo(4);
    }

    @Test
    public void returning() throws SQLException {
        Function.create(
                conn,
                "f2",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(4);
                    }
                });
        ResultSet rs = stat.executeQuery("select f2();");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(4);
        rs.close();

        for (int i = 0; i < 20; i++) {
            rs = stat.executeQuery("select (f2() + " + i + ");");
            assertThat(rs.next()).isTrue();
            assertThat(rs.getInt(1)).isEqualTo(4 + i);
            rs.close();
        }
    }

    @Test
    public void accessArgs() throws SQLException {
        Function.create(
                conn,
                "f3",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(value_int(0));
                    }
                });
        for (int i = 0; i < 15; i++) {
            ResultSet rs = stat.executeQuery("select f3(" + i + ");");
            assertThat(rs.next()).isTrue();
            assertThat(rs.getInt(1)).isEqualTo(i);
            rs.close();
        }
    }

    @Test
    public void multipleArgs() throws SQLException {
        Function.create(
                conn,
                "f4",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        int ret = 0;
                        for (int i = 0; i < args(); i++) {
                            ret += value_int(i);
                        }
                        result(ret);
                    }
                });
        ResultSet rs = stat.executeQuery("select f4(2, 3, 9, -5);");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(9);
        rs.close();
        rs = stat.executeQuery("select f4(2);");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(2);
        rs.close();
        rs = stat.executeQuery("select f4(-3, -4, -5);");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(-12);
    }

    @Test
    public void returnTypes() throws SQLException {
        Function.create(
                conn,
                "f5",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result("Hello World");
                    }
                });
        ResultSet rs = stat.executeQuery("select f5();");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getString(1)).isEqualTo("Hello World");

        Function.create(
                conn,
                "f6",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(Long.MAX_VALUE);
                    }
                });
        rs.close();
        rs = stat.executeQuery("select f6();");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getLong(1)).isEqualTo(Long.MAX_VALUE);

        Function.create(
                conn,
                "f7",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(Double.MAX_VALUE);
                    }
                });
        rs.close();
        rs = stat.executeQuery("select f7();");
        assertThat(rs.next()).isTrue();
        assertThat(Double.MAX_VALUE).isCloseTo(rs.getDouble(1), offset(0.0001));

        Function.create(
                conn,
                "f8",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(b1);
                    }
                });
        rs.close();
        rs = stat.executeQuery("select f8();");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getBytes(1)).containsExactly(b1);
    }

    @Test
    public void returnArgInt() throws SQLException {
        Function.create(
                conn,
                "farg_int",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(value_int(0));
                    }
                });
        PreparedStatement prep = conn.prepareStatement("select farg_int(?);");
        prep.setInt(1, Integer.MAX_VALUE);
        ResultSet rs = prep.executeQuery();
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(Integer.MAX_VALUE);
        prep.close();
    }

    @Test
    public void returnArgLong() throws SQLException {
        Function.create(
                conn,
                "farg_long",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(value_long(0));
                    }
                });
        PreparedStatement prep = conn.prepareStatement("select farg_long(?);");
        prep.setLong(1, Long.MAX_VALUE);
        ResultSet rs = prep.executeQuery();
        assertThat(rs.next()).isTrue();
        assertThat(rs.getLong(1)).isEqualTo(Long.MAX_VALUE);
        prep.close();
    }

    @Test
    public void returnArgDouble() throws SQLException {
        Function.create(
                conn,
                "farg_doub",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(value_double(0));
                    }
                });
        PreparedStatement prep = conn.prepareStatement("select farg_doub(?);");
        prep.setDouble(1, Double.MAX_VALUE);
        ResultSet rs = prep.executeQuery();
        assertThat(rs.next()).isTrue();
        assertThat(Double.MAX_VALUE).isCloseTo(rs.getDouble(1), offset(0.0001));
        prep.close();
    }

    @Test
    public void returnArgBlob() throws SQLException {
        Function.create(
                conn,
                "farg_blob",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(value_blob(0));
                    }
                });
        PreparedStatement prep = conn.prepareStatement("select farg_blob(?);");
        prep.setBytes(1, b1);
        ResultSet rs = prep.executeQuery();
        assertThat(rs.next()).isTrue();
        assertThat(rs.getBytes(1)).containsExactly(b1);
        prep.close();
    }

    @Test
    public void returnArgString() throws SQLException {
        Function.create(
                conn,
                "farg_str",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(value_text(0));
                    }
                });
        PreparedStatement prep = conn.prepareStatement("select farg_str(?);");
        prep.setString(1, "Hello");
        ResultSet rs = prep.executeQuery();
        assertThat(rs.next()).isTrue();
        assertThat(rs.getString(1)).isEqualTo("Hello");
        prep.close();
    }

    @Test
    public void trigger() throws SQLException {
        Function.create(
                conn,
                "inform",
                new Function() {
                    @Override
                    protected void xFunc() throws SQLException {
                        gotTrigger = value_int(0);
                    }
                });
        stat.executeUpdate("create table trigtest (c1);");
        stat.executeUpdate(
                "create trigger trigt after insert on trigtest"
                        + " begin select inform(new.c1); end;");
        stat.executeUpdate("insert into trigtest values (5);");
        assertThat(gotTrigger).isEqualTo(5);
    }

    @Test
    public void aggregate() throws SQLException {
        Function.create(
                conn,
                "mySum",
                new Function.Aggregate() {
                    private int val = 0;

                    @Override
                    protected void xStep() throws SQLException {
                        for (int i = 0; i < args(); i++) {
                            val += value_int(i);
                        }
                    }

                    @Override
                    protected void xFinal() throws SQLException {
                        result(val);
                    }
                });
        stat.executeUpdate("create table t (c1);");
        stat.executeUpdate("insert into t values (5);");
        stat.executeUpdate("insert into t values (3);");
        stat.executeUpdate("insert into t values (8);");
        stat.executeUpdate("insert into t values (2);");
        stat.executeUpdate("insert into t values (7);");
        ResultSet rs = stat.executeQuery("select mySum(c1), sum(c1) from t;");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(rs.getInt(2));
    }

    @Test
    public void window() throws SQLException {
        Function.create(
                conn,
                "mySum",
                new Function.Window() {
                    private int val = 0;

                    @Override
                    protected void xStep() throws SQLException {
                        for (int i = 0; i < args(); i++) {
                            val += value_int(i);
                        }
                    }

                    @Override
                    protected void xInverse() throws SQLException {
                        for (int i = 0; i < args(); i++) {
                            val -= value_int(i);
                        }
                    }

                    @Override
                    protected void xValue() throws SQLException {
                        result(val);
                    }

                    @Override
                    protected void xFinal() throws SQLException {
                        result(val);
                    }
                });

        stat.executeUpdate("create table t (x);");
        stat.executeUpdate("insert into t values(1);");
        stat.executeUpdate("insert into t values(2);");
        stat.executeUpdate("insert into t values(3);");
        stat.executeUpdate("insert into t values(4);");
        stat.executeUpdate("insert into t values(5);");

        ResultSet rs =
                stat.executeQuery(
                        "select mySum(x) over (order by x rows between 1 preceding and 1 following) from t order by x;");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(3);
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(6);
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(9);
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(12);
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(9);
    }

    @Test
    public void destroy() throws SQLException {
        Function.create(
                conn,
                "f1",
                new Function() {
                    @Override
                    public void xFunc() {
                        val = 9;
                    }
                });
        stat.executeQuery("select f1();").close();
        assertThat(val).isEqualTo(9);

        Function.destroy(conn, "f1");
        Function.destroy(conn, "f1");
    }

    @Test
    public void manyfunctions() throws SQLException {
        Function.create(
                conn,
                "f1",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(1);
                    }
                });
        Function.create(
                conn,
                "f2",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(2);
                    }
                });
        Function.create(
                conn,
                "f3",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(3);
                    }
                });
        Function.create(
                conn,
                "f4",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(4);
                    }
                });
        Function.create(
                conn,
                "f5",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(5);
                    }
                });
        Function.create(
                conn,
                "f6",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(6);
                    }
                });
        Function.create(
                conn,
                "f7",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(7);
                    }
                });
        Function.create(
                conn,
                "f8",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(8);
                    }
                });
        Function.create(
                conn,
                "f9",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(9);
                    }
                });
        Function.create(
                conn,
                "f10",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(10);
                    }
                });
        Function.create(
                conn,
                "f11",
                new Function() {
                    @Override
                    public void xFunc() throws SQLException {
                        result(11);
                    }
                });

        ResultSet rs =
                stat.executeQuery(
                        "select f1() + f2() + f3() + f4() + f5() + f6()"
                                + " + f7() + f8() + f9() + f10() + f11();");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11);
        rs.close();
    }

    @Test
    public void multipleThreads() throws Exception {
        Function func =
                new Function() {
                    int sum = 0;

                    @Override
                    protected void xFunc() {
                        try {
                            sum += value_int(1);
                        } catch (SQLException e) {
                            e.printStackTrace();
                        }
                    }

                    @Override
                    public String toString() {
                        return String.valueOf(sum);
                    }
                };
        Function.create(conn, "func", func);
        stat.executeUpdate("create table foo (col integer);");
        stat.executeUpdate(
                "create trigger foo_trigger after insert on foo begin"
                        + " select func(new.rowid, new.col); end;");
        int times = 1000;
        List<Thread> threads = new LinkedList<>();
        for (int tn = 0; tn < times; tn++) {
            threads.add(
                    new Thread("func thread " + tn) {
                        @Override
                        public void run() {
                            try {
                                Statement s = conn.createStatement();
                                s.executeUpdate("insert into foo values (1);");
                                s.close();
                            } catch (SQLException e) {
                                e.printStackTrace();
                            }
                        }
                    });
        }
        for (Thread thread : threads) {
            thread.start();
        }
        for (Thread thread : threads) {
            thread.join();
        }

        // check that all of the threads successfully executed
        ResultSet rs = stat.executeQuery("select sum(col) from foo;");
        assertThat(rs.next()).isTrue();
        assertThat(rs.getInt(1)).isEqualTo(times);
        rs.close();

        // check that custom function was executed each time
        assertThat(Integer.parseInt(func.toString())).isEqualTo(times);
    }
}