ClientMessage.java

// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2025 MariaDB Corporation Ab
package org.mariadb.jdbc.message;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.function.Consumer;
import java.util.regex.Pattern;
import org.mariadb.jdbc.BasePreparedStatement;
import org.mariadb.jdbc.Statement;
import org.mariadb.jdbc.client.*;
import org.mariadb.jdbc.client.impl.StandardReadableByteBuf;
import org.mariadb.jdbc.client.result.CompleteResult;
import org.mariadb.jdbc.client.result.StreamingResult;
import org.mariadb.jdbc.client.result.UpdatableResult;
import org.mariadb.jdbc.client.socket.Reader;
import org.mariadb.jdbc.client.socket.Writer;
import org.mariadb.jdbc.client.util.ClosableLock;
import org.mariadb.jdbc.client.util.Parameters;
import org.mariadb.jdbc.export.ExceptionFactory;
import org.mariadb.jdbc.message.server.ErrorPacket;
import org.mariadb.jdbc.message.server.OkPacket;
import org.mariadb.jdbc.util.constants.ServerStatus;

public interface ClientMessage {

  /**
   * Check that file requested correspond to request.
   *
   * @param sql current command sql
   * @param parameters current command parameter
   * @param fileName file path request
   * @param context current connection context
   * @return true if file name correspond to demand and query is a load local infile
   */
  static boolean validateLocalFileName(
      String sql, Parameters parameters, String fileName, Context context) {
    String reg =
        "^((\\s[--]|#).*(\\r"
            + "\\n"
            + "|\\r"
            + "|\\n"
            + ")|\\s*/\\*([^*]|\\*[^/])*\\*/|.)*\\s*LOAD\\s+(DATA|XML)\\s+((LOW_PRIORITY|CONCURRENT)\\s+)?LOCAL\\s+INFILE\\s+'"
            + Pattern.quote(fileName.replace("\\", "\\\\"))
            + "'";

    Pattern pattern = Pattern.compile(reg, Pattern.CASE_INSENSITIVE);
    if (pattern.matcher(sql).find()) {
      return true;
    }

    if (parameters != null) {
      pattern =
          Pattern.compile(
              "^((\\s[--]|#).*(\\r"
                  + "\\n"
                  + "|\\r"
                  + "|\\n"
                  + ")|\\s*/\\*([^*]|\\*[^/])*\\*/|.)*\\s*LOAD\\s+(DATA|XML)\\s+((LOW_PRIORITY|CONCURRENT)\\s+)?LOCAL\\s+INFILE\\s+\\?",
              Pattern.CASE_INSENSITIVE);
      if (pattern.matcher(sql).find() && parameters.size() > 0) {
        String paramString = parameters.get(0).bestEffortStringValue(context);
        if (paramString != null) {
          return paramString.equalsIgnoreCase("'" + fileName.replace("\\", "\\\\") + "'");
        }
        return true;
      }
    }
    return false;
  }

  /**
   * Encode client message to socket.
   *
   * @param writer socket writer
   * @param context connection context
   * @return number of client message written
   * @throws IOException if socket error occur
   * @throws SQLException if any issue occurs
   */
  int encode(Writer writer, Context context) throws IOException, SQLException;

  /**
   * Number of parameter rows, and so expected return length
   *
   * @return batch update length
   */
  default int batchUpdateLength() {
    return 0;
  }

  /**
   * Message description
   *
   * @return description
   */
  default String description() {
    return null;
  }

  /**
   * Are return value encoded in binary protocol
   *
   * @return use binary protocol
   */
  default boolean binaryProtocol() {
    return false;
  }

  /**
   * Can skip metadata
   *
   * @return can skip metadata
   */
  default boolean canSkipMeta() {
    return false;
  }

