PublishSubscribeCommandsTest.java

package redis.clients.jedis.commands.jedis;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasItems;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static redis.clients.jedis.Protocol.Command.CLIENT;

import java.io.IOException;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

import org.hamcrest.Matchers;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedClass;
import org.junit.jupiter.params.provider.MethodSource;

import redis.clients.jedis.BinaryJedisPubSub;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPubSub;
import redis.clients.jedis.RedisProtocol;
import redis.clients.jedis.exceptions.JedisException;
import redis.clients.jedis.util.SafeEncoder;

@ParameterizedClass
@MethodSource("redis.clients.jedis.commands.CommandsTestsParameters#respVersions")
public class PublishSubscribeCommandsTest extends JedisCommandsTestBase {

  public PublishSubscribeCommandsTest(RedisProtocol protocol) {
    super(protocol);
  }

  private void publishOne(final String channel, final String message) {
    Thread t = new Thread(new Runnable() {
      public void run() {
        try ( Jedis j = createJedis()) {
          j.publish(channel, message);
          j.disconnect();
        } catch (Exception ex) {
          // ignore
        }
      }
    });
    t.start();
  }

  @Test
  public void subscribe() throws InterruptedException {
    jedis.subscribe(new JedisPubSub() {
      public void onMessage(String channel, String message) {
        assertEquals("foo", channel);
        assertEquals("exit", message);
        unsubscribe();
      }

      public void onSubscribe(String channel, int subscribedChannels) {
        assertEquals("foo", channel);
        assertEquals(1, subscribedChannels);

        // now that I'm subscribed... publish
        publishOne("foo", "exit");
      }

      public void onUnsubscribe(String channel, int subscribedChannels) {
        assertEquals("foo", channel);
        assertEquals(0, subscribedChannels);
      }
    }, "foo");
  }

  @Test
  public void pubSubChannels() {
    jedis.subscribe(new JedisPubSub() {
      private int count = 0;

      @Override
      public void onSubscribe(String channel, int subscribedChannels) {
        count++;
        // All channels are subscribed
        if (count == 3) {
          List<String> activeChannels;
          try (Jedis otherJedis = createJedis()) {
            activeChannels = otherJedis.pubsubChannels();
          }
          // Since we are utilizing sentinel for the tests, there is an additional
          // '__sentinel__:hello' channel that has subscribers and will be returned
          // from PUBSUB CHANNELS.
          assertThat(activeChannels, hasItems("testchan1", "testchan2", "testchan3"));
          unsubscribe();
        }
      }
    }, "testchan1", "testchan2", "testchan3");
  }

  @Test
  public void pubSubChannelsWithPattern() {
    jedis.subscribe(new JedisPubSub() {
      private int count = 0;

      @Override
      public void onSubscribe(String channel, int subscribedChannels) {
        count++;
        // All channels are subscribed
        if (count == 3) {
          List<String> activeChannels;
          try (Jedis otherJedis = createJedis()) {
            activeChannels = otherJedis.pubsubChannels("test*");
          }
          assertThat(activeChannels, hasItems("testchan1", "testchan2", "testchan3"));
          unsubscribe();
        }
      }
    }, "testchan1", "testchan2", "testchan3");
  }

  @Test
  public void pubSubChannelWithPingPong() throws InterruptedException {
    final CountDownLatch latchUnsubscribed = new CountDownLatch(1);
    final CountDownLatch latchReceivedPong = new CountDownLatch(1);
    jedis.subscribe(new JedisPubSub() {

      @Override
      public void onSubscribe(String channel, int subscribedChannels) {
        publishOne("testchan1", "hello");
      }

      @Override
      public void onMessage(String channel, String message) {
        this.ping();
      }

      @Override
      public void onPong(String pattern) {
        latchReceivedPong.countDown();
        unsubscribe();
      }

      @Override
      public void onUnsubscribe(String channel, int subscribedChannels) {
        latchUnsubscribed.countDown();
      }
    }, "testchan1");
    assertEquals(0L, latchReceivedPong.getCount());
    assertEquals(0L, latchUnsubscribed.getCount());
  }

