QueryWithParametersRewritePacket.java

// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2025 MariaDB Corporation Ab
package org.mariadb.jdbc.message.client;

import java.io.IOException;
import java.io.InputStream;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.socket.Writer;
import org.mariadb.jdbc.client.util.Parameter;
import org.mariadb.jdbc.client.util.Parameters;
import org.mariadb.jdbc.plugin.codec.ByteArrayCodec;
import org.mariadb.jdbc.util.ClientParser;

/**
 * batch execution using REWRITE.
 *
 * @see <a href="https://mariadb.com/kb/en/com_stmt_bulk_execute/">documentation</a>
 */
public final class QueryWithParametersRewritePacket implements RedoableClientMessage {
  private final String preSqlCmd;
  private final ClientParser parser;
  private List<Parameters> parametersList;

  /**
   * @param preSqlCmd pre sql command
   * @param parser parser
   * @param batchParameterList parameters
   */
  public QueryWithParametersRewritePacket(
      String preSqlCmd, ClientParser parser, List<Parameters> batchParameterList) {
    this.preSqlCmd = preSqlCmd;
    this.parametersList = batchParameterList;
    this.parser = parser;
  }

  public void saveParameters() {
    List<Parameters> savedList = new ArrayList<>(parametersList.size());
    for (Parameters parameterList : parametersList) {
      savedList.add(parameterList.clone());
    }
    this.parametersList = savedList;
  }

  @Override
  public void ensureReplayable(Context context) throws IOException, SQLException {
    for (int j = 0; j < parametersList.size(); j++) {
      Parameters parameters = parametersList.get(j);
      int parameterCount = parameters.size();
      for (int i = 0; i < parameterCount; i++) {
        Parameter p = parameters.get(i);
        if (!p.isNull() && p.canEncodeLongData()) {
          parameters.set(
              i, new org.mariadb.jdbc.codec.Parameter<>(ByteArrayCodec.INSTANCE, p.encodeData()));
        }
      }
    }
  }

  public int encode(Writer writer, Context context) throws IOException, SQLException {

    Iterator<Parameters> paramIterator = parametersList.iterator();
    Parameters parameters = paramIterator.next();

    int rewritePacketNo = 0;
    int endingPartLen = parser.getQuery().length - parser.getValuesBracketPositions().get(1);

    // Implementation After writing a bunch of parameter to buffer is marked. then : - when writing
    // next bunch of parameter, if buffer grow more than max_allowed_packet, send buffer up to mark,
    // then create a new packet with current bunch of data - if a bunch of parameter data type
    // changes
    // send buffer up to mark, then create a new packet with new data type.
    // Problem remains if a bunch of parameter is bigger than max_allowed_packet
    main_loop:
    while (true) {
      rewritePacketNo++;

      writer.initPacket();
      writer.writeByte(0x03);
      if (preSqlCmd != null) writer.writeAscii(preSqlCmd);

      int pos = 0;
      int paramPos;
      if (parser.getParamCount() > parameters.size()) {
        throw context.getExceptionFactory().create("wrong number of parameters", "Y0000");
      }

      for (int i = 0; i < parser.getParamCount(); i++) {
        paramPos = parser.getParamPositions().get(i);
        writer.writeBytes(parser.getQuery(), pos, paramPos - pos);
        pos = paramPos + 1;
        parameters.get(i).encodeText(writer, context);
      }

      if (paramIterator.hasNext()) {
        parameters = paramIterator.next();
      } else break;

      if (writer.throwMaxAllowedLengthOr16M(writer.pos() + endingPartLen)) {
        writer.writeBytes(
            parser.getQuery(), parser.getValuesBracketPositions().get(1), endingPartLen);
        writer.flush();
        continue;
      }

      parameter_loop:
      while (true) {

        // check packet length so to separate in multiple packet
        int parameterLength = 0;
        boolean knownParameterSize = true;
        if (parser.getParamCount() > parameters.size()) {
          throw context.getExceptionFactory().create("wrong number of parameters", "Y0000");
        }
        for (int i = 0; i < parser.getParamCount(); i++) {
          int paramSize = parameters.get(i).getApproximateTextProtocolLength();
          if (paramSize == -1) {
            knownParameterSize = false;
            break;
          }
          if (i > 0) {
            parameterLength +=
                parser.getParamPositions().get(i) - (parser.getParamPositions().get(i - 1) + 1);
          }
          parameterLength += paramSize;
        }

        if (!knownParameterSize
            || writer.throwMaxAllowedLengthOr16M(writer.pos() + parameterLength)) {
          writer.writeBytes(
              parser.getQuery(), parser.getValuesBracketPositions().get(1), endingPartLen);
          writer.flush();
          break;
        }

        writer.writeBytes(
            parser.getQuery(), pos, parser.getValuesBracketPositions().get(1) + 1 - pos);
        writer.writeByte((byte) ',');

        pos = parser.getValuesBracketPositions().get(0);
        for (int i = 0; i < parser.getParamPositions().size(); i++) {
          paramPos = parser.getParamPositions().get(i);
          writer.writeBytes(parser.getQuery(), pos, paramPos - pos);
          pos = paramPos + 1;
          parameters.get(i).encodeText(writer, context);
        }

        if (paramIterator.hasNext()) {
          parameters = paramIterator.next();
        } else break main_loop;
      }
    }
    writer.writeBytes(parser.getQuery(), parser.getValuesBracketPositions().get(1), endingPartLen);
    writer.flush();

    return rewritePacketNo;
  }

  public boolean binaryProtocol() {
    return false;
  }

  @Override
  public InputStream getLocalInfileInputStream() {
    return null;
  }

  public String description() {
    return "REWRITE: " + preSqlCmd + parser.getSql();
  }

  @Override
  public int batchUpdateLength() {
    return parametersList.size();
  }

  @Override
  public boolean validateLocalFileName(String fileName, Context context) {
    return false;
  }
}