TcpMockServer.java

package redis.clients.jedis.util.server;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.CommandArguments;
import redis.clients.jedis.Protocol;
import redis.clients.jedis.commands.ProtocolCommand;
import redis.clients.jedis.util.RedisInputStream;
import redis.clients.jedis.util.RedisOutputStream;
import redis.clients.jedis.util.SafeEncoder;

import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * A simple TCP mock server for testing Redis push notifications and timeout behavior. This server
 * can accept connections and send predefined responses including push messages.
 */
public class TcpMockServer {
  private final AtomicBoolean running = new AtomicBoolean(false);
  private final ExecutorService executor = Executors.newCachedThreadPool();
  private final Map<String, ClientHandler> connectedClients = new ConcurrentHashMap<>();
  Logger logger = LoggerFactory.getLogger(TcpMockServer.class);
  private ServerSocket serverSocket;
  private int port;
  private CommandHandler commandHandler;

  /**
   * Start the server on an available port
   */
  public void start() throws IOException {
    start(0); // Use any available port
  }

  /**
   * Start the server on a specific port
   */
  public void start(int port) throws IOException {
    serverSocket = new ServerSocket(port);
    this.port = serverSocket.getLocalPort();
    running.set(true);

    executor.submit(() -> {
      while (running.get() && !serverSocket.isClosed()) {
        try {
          Socket clientSocket = serverSocket.accept();
          executor.submit(new ClientHandler(clientSocket));
        } catch (IOException e) {
          if (running.get()) {
            logger.error("Error accepting client connection: " + e.getMessage());
          }
        }
      }
    });
  }

  /**
   * Stop the server and close all active connections
   */
  public void stop() throws IOException {
    running.set(false);

    // Close all active client connections first
    closeAllActiveConnections();

    // Close the server socket
    if (serverSocket != null && !serverSocket.isClosed()) {
      serverSocket.close();
    }
    executor.shutdownNow();
  }

  /**
   * Get the port the server is running on
   */
  public int getPort() {
    return port;
  }

  /**
   * Check if the server is running
   */
  public boolean isRunning() {
    return running.get() && serverSocket != null && !serverSocket.isClosed();
  }

  /**
   * Get the number of connected clients
   */
  public int getConnectedClientCount() {
    return connectedClients.size();
  }

  /**
   * Generic method to send a push message to all connected clients.
   * @param pushType the type of push message (e.g., "MIGRATING", "MIGRATED")
   * @param args optional arguments for the push message
   */
  public void sendPushMessageToAll(String pushType, String... args) {
    connectedClients.values().forEach(client -> client.sendPushMessage(pushType, args));
  }

  /**
   * Send a raw RESP3 message to all connected clients. This allows sending properly formatted push
   * messages.
   * @param rawMessage the raw RESP3 protocol message to send
   */
  public void sendRawPushMessageToAll(String rawMessage) {
    connectedClients.values().forEach(client -> client.sendRawPushMessage(rawMessage));
  }

  /**
   * Get the current command handler.
   * @return The current command handler, or null if none is set
   */
  public CommandHandler getCommandHandler() {
    return commandHandler;
  }

  /**
   * Set a custom command handler for processing Redis commands.
   * @param commandHandler The command handler to use, or null to use only built-in handlers
   */
  public void setCommandHandler(CommandHandler commandHandler) {
    this.commandHandler = commandHandler;
  }

  /**
   * Close all active client connections
   */
  private void closeAllActiveConnections() {
    // Create a copy of the values to avoid ConcurrentModificationException
    java.util.List<ClientHandler> clientsToClose = new java.util.ArrayList<>(
        connectedClients.values());

    for (ClientHandler client : clientsToClose) {
      try {
        client.forceClose();
      } catch (Exception e) {
        logger.error("Error closing client connection: " + e.getMessage());
      }
    }

    // Clear the map
    connectedClients.clear();
  }

  /**
   * Static registry of built-in command responses (shared across all client handlers). Commands are
   * stored as CommandKey (command + optional subcommand) for smart lookup.
   */
  private static final java.util.Map<CommandKey, String> BUILTIN_RESPONSES;

  static {
    java.util.Map<CommandKey, String> responses = new java.util.HashMap<>();

    // RESP3 HELLO response - version 7.4.0 to support client-side caching
    responses.put(new CommandKey(Protocol.Command.HELLO),
      "%7\r\n" + "$6\r\nserver\r\n$5\r\nredis\r\n" + "$7\r\nversion\r\n$5\r\n7.4.0\r\n"
          + "$5\r\nproto\r\n:3\r\n" + "$2\r\nid\r\n:1\r\n" + "$4\r\nmode\r\n$10\r\nstandalone\r\n"
          + "$4\r\nrole\r\n$6\r\nmaster\r\n" + "$7\r\nmodules\r\n*0\r\n");

    responses.put(new CommandKey(Protocol.Command.PING), "+PONG\r\n");

    // CLIENT subcommands
    responses.put(new CommandKey(Protocol.Command.CLIENT, Protocol.Keyword.SETNAME), "+OK\r\n");
    responses.put(new CommandKey(Protocol.Command.CLIENT, Protocol.Keyword.SETINFO), "+OK\r\n");
    responses.put(new CommandKey("CLIENT", "TRACKING"), "+OK\r\n");

    BUILTIN_RESPONSES = java.util.Collections.unmodifiableMap(responses);
  }