  @Test
  public void pubSubChannelWithPingPongWithArgument() throws InterruptedException {
    final CountDownLatch latchUnsubscribed = new CountDownLatch(1);
    final CountDownLatch latchReceivedPong = new CountDownLatch(1);
    final List<String> pongPatterns = new ArrayList<>();
    jedis.subscribe(new JedisPubSub() {

      @Override
      public void onSubscribe(String channel, int subscribedChannels) {
        publishOne("testchan1", "hello");
      }

      @Override
      public void onMessage(String channel, String message) {
        this.ping("hi!");
      }

      @Override
      public void onPong(String pattern) {
        pongPatterns.add(pattern);
        latchReceivedPong.countDown();
        unsubscribe();
      }

      @Override
      public void onUnsubscribe(String channel, int subscribedChannels) {
        latchUnsubscribed.countDown();
      }
    }, "testchan1");

    assertEquals(0L, latchReceivedPong.getCount());
    assertEquals(0L, latchUnsubscribed.getCount());
    assertEquals(Collections.singletonList("hi!"), pongPatterns);
  }

  @Test
  public void pubSubNumPat() {
    jedis.psubscribe(new JedisPubSub() {
      private int count = 0;

      @Override
      public void onPSubscribe(String pattern, int subscribedChannels) {
        count++;
        if (count == 3) {
          Long numPatterns;
          try (Jedis otherJedis = createJedis()) {
            numPatterns = otherJedis.pubsubNumPat();
          }
          assertEquals(Long.valueOf(2L), numPatterns);
          punsubscribe();
        }
      }

    }, "test*", "test*", "chan*");
  }

  @Test
  public void pubSubNumSub() {
    final Map<String, Long> expectedNumSub = new HashMap<>();
    expectedNumSub.put("testchannel2", 1L);
    expectedNumSub.put("testchannel1", 1L);
    jedis.subscribe(new JedisPubSub() {
      private int count = 0;

      @Override
      public void onSubscribe(String channel, int subscribedChannels) {
        count++;
        if (count == 2) {
          Map<String, Long> numSub;
          try (Jedis otherJedis = createJedis()) {
            numSub = otherJedis.pubsubNumSub("testchannel1", "testchannel2");
          }
          assertEquals(expectedNumSub, numSub);
          unsubscribe();
        }
      }
    }, "testchannel1", "testchannel2");
  }

  @Test
  public void subscribeMany() throws UnknownHostException, IOException, InterruptedException {
    jedis.subscribe(new JedisPubSub() {
      public void onMessage(String channel, String message) {
        unsubscribe(channel);
      }

      public void onSubscribe(String channel, int subscribedChannels) {
        publishOne(channel, "exit");
      }

    }, "foo", "bar");
  }

  @Test
  public void psubscribe() throws UnknownHostException, IOException, InterruptedException {
    jedis.psubscribe(new JedisPubSub() {
      public void onPSubscribe(String pattern, int subscribedChannels) {
        assertEquals("foo.*", pattern);
        assertEquals(1, subscribedChannels);
        publishOne("foo.bar", "exit");

      }

      public void onPUnsubscribe(String pattern, int subscribedChannels) {
        assertEquals("foo.*", pattern);
        assertEquals(0, subscribedChannels);
      }

      public void onPMessage(String pattern, String channel, String message) {
        assertEquals("foo.*", pattern);
        assertEquals("foo.bar", channel);
        assertEquals("exit", message);
        punsubscribe();
      }
    }, "foo.*");
  }

  @Test
  public void psubscribeMany() throws UnknownHostException, IOException, InterruptedException {
    jedis.psubscribe(new JedisPubSub() {
      public void onPSubscribe(String pattern, int subscribedChannels) {
        publishOne(pattern.replace("*", "123"), "exit");
      }

      public void onPMessage(String pattern, String channel, String message) {
        punsubscribe(pattern);
      }
    }, "foo.*", "bar.*");
  }

