FederationQueryRunner.java

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.yarn.server.federation.store.sql;

import org.apache.hadoop.classification.VisibleForTesting;

import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.CallableStatement;
import java.sql.ResultSet;
import java.util.Arrays;

import org.apache.hadoop.yarn.server.federation.store.sql.DatabaseProduct.DbType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.apache.hadoop.yarn.server.federation.store.sql.DatabaseProduct.isDuplicateKeyError;

/**
 * QueryRunner is used to execute stored procedure SQL and parse the returned results.
 */
public class FederationQueryRunner {

  public final static String YARN_ROUTER_SEQUENCE_NUM = "YARN_ROUTER_SEQUENCE_NUM";

  public final static String YARN_ROUTER_CURRENT_KEY_ID = "YARN_ROUTER_CURRENT_KEY_ID";

  public final static String QUERY_SEQUENCE_TABLE_SQL =
      "SELECT nextVal FROM sequenceTable WHERE sequenceName = %s";

  public final static String INSERT_SEQUENCE_TABLE_SQL =
      "INSERT INTO sequenceTable(sequenceName, nextVal) VALUES(%s, %d)";

  public final static String UPDATE_SEQUENCE_TABLE_SQL =
      "UPDATE sequenceTable SET nextVal = %d WHERE sequenceName = %s";

  public final static String DELETE_QUEUE_SQL = "DELETE FROM policies WHERE queue = %s";

  public static final Logger LOG = LoggerFactory.getLogger(FederationQueryRunner.class);

  /**
   * Execute Stored Procedure SQL.
   *
   * @param conn      Database Connection.
   * @param procedure Stored Procedure SQL.
   * @param rsh       Result Set handler.
   * @param params    List of stored procedure parameters.
   * @param <T>       Generic T.
   * @return Stored Procedure Result Set.
   * @throws SQLException An exception occurred when calling a stored procedure.
   */
  public <T> T execute(Connection conn, String procedure, ResultSetHandler<T> rsh, Object... params)
      throws SQLException {
    if (conn == null) {
      throw new SQLException("Null connection");
    }

    if (procedure == null) {
      throw new SQLException("Null Procedure SQL statement");
    }

    if (rsh == null) {
      throw new SQLException("Null ResultSetHandler");
    }

    CallableStatement stmt = null;
    T results = null;

    try {
      stmt = this.getCallableStatement(conn, procedure);
      this.fillStatement(stmt, params);
      stmt.executeUpdate();
      this.retrieveOutParameters(stmt, params);
      results = rsh.handle(params);
    } catch (SQLException e) {
      this.rethrow(e, procedure, params);
    } finally {
      close(stmt);
    }
    return results;
  }

  /**
   * Get CallableStatement from Conn.
   *
   * @param conn Database Connection.
   * @param procedure Stored Procedure SQL.
   * @return CallableStatement.
   * @throws SQLException An exception occurred when calling a stored procedure.
   */
  @VisibleForTesting
  protected CallableStatement getCallableStatement(Connection conn, String procedure)
      throws SQLException {
    return conn.prepareCall(procedure);
  }

  /**
   * Set Statement parameters.
   *
   * @param stmt CallableStatement.
   * @param params Stored procedure parameters.
   * @throws SQLException An exception occurred when calling a stored procedure.
   */
  public void fillStatement(CallableStatement stmt, Object... params)
      throws SQLException {
    for (int i = 0; i < params.length; i++) {
      if (params[i] != null) {
        if (stmt != null) {
          if (params[i] instanceof FederationSQLOutParameter) {
            FederationSQLOutParameter sqlOutParameter = (FederationSQLOutParameter) params[i];
            sqlOutParameter.register(stmt, i + 1);
          } else {
            stmt.setObject(i + 1, params[i]);
          }
        }
      }
    }
  }

  /**
   * Close Statement.
   *
   * @param stmt CallableStatement.
   * @throws SQLException An exception occurred when calling a stored procedure.
   */
  public void close(Statement stmt) throws SQLException {
    if (stmt != null) {
      stmt.close();
      stmt = null;
    }
  }

  /**
   * Retrieve execution result from CallableStatement.
   *
   * @param stmt CallableStatement.
   * @param params Stored procedure parameters.
   * @throws SQLException An exception occurred when calling a stored procedure.
   */
  private void retrieveOutParameters(CallableStatement stmt, Object[] params) throws SQLException {
    if (params != null && stmt != null) {
      for (int i = 0; i < params.length; i++) {
        if (params[i] instanceof FederationSQLOutParameter) {
          FederationSQLOutParameter sqlOutParameter = (FederationSQLOutParameter) params[i];
          sqlOutParameter.setValue(stmt, i + 1);
        }
      }
    }
  }

  /**
   * Re-throw SQL exception.
   *
   * @param cause SQLException.
   * @param sql Stored Procedure SQL.
   * @param params Stored procedure parameters.
   * @throws SQLException An exception occurred when calling a stored procedure.
   */
  protected void rethrow(SQLException cause, String sql, Object... params)
      throws SQLException {

    String causeMessage = cause.getMessage();
    if (causeMessage == null) {
      causeMessage = "";
    }

    StringBuilder msg = new StringBuilder(causeMessage);
    msg.append(" Query: ");
    msg.append(sql);
    msg.append(" Parameters: ");

    if (params == null) {
      msg.append("[]");
    } else {
      msg.append(Arrays.deepToString(params));
    }

    SQLException e = new SQLException(msg.toString(), cause.getSQLState(), cause.getErrorCode());
    e.setNextException(cause);
    throw e;
  }

