TrackingConnectionPool.java

package redis.clients.jedis.mcf;

import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import redis.clients.jedis.Connection;
import redis.clients.jedis.ConnectionFactory;
import redis.clients.jedis.ConnectionPool;
import redis.clients.jedis.DefaultJedisClientConfig;
import redis.clients.jedis.DefaultJedisSocketFactory;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.JedisClientConfig;
import redis.clients.jedis.csc.CacheConnection;
import redis.clients.jedis.exceptions.JedisConnectionException;

public class TrackingConnectionPool extends ConnectionPool {

  private static class FailFastConnectionFactory extends ConnectionFactory {
    private volatile boolean failFast = false;
    private final Set<Connection> factoryTrackedObjects = ConcurrentHashMap.newKeySet();

    public FailFastConnectionFactory(ConnectionFactory.Builder factoryBuilder,
        JedisClientConfig clientConfig) {
      super(factoryBuilder
          .connectionBuilder(createCustomConnectionBuilder(factoryBuilder, clientConfig)));
    }

    private static Connection.Builder createCustomConnectionBuilder(
        ConnectionFactory.Builder factoryBuilder, JedisClientConfig clientConfig) {
      Connection.Builder connBuilder = factoryBuilder.getCache() == null ? Connection.builder()
          : CacheConnection.builder(factoryBuilder.getCache());

      return connBuilder.socketFactory(factoryBuilder.getJedisSocketFactory())
          .clientConfig(clientConfig);
    }

    @Override
    public PooledObject<Connection> makeObject() throws Exception {
      if (failFast) {
        throw new JedisConnectionException("Failed to create connection!");
      }
      try {
        PooledObject<Connection> object = super.makeObject();
        factoryTrackedObjects.add(object.getObject());
        try {
          object.getObject().initializeFromClientConfig();
        } finally {
          factoryTrackedObjects.remove(object.getObject());
        }
        // this can make a marginal improvement on fast failover duration!
        if (failFast) {
          object.getObject().close();
          throw new JedisConnectionException("Failed to create connection!");
        }
        return object;
      } catch (JedisConnectionException e) {
        throw e;
      } catch (Exception e) {
        throw new JedisConnectionException(e);
      }
    }

    public void forceDisconnect() {
      for (Connection connection : factoryTrackedObjects) {
        try {
          connection.forceDisconnect();
        } catch (Exception e) {
          log.warn("Error while force disconnecting connection: " + connection.toIdentityString(),
            e);
        }
      }
    }

  }

  public static class Builder {
    private HostAndPort hostAndPort;
    private JedisClientConfig clientConfig;
    private GenericObjectPoolConfig<Connection> poolConfig;

    public Builder hostAndPort(HostAndPort hostAndPort) {
      this.hostAndPort = hostAndPort;
      return this;
    }

    public Builder clientConfig(JedisClientConfig clientConfig) {
      this.clientConfig = clientConfig;
      return this;
    }

    public Builder poolConfig(GenericObjectPoolConfig<Connection> poolConfig) {
      this.poolConfig = poolConfig;
      return this;
    }

    public TrackingConnectionPool build() {
      applyDefaults();
      return new TrackingConnectionPool(this);
    }

    private void applyDefaults() {
      if (clientConfig == null) {
        clientConfig = DefaultJedisClientConfig.builder().build();
      }
      if (poolConfig == null) {
        poolConfig = new GenericObjectPoolConfig<>();
      }
    }
  }

  private static final Logger log = LoggerFactory.getLogger(TrackingConnectionPool.class);

  private final HostAndPort hostAndPort;
  private final JedisClientConfig clientConfig;
  private final GenericObjectPoolConfig<Connection> poolConfig;
  private final AtomicInteger numWaiters = new AtomicInteger();
  private final Set<Connection> poolTrackedObjects = ConcurrentHashMap.newKeySet();

  public static Builder builder() {
    return new Builder();
  }

  private TrackingConnectionPool(Builder builder) {
    super(createfailFastFactory(builder),
        builder.poolConfig != null ? builder.poolConfig : new GenericObjectPoolConfig<>());

    this.hostAndPort = builder.hostAndPort;
    this.clientConfig = builder.clientConfig;
    this.poolConfig = builder.poolConfig;
    this.attachAuthenticationListener(builder.clientConfig.getAuthXManager());
  }

  private static FailFastConnectionFactory createfailFastFactory(Builder poolBuilder) {
    ConnectionFactory.Builder factoryBuilder = ConnectionFactory.builder()
        .clientConfig(poolBuilder.clientConfig).socketFactory(
          new DefaultJedisSocketFactory(poolBuilder.hostAndPort, poolBuilder.clientConfig));
    return new FailFastConnectionFactory(factoryBuilder, poolBuilder.clientConfig);
  }

  public static TrackingConnectionPool from(TrackingConnectionPool existing) {
    return builder().hostAndPort(existing.hostAndPort).clientConfig(existing.clientConfig)
        .poolConfig(existing.poolConfig).build();
  }

  @Override
  public Connection getResource() {
    try {
      numWaiters.incrementAndGet();
      Connection conn = super.getResource();
      poolTrackedObjects.add(conn);
      return conn;
    } catch (Exception e) {
      if (this.isClosed()) {
        throw new JedisConnectionException("Pool is closed!", e);
      }
      throw e;
    } finally {
      numWaiters.decrementAndGet();
    }
  }

  @Override
  public void returnResource(final Connection resource) {
    super.returnResource(resource);
    poolTrackedObjects.remove(resource);
  }

  @Override
  public void returnBrokenResource(final Connection resource) {
    super.returnBrokenResource(resource);
    poolTrackedObjects.remove(resource);
  }

  public void forceDisconnect() {
    this.close();
    ((FailFastConnectionFactory) this.getFactory()).failFast = true;
    int numOfConnected = poolTrackedObjects.size();
    // we need to wait for all waiters to leave before we are done with disconnecting the
    // connections, since a user app thread might be either;
    // - in the middle of a factory call(create|init) and not yet show up in poolTrackedObjects
    // - blocked on an exhausted pool, waiting for resources to return back pool
    while (numWaiters.get() > 0 || numOfConnected > 0) {
      this.clear();
      ((FailFastConnectionFactory) this.getFactory()).forceDisconnect();
      numOfConnected = 0;
      for (Connection connection : poolTrackedObjects) {
        try {
          if (connection.isConnected()) {
            numOfConnected++;
          }
          connection.forceDisconnect();
        } catch (Exception e) {
          log.warn("Error while force disconnecting connection: " + connection.toIdentityString(),
            e);
        }
      }
      try {
        // this is just to yield the thread for a fair share of CPU
        Thread.sleep(1);
      } catch (InterruptedException e) {
      }
    }
    ((FailFastConnectionFactory) this.getFactory()).failFast = false;
  }

  @Override
  public void close() {
    this.destroy();
    this.detachAuthenticationListener();
  }
}