  @Test
  public void subscribeLazily() throws UnknownHostException, IOException, InterruptedException {
    final JedisPubSub pubsub = new JedisPubSub() {
      public void onMessage(String channel, String message) {
        unsubscribe(channel);
      }

      public void onSubscribe(String channel, int subscribedChannels) {
        publishOne(channel, "exit");
        if (!channel.equals("bar")) {
          this.subscribe("bar");
          this.psubscribe("bar.*");
        }
      }

      public void onPSubscribe(String pattern, int subscribedChannels) {
        publishOne(pattern.replace("*", "123"), "exit");
      }

      public void onPMessage(String pattern, String channel, String message) {
        punsubscribe(pattern);
      }
    };

    jedis.subscribe(pubsub, "foo");
  }

  @Test
  public void binarySubscribe() throws UnknownHostException, IOException, InterruptedException {
    jedis.subscribe(new BinaryJedisPubSub() {
      public void onMessage(byte[] channel, byte[] message) {
        assertArrayEquals(SafeEncoder.encode("foo"), channel);
        assertArrayEquals(SafeEncoder.encode("exit"), message);
        unsubscribe();
      }

      public void onSubscribe(byte[] channel, int subscribedChannels) {
        assertArrayEquals(SafeEncoder.encode("foo"), channel);
        assertEquals(1, subscribedChannels);
        publishOne(SafeEncoder.encode(channel), "exit");
      }

      public void onUnsubscribe(byte[] channel, int subscribedChannels) {
        assertArrayEquals(SafeEncoder.encode("foo"), channel);
        assertEquals(0, subscribedChannels);
      }
    }, SafeEncoder.encode("foo"));
  }

  @Test
  public void binarySubscribeMany() throws UnknownHostException, IOException, InterruptedException {
    jedis.subscribe(new BinaryJedisPubSub() {
      public void onMessage(byte[] channel, byte[] message) {
        unsubscribe(channel);
      }

      public void onSubscribe(byte[] channel, int subscribedChannels) {
        publishOne(SafeEncoder.encode(channel), "exit");
      }
    }, SafeEncoder.encode("foo"), SafeEncoder.encode("bar"));
  }

  @Test
  public void binaryPsubscribe() throws UnknownHostException, IOException, InterruptedException {
    jedis.psubscribe(new BinaryJedisPubSub() {
      public void onPSubscribe(byte[] pattern, int subscribedChannels) {
        assertArrayEquals(SafeEncoder.encode("foo.*"), pattern);
        assertEquals(1, subscribedChannels);
        publishOne(SafeEncoder.encode(pattern).replace("*", "bar"), "exit");
      }

      public void onPUnsubscribe(byte[] pattern, int subscribedChannels) {
        assertArrayEquals(SafeEncoder.encode("foo.*"), pattern);
        assertEquals(0, subscribedChannels);
      }

      public void onPMessage(byte[] pattern, byte[] channel, byte[] message) {
        assertArrayEquals(SafeEncoder.encode("foo.*"), pattern);
        assertArrayEquals(SafeEncoder.encode("foo.bar"), channel);
        assertArrayEquals(SafeEncoder.encode("exit"), message);
        punsubscribe();
      }
    }, SafeEncoder.encode("foo.*"));
  }

  @Test
  public void binaryPsubscribeMany() throws UnknownHostException, IOException, InterruptedException {
    jedis.psubscribe(new BinaryJedisPubSub() {
      public void onPSubscribe(byte[] pattern, int subscribedChannels) {
        publishOne(SafeEncoder.encode(pattern).replace("*", "123"), "exit");
      }

      public void onPMessage(byte[] pattern, byte[] channel, byte[] message) {
        punsubscribe(pattern);
      }
    }, SafeEncoder.encode("foo.*"), SafeEncoder.encode("bar.*"));
  }