  /**
   * We query or update the SequenceTable.
   *
   * @param connection database conn.
   * @param sequenceName sequenceName, We currently have 2 sequences,
   * YARN_ROUTER_SEQUENCE_NUM and YARN_ROUTER_CURRENT_KEY_ID.
   * @param isUpdate true, means we will update the SequenceTable,
   * false, we query the SequenceTable.
   *
   * @return SequenceValue.
   * @throws SQLException An exception occurred when calling a stored procedure.
   */
  public int selectOrUpdateSequenceTable(Connection connection, String sequenceName,
      boolean isUpdate) throws SQLException {

    int maxSequenceValue = 0;
    boolean insertDone = false;
    boolean committed = false;
    Statement statement = null;

    try {

      // Step1. Query SequenceValue.
      while (maxSequenceValue == 0) {
        // Query SQL.
        String sql = String.format(QUERY_SEQUENCE_TABLE_SQL, quoteString(sequenceName));
        DbType dbType = DatabaseProduct.getDbType(connection);
        String forUpdateSQL = DatabaseProduct.addForUpdateClause(dbType, sql);
        statement = connection.createStatement();
        ResultSet rs = statement.executeQuery(forUpdateSQL);
        if (rs.next()) {
          maxSequenceValue = rs.getInt("nextVal");
        } else if (insertDone) {
          throw new SQLException("Invalid state of SEQUENCE_TABLE for " + sequenceName);
        } else {
          insertDone = true;
          close(statement);
          statement = connection.createStatement();
          String insertSQL = String.format(INSERT_SEQUENCE_TABLE_SQL, quoteString(sequenceName), 1);
          try {
            statement.executeUpdate(insertSQL);
          } catch (SQLException e) {
            // If the record is already inserted by some other thread continue to select.
            if (isDuplicateKeyError(dbType, e)) {
              continue;
            }
            LOG.error("Unable to insert into SEQUENCE_TABLE for {}.", sequenceName, e);
            throw e;
          } finally {
            close(statement);
          }
        }
      }

      // Step2. Increase SequenceValue.
      if (isUpdate) {
        int nextSequenceValue = maxSequenceValue + 1;
        close(statement);
        statement = connection.createStatement();
        String updateSQL =
            String.format(UPDATE_SEQUENCE_TABLE_SQL, nextSequenceValue, quoteString(sequenceName));
        statement.executeUpdate(updateSQL);
        maxSequenceValue = nextSequenceValue;
      }

      connection.commit();
      committed = true;
      return maxSequenceValue;
    } catch(SQLException e){
      throw new SQLException("Unable to selectOrUpdateSequenceTable due to: " + e.getMessage(), e);
    } finally {
      if (!committed) {
        rollbackDBConn(connection);
      }
      close(statement);
    }
  }

  public void updateSequenceTable(Connection connection, String sequenceName, int sequenceValue)
      throws SQLException {
    String updateSQL =
        String.format(UPDATE_SEQUENCE_TABLE_SQL, sequenceValue, quoteString(sequenceName));
    boolean committed = false;
    Statement statement = null;
    try {
      statement = connection.createStatement();
      statement.executeUpdate(updateSQL);
      connection.commit();
      committed = true;
    } catch (SQLException e) {
      throw new SQLException("Unable to updateSequenceTable due to: " + e.getMessage());
    } finally {
      if (!committed) {
        rollbackDBConn(connection);
      }
      close(statement);
    }
  }

  public void deletePolicyByQueue(Connection connection, String queue)
      throws SQLException {
    String deleteSQL = String.format(DELETE_QUEUE_SQL, quoteString(queue));
    boolean committed = false;
    Statement statement = null;
    try {
      statement = connection.createStatement();
      statement.executeUpdate(deleteSQL);
      connection.commit();
      committed = true;
    } catch (SQLException e) {
      throw new SQLException("Unable to deletePolicyByQueue due to: " + e.getMessage());
    } finally {
      if (!committed) {
        rollbackDBConn(connection);
      }
      close(statement);
    }
  }

  public void truncateTable(Connection connection, String tableName)
      throws SQLException {
    DbType dbType = DatabaseProduct.getDbType(connection);
    String deleteSQL = getTruncateStatement(dbType, tableName);
    boolean committed = false;
    Statement statement = null;
    try {
      statement = connection.createStatement();
      statement.execute(deleteSQL);
      connection.commit();
      committed = true;
    } catch (SQLException e) {
      throw new SQLException("Unable to truncateTable due to: " + e.getMessage());
    } finally {
      if (!committed) {
        rollbackDBConn(connection);
      }
      close(statement);
    }
  }

  private String getTruncateStatement(DbType dbType, String tableName) {
    if (isMYSQL(dbType)) {
      return ("DELETE FROM \"" + tableName + "\"");
    } else {
      return("DELETE FROM " + tableName);
    }
  }

  private boolean isMYSQL(DbType dbType) {
    return dbType == DbType.MYSQL;
  }

  static void rollbackDBConn(Connection dbConn) {
    try {
      if (dbConn != null && !dbConn.isClosed()) {
        dbConn.rollback();
      }
    } catch (SQLException e) {
      LOG.warn("Failed to rollback db connection ", e);
    }
  }

  static String quoteString(String input) {
    return "'" + input + "'";
  }
}