  /**
   * Simple implementation of ProtocolCommand for unknown commands.
   */
  private static class SimpleProtocolCommand implements ProtocolCommand {
    private final byte[] raw;
    private final int hashCode;

    public SimpleProtocolCommand(String command) {
      this.raw = SafeEncoder.encode(command);
      this.hashCode = Arrays.hashCode(raw);
    }

    @Override
    public byte[] getRaw() {
      return raw;
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) return true;
      if (!(o instanceof ProtocolCommand)) return false;
      ProtocolCommand that = (ProtocolCommand) o;
      return Arrays.equals(raw, that.getRaw());
    }

    @Override
    public int hashCode() {
      return hashCode;
    }

    @Override
    public String toString() {
      return SafeEncoder.encode(raw);
    }
  }

  /**
   * Key for command lookup in the registry. Supports command + optional subcommand.
   */
  private static class CommandKey {
    private final String command;
    private final String subcommand;
    private final int hashCode;

    public CommandKey(ProtocolCommand command) {
      this(command, null);
    }

    public CommandKey(ProtocolCommand command, redis.clients.jedis.args.Rawable subcommand) {
      this.command = SafeEncoder.encode(command.getRaw()).toUpperCase();
      this.subcommand = subcommand != null ? SafeEncoder.encode(subcommand.getRaw()).toUpperCase()
          : null;
      this.hashCode = java.util.Objects.hash(this.command, this.subcommand);
    }

    public CommandKey(String command, String subcommand) {
      this.command = command.toUpperCase();
      this.subcommand = subcommand != null ? subcommand.toUpperCase() : null;
      this.hashCode = java.util.Objects.hash(this.command, this.subcommand);
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) return true;
      if (!(o instanceof CommandKey)) return false;
      CommandKey that = (CommandKey) o;
      return command.equals(that.command) && java.util.Objects.equals(subcommand, that.subcommand);
    }

    @Override
    public int hashCode() {
      return hashCode;
    }

    @Override
    public String toString() {
      return subcommand != null ? command + " " + subcommand : command;
    }
  }

  /**
   * Client handler for each connection
   */
  private class ClientHandler implements Runnable {
    private final Socket clientSocket;
    private final String clientId;
    private RedisOutputStream outputStream;
    private volatile boolean connected = true;
    private final Object outputLock = new Object(); // Lock to prevent interleaving

    public ClientHandler(Socket clientSocket) {
      this.clientSocket = clientSocket;
      this.clientId = clientSocket.getRemoteSocketAddress().toString();
    }

    @Override
    public void run() {
      try (RedisInputStream rin = new RedisInputStream(clientSocket.getInputStream());
          RedisOutputStream out = new RedisOutputStream(clientSocket.getOutputStream())) {

        this.outputStream = out;
        connectedClients.put(clientId, this);

        Object input;
        while (connected && !clientSocket.isClosed()) {
          try {
            input = Protocol.read(rin);
            if (input == null) {
              connected = false;
              break;
            }

            // Deserialize into CommandArguments
            List<byte[]> rawArgs = (List<byte[]>) input;
            CommandArguments commandArgs = deserializeToCommandArguments(rawArgs);

            // Process command with custom handler or built-in responses
            processCommand(commandArgs);
          } catch (IOException e) {
            logger.debug("Client " + clientId + " disconnected: " + e.getMessage());
            connected = false;
            break;
          } catch (Exception e) {
            logger.debug("Client " + clientId + " connection error: " + e.getMessage());
            connected = false;
            break;
          }
        }
      } catch (IOException e) {
        logger.error("Error handling client: " + e.getMessage());
      } finally {
        cleanup();
      }
    }

    /**
     * Deserialize raw byte arrays into CommandArguments. First element is the command, rest are
     * arguments.
     */
    private CommandArguments deserializeToCommandArguments(List<byte[]> rawArgs) {
      if (rawArgs == null || rawArgs.isEmpty()) {
        throw new IllegalArgumentException("Empty command");
      }

      // First element is the command - try to match it to a known ProtocolCommand
      String cmdString = SafeEncoder.encode(rawArgs.get(0)).toUpperCase();
      ProtocolCommand command = findProtocolCommand(cmdString);

      // If no known command found, create a simple wrapper
      if (command == null) {
        command = new SimpleProtocolCommand(cmdString);
      }

      // Create CommandArguments with the command
      CommandArguments commandArgs = new CommandArguments(command);

      // Add remaining arguments
      for (int i = 1; i < rawArgs.size(); i++) {
        commandArgs.add(rawArgs.get(i));
      }

      return commandArgs;
    }

    /**
     * Try to find a matching ProtocolCommand from Protocol.Command enum.
     */
    private ProtocolCommand findProtocolCommand(String cmdString) {
      try {
        return Protocol.Command.valueOf(cmdString);
      } catch (IllegalArgumentException e) {
        // Not a standard command, return null
        return null;
      }
    }

    /**
     * Process a command by first checking if a custom command handler is available, otherwise
     * falling back to predefined built-in responses.
     * @param commandArgs the command arguments
     * @throws IOException if writing the response fails
     */
    private void processCommand(CommandArguments commandArgs) throws IOException {
      String response = null;

      // First, try custom command handler if available
      if (commandHandler != null) {
        response = commandHandler.handleCommand(commandArgs, clientId);
      }

      // If no custom handler or it returned null, fall back to built-in responses
      if (response == null) {
        response = getBuiltinResponse(commandArgs);
      }

      // Write the response
      if (response != null) {
        writeResponse(response);
      } else {
        throw new RuntimeException("Unknown command: " + commandArgs.getCommand());
      }
    }

    /**
     * Synchronized method to write response to output stream. This ensures thread-safe access to
     * the non-thread-safe RedisOutputStream.
     */
    private void writeResponse(String response) throws IOException {
      synchronized (outputLock) {
        if (outputStream != null && connected) {
          outputStream.write(response.getBytes());
          outputStream.flush();
        }
      }
    }

    /**
     * Get the built-in response for a command from the response registry. Uses smart lookup: tries
     * command + first argument first, then just command.
     * @param commandArgs the command arguments
     * @return the response string, or null if no built-in response exists
     */
    private String getBuiltinResponse(CommandArguments commandArgs) {
      String cmdString = SafeEncoder.encode(commandArgs.getCommand().getRaw());
      String subcommand = null;

      // Extract subcommand if present (first argument)
      if (commandArgs.size() > 1) {
        subcommand = SafeEncoder.encode(commandArgs.get(1).getRaw());
      }

      // Try lookup with command + subcommand first (e.g., "CLIENT SETNAME")
      if (subcommand != null) {
        CommandKey key = new CommandKey(cmdString, subcommand);
        String response = BUILTIN_RESPONSES.get(key);
        if (response != null) {
          return response;
        }
      }

      // Fall back to command only (e.g., "CLIENT")
      CommandKey key = new CommandKey(cmdString, null);
      return BUILTIN_RESPONSES.get(key);
    }

    /**
     * Clean up client resources and remove from connected clients map
     */
    private void cleanup() {
      connected = false;
      connectedClients.remove(clientId);

      // Synchronize to ensure no push message is being sent while we clean up
      synchronized (outputLock) {
        outputStream = null;
      }

      try {
        if (clientSocket != null && !clientSocket.isClosed()) {
          clientSocket.close();
        }
      } catch (IOException e) {
        logger.error("Error closing client socket during cleanup: " + e.getMessage());
      }
    }

    /**
     * Generic method to send a push message to this client. According to RESP3 spec, push messages
     * may precede or follow command replies, but must not interleave with them. We use
     * synchronization to ensure this.
     * @param pushType the type of push message (e.g., "MIGRATING", "MIGRATED")
     * @param args optional arguments for the push message
     */
    public void sendPushMessage(String pushType, String... args) {
      try {
        StringBuilder pushMessage = new StringBuilder();

        // Calculate total number of elements (push type + arguments)
        int elementCount = 1 + args.length;
        pushMessage.append(">").append(elementCount).append("\r\n");

        // Add push type
        pushMessage.append("$").append(pushType.length()).append("\r\n").append(pushType)
            .append("\r\n");

        // Add arguments
        for (String arg : args) {
          pushMessage.append("$").append(arg.length()).append("\r\n").append(arg).append("\r\n");
        }

        // Use synchronized writeResponse method to prevent interleaving
        writeResponse(pushMessage.toString());

      } catch (IOException e) {
        logger.error("Error sending " + pushType + " push to " + clientId
            + " (client disconnected): " + e.getMessage());
        cleanup();
      }
    }

    /**
     * Send a raw RESP3 message to this client. This allows sending properly formatted push
     * messages..
     * @param rawMessage the raw RESP3 protocol message to send
     */
    public void sendRawPushMessage(String rawMessage) {
      try {
        // Use synchronized writeResponse method to prevent interleaving
        writeResponse(rawMessage);
      } catch (IOException e) {
        logger.error(
          "Error sending raw message to " + clientId + " (client disconnected): " + e.getMessage());
        cleanup();
      }
    }

    /**
     * Force close this client connection (used when server is shutting down)
     */
    public void forceClose() {
      connected = false;

      try {
        if (clientSocket != null && !clientSocket.isClosed()) {
          clientSocket.close();
        }
      } catch (IOException e) {
        logger.error("Error force closing client socket: " + e.getMessage());
      }

      // Remove from connected clients map
      connectedClients.remove(clientId);
      outputStream = null;
    }

  }

}