  @Test
  public void binaryPubSubChannelWithPingPong() throws InterruptedException {
    final CountDownLatch latchUnsubscribed = new CountDownLatch(1);
    final CountDownLatch latchReceivedPong = new CountDownLatch(1);

    jedis.subscribe(new BinaryJedisPubSub() {

      @Override
      public void onSubscribe(byte[] channel, int subscribedChannels) {
        publishOne("testchan1", "hello");
      }

      @Override
      public void onMessage(byte[] channel, byte[] message) {
        this.ping();
      }

      @Override
      public void onPong(byte[] pattern) {
        latchReceivedPong.countDown();
        unsubscribe();
      }

      @Override
      public void onUnsubscribe(byte[] channel, int subscribedChannels) {
        latchUnsubscribed.countDown();
      }
    }, SafeEncoder.encode("testchan1"));
    assertEquals(0L, latchReceivedPong.getCount());
    assertEquals(0L, latchUnsubscribed.getCount());
  }

  @Test
  public void binaryPubSubChannelWithPingPongWithArgument() throws InterruptedException {
    final CountDownLatch latchUnsubscribed = new CountDownLatch(1);
    final CountDownLatch latchReceivedPong = new CountDownLatch(1);
    final List<byte[]> pongPatterns = new ArrayList<>();
    final byte[] pingMessage = SafeEncoder.encode("hi!");

    jedis.subscribe(new BinaryJedisPubSub() {

      @Override
      public void onSubscribe(byte[] channel, int subscribedChannels) {
        publishOne("testchan1", "hello");
      }

      @Override
      public void onMessage(byte[] channel, byte[] message) {
        this.ping(pingMessage);
      }

      @Override
      public void onPong(byte[] pattern) {
        pongPatterns.add(pattern);
        latchReceivedPong.countDown();
        unsubscribe();
      }

      @Override
      public void onUnsubscribe(byte[] channel, int subscribedChannels) {
        latchUnsubscribed.countDown();
      }
    }, SafeEncoder.encode("testchan1"));

    assertEquals(0L, latchReceivedPong.getCount());
    assertEquals(0L, latchUnsubscribed.getCount());
    assertArrayEquals(pingMessage, pongPatterns.get(0));
  }

  @Test
  public void binarySubscribeLazily() throws UnknownHostException, IOException,
      InterruptedException {
    final BinaryJedisPubSub pubsub = new BinaryJedisPubSub() {
      public void onMessage(byte[] channel, byte[] message) {
        unsubscribe(channel);
      }

      public void onSubscribe(byte[] channel, int subscribedChannels) {
        publishOne(SafeEncoder.encode(channel), "exit");

        if (!SafeEncoder.encode(channel).equals("bar")) {
          this.subscribe(SafeEncoder.encode("bar"));
          this.psubscribe(SafeEncoder.encode("bar.*"));
        }
      }

      public void onPSubscribe(byte[] pattern, int subscribedChannels) {
        publishOne(SafeEncoder.encode(pattern).replace("*", "123"), "exit");
      }

      public void onPMessage(byte[] pattern, byte[] channel, byte[] message) {
        punsubscribe(pattern);
      }
    };

    jedis.subscribe(pubsub, SafeEncoder.encode("foo"));
  }

  @Test
  public void unsubscribeWhenNotSusbscribed() throws InterruptedException {
    JedisPubSub pubsub = new JedisPubSub() {
    };
    assertThrows(JedisException.class, pubsub::unsubscribe);
  }

