PubSubHelpers.java
package redis.clients.jedis.util;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import java.nio.ByteBuffer;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import redis.clients.jedis.BinaryJedisPubSub;
import redis.clients.jedis.JedisPubSub;
/** Test utilities for asserting on Pub/Sub notifications. */
public final class PubSubHelpers {
public static final long DEFAULT_AWAIT_MILLIS = 5_000L;
private static final long SUBSCRIBED_AWAIT_SECONDS = 5L;
private PubSubHelpers() {
}
/** Asserts that the subscriber became active within 5 seconds. */
public static void awaitSubscribed(CountDownLatch subscribed) throws InterruptedException {
assertThat("subscriber did not become active",
subscribed.await(SUBSCRIBED_AWAIT_SECONDS, TimeUnit.SECONDS), equalTo(true));
}
public static byte[] concat(byte[]... parts) {
int total = 0;
for (byte[] p : parts)
total += p.length;
byte[] out = new byte[total];
int off = 0;
for (byte[] p : parts) {
System.arraycopy(p, 0, out, off, p.length);
off += p.length;
}
return out;
}
public static final class Notification {
public final String pattern;
public final String channel;
public final String message;
public Notification(String pattern, String channel, String message) {
this.pattern = pattern;
this.channel = channel;
this.message = message;
}
}
/** Routes each received message into a per-channel queue. */
public static final class CapturingPubSub extends JedisPubSub {
private final ConcurrentMap<String, BlockingQueue<Notification>> byChannel = new ConcurrentHashMap<>();
public final CountDownLatch subscribed = new CountDownLatch(1);
@Override
public void onSubscribe(String channel, int subscribedChannels) {
subscribed.countDown();
}
@Override
public void onPSubscribe(String pattern, int subscribedChannels) {
subscribed.countDown();
}
@Override
public void onMessage(String channel, String message) {
queueFor(channel).add(new Notification(null, channel, message));
}
@Override
public void onPMessage(String pattern, String channel, String message) {
queueFor(channel).add(new Notification(pattern, channel, message));
}
private BlockingQueue<Notification> queueFor(String channel) {
return byChannel.computeIfAbsent(channel, k -> new LinkedBlockingQueue<>());
}
public Notification expectMessageOn(String channel) throws InterruptedException {
return expectMessageOn(channel, DEFAULT_AWAIT_MILLIS);
}
public Notification expectMessageOn(String channel, long timeoutMillis)
throws InterruptedException {
Notification n = queueFor(channel).poll(timeoutMillis, TimeUnit.MILLISECONDS);
if (n == null)
throw new AssertionError("did not receive notification on channel: " + channel);
return n;
}
public void expectNoMessageOn(String channel, long timeout, TimeUnit unit)
throws InterruptedException {
Notification n = queueFor(channel).poll(timeout, unit);
if (n != null) {
throw new AssertionError(
"expected no message on channel '" + channel + "' but received: " + n.message);
}
}
}
/** Routes each received message into a per-channel queue (binary overload). */
public static final class CapturingBinaryPubSub extends BinaryJedisPubSub {
private final ConcurrentMap<ByteBuffer, BlockingQueue<byte[]>> byChannel = new ConcurrentHashMap<>();
public final CountDownLatch subscribed = new CountDownLatch(1);
@Override
public void onSubscribe(byte[] channel, int subscribedChannels) {
subscribed.countDown();
}
@Override
public void onPSubscribe(byte[] pattern, int subscribedChannels) {
subscribed.countDown();
}
@Override
public void onMessage(byte[] channel, byte[] message) {
queueFor(channel).add(message);
}
@Override
public void onPMessage(byte[] pattern, byte[] channel, byte[] message) {
queueFor(channel).add(message);
}
private BlockingQueue<byte[]> queueFor(byte[] channel) {
return byChannel.computeIfAbsent(ByteBuffer.wrap(channel), k -> new LinkedBlockingQueue<>());
}
public byte[] expectMessageOn(byte[] channel) throws InterruptedException {
return expectMessageOn(channel, DEFAULT_AWAIT_MILLIS);
}
public byte[] expectMessageOn(byte[] channel, long timeoutMillis) throws InterruptedException {
byte[] msg = queueFor(channel).poll(timeoutMillis, TimeUnit.MILLISECONDS);
if (msg == null) throw new AssertionError("did not receive notification on expected channel");
return msg;
}
public void expectNoMessageOn(byte[] channel, long timeout, TimeUnit unit)
throws InterruptedException {
byte[] msg = queueFor(channel).poll(timeout, unit);
if (msg != null) {
throw new AssertionError(
"expected no message on the given channel but received " + msg.length + " bytes");
}
}
}
}