  /**
   * default packet resultset parser
   *
   * @param stmt caller
   * @param fetchSize fetch size
   * @param maxRows maximum number of rows
   * @param resultSetConcurrency resultset concurrency
   * @param resultSetType resultset type
   * @param closeOnCompletion must close caller on result parsing end
   * @param reader packet reader
   * @param writer packet writer
   * @param context connection context
   * @param exceptionFactory connection exception factory
   * @param lock thread safe locks
   * @param traceEnable is logging trace enable
   * @param message client message
   * @param redirectFct redirect consumer
   * @return results
   * @throws IOException if any socket error occurs
   * @throws SQLException for other kind of errors
   */
  default Completion readPacket(
      Statement stmt,
      int fetchSize,
      long maxRows,
      int resultSetConcurrency,
      int resultSetType,
      boolean closeOnCompletion,
      Reader reader,
      Writer writer,
      Context context,
      ExceptionFactory exceptionFactory,
      ClosableLock lock,
      boolean traceEnable,
      ClientMessage message,
      Consumer<String> redirectFct)
      throws IOException, SQLException {

    ReadableByteBuf buf = reader.readReusablePacket(traceEnable);

    switch (buf.getByte()) {

        // *********************************************************************************************************
        // * OK response
        // *********************************************************************************************************
      case (byte) 0x00:
        OkPacket ok = OkPacket.parse(buf, context);
        if (context.getRedirectUrl() != null
            && (context.getServerStatus() & ServerStatus.IN_TRANSACTION) == 0
            && (context.getServerStatus() & ServerStatus.MORE_RESULTS_EXISTS) == 0) {
          redirectFct.accept(context.getRedirectUrl());
        }
        return ok;

        // *********************************************************************************************************
        // * ERROR response
        // *********************************************************************************************************
      case (byte) 0xff:
        // force current status to in transaction to ensure rollback/commit, since command may
        // have issue a transaction
        ErrorPacket errorPacket = new ErrorPacket(buf, context);
        throw exceptionFactory
            .withSql(this.description())
            .create(
                errorPacket.getMessage(), errorPacket.getSqlState(), errorPacket.getErrorCode());
      case (byte) 0xfb:
        buf.skip(1); // skip header
        SQLException exception = null;
        reader.getSequence().set(writer.getSequence());
        InputStream is = getLocalInfileInputStream();
        if (is == null) {
          String fileName = buf.readStringNullEnd();
          if (!message.validateLocalFileName(fileName, context)) {
            exception =
                exceptionFactory
                    .withSql(this.description())
                    .create(
                        String.format(
                            "LOAD DATA LOCAL INFILE asked for file '%s' that doesn't correspond to"
                                + " initial query %s. Possible malicious proxy changing server"
                                + " answer ! Command interrupted",
                            fileName, this.description()),
                        "HY000");
          } else {

            try {
              is = new FileInputStream(fileName);
            } catch (FileNotFoundException f) {
              exception =
                  exceptionFactory
                      .withSql(this.description())
                      .create("Could not send file : " + f.getMessage(), "HY000", f);
            }
          }
        }

        // sending stream
        if (is != null) {
          try {
            byte[] fileBuf = new byte[8192];
            int len;
            while ((len = is.read(fileBuf)) > 0) {
              writer.writeBytes(fileBuf, 0, len);
              writer.flush();
            }
          } finally {
            is.close();
          }
        }

        // after file send / having an error, sending an empty packet to keep connection state ok
        writer.writeEmptyPacket();
        Completion completion =
            readPacket(
                stmt,
                fetchSize,
                maxRows,
                resultSetConcurrency,
                resultSetType,
                closeOnCompletion,
                reader,
                writer,
                context,
                exceptionFactory,
                lock,
                traceEnable,
                message,
                redirectFct);
        if (exception != null) {
          throw exception;
        }
        return completion;

        // *********************************************************************************************************
        // * ResultSet
        // *********************************************************************************************************
      default:
        int fieldCount = buf.readIntLengthEncodedNotNull();

        ColumnDecoder[] ci;
        if (context.canSkipMeta() && this.canSkipMeta()) {
          if (buf.readByte() == 0) {
            // skip meta
            ci = ((BasePreparedStatement) stmt).getMeta();
          } else {
            // can skip meta, but meta might have changed
            ci = new ColumnDecoder[fieldCount];
            for (int i = 0; i < fieldCount; i++) {
              ci[i] =
                  context
                      .getColumnDecoderFunction()
                      .apply(new StandardReadableByteBuf(reader.readPacket(traceEnable)));
            }
            ((BasePreparedStatement) stmt).updateMeta(ci);
          }
        } else {
          // always read meta
          ci = new ColumnDecoder[fieldCount];
          for (int i = 0; i < fieldCount; i++) {
            ci[i] =
                context
                    .getColumnDecoderFunction()
                    .apply(new StandardReadableByteBuf(reader.readPacket(traceEnable)));
          }
        }

        // intermediate EOF
        if (!context.isEofDeprecated()) {
          reader.readReusablePacket();
        }

        // read resultSet
        if (resultSetConcurrency == ResultSet.CONCUR_UPDATABLE) {
          return new UpdatableResult(
              stmt,
              binaryProtocol(),
              maxRows,
              ci,
              reader,
              context,
              resultSetType,
              closeOnCompletion,
              traceEnable);
        }

        if (fetchSize != 0) {
          if ((context.getServerStatus() & ServerStatus.MORE_RESULTS_EXISTS) > 0) {
            context.setServerStatus(context.getServerStatus() - ServerStatus.MORE_RESULTS_EXISTS);
          }

          return new StreamingResult(
              stmt,
              binaryProtocol(),
              maxRows,
              ci,
              reader,
              context,
              fetchSize,
              lock,
              resultSetType,
              closeOnCompletion,
              traceEnable);
        } else {
          return new CompleteResult(
              stmt,
              binaryProtocol(),
              maxRows,
              ci,
              reader,
              context,
              resultSetType,
              closeOnCompletion,
              traceEnable,
              mightBeBulkResult());
        }
    }
  }

  /**
   * Get current local infile input stream.
   *
   * @return default to null
   */
  default InputStream getLocalInfileInputStream() {
    return null;
  }

  /**
   * Indicating if result might be a COM_STMT_BULK result
   *
   * @return true if so.
   */
  default boolean mightBeBulkResult() {
    return false;
  }

  /**
   * Request for local file to be validated from current query.
   *
   * @param fileName server file request path
   * @param context current connection context
   * @return true if file name correspond to demand and query is a load local infile
   */
  default boolean validateLocalFileName(String fileName, Context context) {
    return false;
  }
}