  @Test
  public void handleClientOutputBufferLimitForSubscribeTooSlow() throws InterruptedException {
    assertThrows(JedisException.class, () -> {
      final Jedis j = createJedis();
      final AtomicBoolean exit = new AtomicBoolean(false);

      final Thread t = new Thread(new Runnable() {
        public void run() {
          try {

            // we already set jedis1 config to
            // client-output-buffer-limit pubsub 256k 128k 5
            // it means if subscriber delayed to receive over 256k or
            // 128k continuously 5 sec,
            // redis disconnects subscriber

            // we publish over 100M data for making situation for exceed
            // client-output-buffer-limit
            String veryLargeString = makeLargeString(10485760);

            // 10M * 10 = 100M
            for (int i = 0; i < 10 && !exit.get(); i++) {
              j.publish("foo", veryLargeString);
            }

            j.disconnect();
          } catch (Exception ex) {
          }
        }
      });
      t.start();
      try {
        jedis.subscribe(new JedisPubSub() {
          public void onMessage(String channel, String message) {
            try {
              // wait 0.5 secs to slow down subscribe and
              // client-output-buffer exceed
              Thread.sleep(100);
            } catch (Exception e) {
              try {
                t.join();
              } catch (InterruptedException e1) {
              }

              fail(e.getMessage());
            }
          }
        }, "foo");
      } finally {
        // exit the publisher thread. if exception is thrown, thread might
        // still keep publishing things.
        exit.set(true);
        if (t.isAlive()) {
          t.join();
        }
      }
    });
  }

  private String makeLargeString(int size) {
    StringBuffer sb = new StringBuffer();
    for (int i = 0; i < size; i++)
      sb.append((char) ('a' + i % 26));

    return sb.toString();
  }

  @Test
  @Timeout(5)
  public void subscribeCacheInvalidateChannel() {
    assumeTrue(protocol != RedisProtocol.RESP3);


    final String cacheInvalidate = "__redis__:invalidate";
    final AtomicBoolean onMessage = new AtomicBoolean(false);
    final JedisPubSub pubsub = new JedisPubSub() {
      @Override public void onMessage(String channel, String message) {
        onMessage.set(true);
        assertEquals(cacheInvalidate, channel);
        if (message != null) {
          assertEquals("foo", message);
          consumeJedis(j -> j.flushAll());
        } else {
          unsubscribe(channel);
        }
      }

      @Override public void onSubscribe(String channel, int subscribedChannels) {
        assertEquals(cacheInvalidate, channel);
        consumeJedis(j -> j.set("foo", "bar"));
      }
    };

    try (Jedis subscriber = createJedis()) {
      long clientId = subscriber.clientId();
      subscriber.sendCommand(CLIENT, "TRACKING", "ON", "REDIRECT", Long.toString(clientId), "BCAST");
      subscriber.subscribe(pubsub, cacheInvalidate);
      assertTrue(onMessage.get(), "Subscriber didn't get any message.");
    }
  }

  @Test
  @Timeout(5)
  public void subscribeCacheInvalidateChannelBinary() {
    assumeTrue(protocol != RedisProtocol.RESP3);

    final byte[] cacheInvalidate = "__redis__:invalidate".getBytes();
    final AtomicBoolean onMessage = new AtomicBoolean(false);
    final BinaryJedisPubSub pubsub = new BinaryJedisPubSub() {
      @Override public void onMessage(byte[] channel, byte[] message) {
        onMessage.set(true);
        assertArrayEquals(cacheInvalidate, channel);
        if (message != null) {
          assertArrayEquals("foo".getBytes(), message);
          consumeJedis(j -> j.flushAll());
        } else {
          unsubscribe(channel);
        }
      }

      @Override public void onSubscribe(byte[] channel, int subscribedChannels) {
        assertArrayEquals(cacheInvalidate, channel);
        consumeJedis(j -> j.set("foo".getBytes(), "bar".getBytes()));
      }
    };

    try (Jedis subscriber = createJedis()) {
      long clientId = subscriber.clientId();
      subscriber.sendCommand(CLIENT, "TRACKING", "ON", "REDIRECT", Long.toString(clientId), "BCAST");
      subscriber.subscribe(pubsub, cacheInvalidate);
      assertTrue(onMessage.get(), "Subscriber didn't get any message.");
    }
  }

  private void consumeJedis(Consumer<Jedis> consumer) {
    Thread t = new Thread(() -> consumer.accept(jedis));
    t.start();
  }
}