KMSBenchmark.java

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.crypto.key.kms.server;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.crypto.key.KeyProvider;
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.util.ExitUtil;
import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.hadoop.util.KMSUtil;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Main class for a series of KMS benchmarks.
 *
 * Each benchmark measures throughput and average execution time
 * of a specific kms operation, e.g. encrypt or decrypt of
 * Data Encryption Keys.
 *
 * The benchmark does not involve any other hadoop components
 * except for kms operations. Each operation is executed
 * by calling directly the respective kms operation.
 *
 * For usage, please see <a href="http://hadoop.apache.org/docs/current/
 * hadoop-project-dist/hadoop-common/Benchmarking.html#KMSBenchmark">
 * the documentation</a>.
 * Meanwhile, if you change the usage of this program, please also update the
 * documentation accordingly.
 */
public class KMSBenchmark implements Tool {
  private static final Logger LOG =
          LoggerFactory.getLogger(KMSBenchmark.class);

  private static final String GENERAL_OPTIONS_USAGE = "[-threads int] |" +
          " [-numops int] | [{-warmup (true|false)}]";

  private static Configuration config;

  private KeyProviderCryptoExtension kp;
  private KeyProviderCryptoExtension.EncryptedKeyVersion eek = null;
  private String encryptionKeyName = "systest";
  private boolean createEncryptionKey = false;
  private boolean warmupKey = false;
  private List<String> keys = new ArrayList<String>();

  KMSBenchmark(Configuration conf, String[] args)
          throws IOException {
    config = conf;
    kp = createKeyProviderCryptoExtension(config);
    try {
      eek = kp.generateEncryptedKey(encryptionKeyName);
    } catch (GeneralSecurityException e) {
      LOG.warn("failed to generate key", e);
    }
    // create key and/or warm up
    for (int i = 2; i < args.length; i++) {
      if (args[i].equals("-warmup")) {
        warmupKey = Boolean.parseBoolean(args[++i]);
      } else if (args[i].equals("-createkey")) {
        encryptionKeyName = args[++i];
      }
    }
    try {
      if (createEncryptionKey) {
        keys = kp.getKeys();
        if (!keys.contains(encryptionKeyName)) {
          kp.createKey(encryptionKeyName, KeyProvider.options(conf));
        } else {
          LOG.warn("encryption key already exists: {}",
                  encryptionKeyName);
        }
      }
      if (warmupKey) {
        kp.warmUpEncryptedKeys(encryptionKeyName);
      }
    } catch (GeneralSecurityException e) {
      LOG.warn(" failed to create or warmup encryption key", e);
    }
  }

  /**
   * Base class for collecting operation statistics.
   *
   * Overload this class in order to run statistics for a
   * specific kms operation.
   */
  abstract class OperationStatsBase {
    protected static final String OP_ALL_NAME = "all";
    protected static final String OP_ALL_USAGE =
            "-op all <other ops options>";

    // number of threads
    private int  numThreads = 0;

    // number of operations requested
    private int  numOpsRequired = 0;

    // number of operations executed
    private int  numOpsExecuted = 0;

    // sum of times for each op
    private long cumulativeTime = 0;

    // time from start to finish
    private long elapsedTime = 0;

    private List<StatsDaemon> daemons;

    /**
     * Operation name.
     */
    abstract String getOpName();

    /**
     * Parse command line arguments.
     *
     * @param args arguments
     * @throws IOException
     */
    abstract void parseArguments(List<String> args) throws IOException;

    /**
     * This corresponds to the arg1 argument of
     * {@link #executeOp(int, int, String)}, which can have
     * different meanings depending on the operation performed.
     *
     * @param daemonId id of the daemon calling this method
     * @return the argument
     */
    abstract String getExecutionArgument(int daemonId);

    /**
     * Execute kms operation.
     *
     * @param daemonId id of the daemon calling this method.
     * @param inputIdx serial index of the operation called by the deamon.
     * @param arg1 operation specific argument.
     * @return time of the individual kms call.
     * @throws IOException
     */
    abstract long executeOp(int daemonId, int inputIdx, String arg1)
            throws IOException;

    /**
     * Print the results of the benchmarking.
     */
    abstract void printResults();

    OperationStatsBase() {
      numOpsRequired = 10000;
      numThreads = 3;
    }

    void benchmark() throws IOException {
      daemons = new ArrayList<StatsDaemon>();
      long start = 0;
      try {
        numOpsExecuted = 0;
        cumulativeTime = 0;
        if (numThreads < 1) {
          return;
        }
        // thread index < nrThreads
        int tIdx = 0;
        int[] opsPerThread = new int[numThreads];
        for (int opsScheduled = 0; opsScheduled < numOpsRequired;
             opsScheduled += opsPerThread[tIdx++]) {
          // execute  in a separate thread
          opsPerThread[tIdx] =
                  (numOpsRequired-opsScheduled)/(numThreads-tIdx);
          if (opsPerThread[tIdx] == 0) {
            opsPerThread[tIdx] = 1;
          }
        }
        // if numThreads > numOpsRequired then the remaining threads
        // will do nothing
        for (; tIdx < numThreads; tIdx++) {
          opsPerThread[tIdx] = 0;
        }
        for (tIdx=0; tIdx < numThreads; tIdx++) {
          daemons.add(new StatsDaemon(tIdx, opsPerThread[tIdx], this));
        }
        start = Time.now();
        LOG.info("Starting "+numOpsRequired+" "+getOpName()+"(s).");
        for (StatsDaemon d : daemons) {
          d.start();
        }
      } finally {
        while(isInProgress()) {
          try {
            Thread.sleep(500);
          } catch (InterruptedException e) {}
        }
        elapsedTime = Time.now() - start;
        for (StatsDaemon d : daemons) {
          incrementStats(d.localNumOpsExecuted, d.localCumulativeTime);
          System.out.println(d.toString() + ": ops Exec = " +
                  d.localNumOpsExecuted);
        }
      }
    }

    private boolean isInProgress() {
      for (StatsDaemon d : daemons) {
        if (d.isInProgress()) {
          return true;
        }
      }
      return false;
    }

    void cleanUp() throws IOException {
    }

    int getNumOpsExecuted() {
      return numOpsExecuted;
    }

    long getCumulativeTime() {
      return cumulativeTime;
    }

    long getElapsedTime() {
      return elapsedTime;
    }

    long getAverageTime() {
      LOG.info("getAverageTime, cumulativeTime = " + cumulativeTime);
      LOG.info("getAverageTime, numOpsExecuted = " + numOpsExecuted);
      return numOpsExecuted == 0? 0 : cumulativeTime/numOpsExecuted;
    }

    double getOpsPerSecond() {
      return elapsedTime == 0?
              0 : 1000*(double)numOpsExecuted / elapsedTime;
    }

    String getClientName(int idx) {
      return getOpName() + "-client-" + idx;
    }

    void incrementStats(int ops, long time) {
      numOpsExecuted += ops;
      cumulativeTime += time;
    }

    int getNumThreads() {
      return numThreads;
    }

    void setNumThreads(int num) {
      numThreads = num;
    }

    int getNumOpsRequired() {
      return numOpsRequired;
    }

    void setNumOpsRequired(int num) {
      numOpsRequired = num;
    }

    /**
     * Parse first 2 arguments, corresponding to the "-op" option.
     *
     * @param args argument list
     * @return true if operation is all, which means that options not
     * related to this operation should be ignored, or false
     * otherwise, meaning that usage should be printed when an
     * unrelated option is encountered.
     */
    protected boolean verifyOpArgument(List<String> args) {
      if (args.size() < 2 || !args.get(0).startsWith("-op")) {
        printUsage();
      }

      // process common options
      String type = args.get(1);
      if (OP_ALL_NAME.equals(type)) {
        type = getOpName();
        return true;
      }
      if (!getOpName().equals(type)) {
        printUsage();
      }
      return false;
    }

    void printStats() {
      LOG.info("--- " + getOpName() + " stats  ---");
      LOG.info("# operations: " + getNumOpsExecuted());
      LOG.info("Elapsed Time: " + getElapsedTime());
      LOG.info(" Ops per sec: " + getOpsPerSecond());
      LOG.info("Average Time: " + getAverageTime());
    }
  }

  /**
   * One of the threads that perform stats operations.
   */
  private class StatsDaemon extends Thread {
    private final int daemonId;
    private int opsPerThread;
    private String arg1;      // argument passed to executeOp()
    private volatile int  localNumOpsExecuted = 0;
    private volatile long localCumulativeTime = 0;
    private final OperationStatsBase statsOp;

    StatsDaemon(int daemonId, int nOps, OperationStatsBase op) {
      this.daemonId = daemonId;
      this.opsPerThread = nOps;
      this.statsOp = op;
      setName(toString());
    }

    @Override
    public void run() {
      localNumOpsExecuted = 0;
      localCumulativeTime = 0;
      arg1 = statsOp.getExecutionArgument(daemonId);
      try {
        benchmarkOne();
      } catch(IOException ex) {
        LOG.error("StatsDaemon " + daemonId + " failed: \n"
            + StringUtils.stringifyException(ex));
      }
    }

    @Override
    public String toString() {
      return "StatsDaemon-" + daemonId;
    }

    void benchmarkOne() throws IOException {
      for (int idx = 0; idx < opsPerThread; idx++) {
        long stat = statsOp.executeOp(daemonId, idx, arg1);
        localNumOpsExecuted++;
        localCumulativeTime += stat;
      }
    }

    boolean isInProgress() {
      return localNumOpsExecuted < opsPerThread;
    }

    /**
     * Schedule to stop this daemon.
     */
    void terminate() {
      opsPerThread = localNumOpsExecuted;
    }
  }

  /**
   * Encrypt key statistics.
   *
   * Each thread encrypts the key.
   */
  class EncryptKeyStats extends OperationStatsBase {
    // Operation types
    static final String OP_ENCRYPT_KEY = "encrypt";
    static final String OP_ENCRYPT_USAGE =
            "-op encrypt [-threads T -numops N -warmup F]";

    EncryptKeyStats(List<String> args) {
      super();
      parseArguments(args);
    }

    @Override
    String getOpName() {
      return OP_ENCRYPT_KEY;
    }

    @Override
    void parseArguments(List<String> args) {
      verifyOpArgument(args);
      // parse command line
      for (int i = 2; i < args.size(); i++) {
        if (args.get(i).equals("-threads")) {
          if (i+1 == args.size()) {
            printUsage();
          }
          setNumThreads(Integer.parseInt(args.get(++i)));
        } else if (args.get(i).equals("-numops")) {
          setNumOpsRequired(Integer.parseInt(args.get(++i)));
        }
      }
    }

    /**
     * Returns client name.
     */
    @Override
    String getExecutionArgument(int daemonId) {
      return getClientName(daemonId);
    }

    /**
     * Execute key encryption.
     */
    @Override
    long executeOp(int daemonId, int inputIdx, String clientName)
            throws IOException {
      long start = Time.now();
      try {
        eek = kp.generateEncryptedKey(encryptionKeyName);
      } catch (GeneralSecurityException e) {
        LOG.warn("failed to generate encrypted key", e);
      }

      long end = Time.now();
      return end-start;
    }

    @Override
    void printResults() {
      LOG.info("--- " + getOpName() + " inputs ---");
      LOG.info("nOps = " + getNumOpsRequired());
      LOG.info("nThreads = " + getNumThreads());
      printStats();
    }
  }

  /**
   * Decrypt key statistics.
   *
   * Each thread decrypts the key.
   */
  class DecryptKeyStats extends OperationStatsBase {
    // Operation types
    static final String OP_DECRYPT_KEY = "decrypt";
    static final String OP_DECRYPT_USAGE =
            "-op decrypt [-threads T -numops N -warmup F]";

    DecryptKeyStats(List<String> args) {
      super();
      parseArguments(args);
    }

    @Override
    String getOpName() {
      return OP_DECRYPT_KEY;
    }

    @Override
    void parseArguments(List<String> args) {
      verifyOpArgument(args);
      // parse command line
      for (int i = 2; i < args.size(); i++) {
        if (args.get(i).equals("-threads")) {
          if (i+1 == args.size()) {
            printUsage();
          }
          setNumThreads(Integer.parseInt(args.get(++i)));
        } else if (args.get(i).equals("-numops")) {
          setNumOpsRequired(Integer.parseInt(args.get(++i)));
        }
      }
    }

    /**
     * returns client name.
     */
    @Override
    String getExecutionArgument(int daemonId) {
      return getClientName(daemonId);
    }

    /**
     * Execute key decryption.
     */
    @Override
    long executeOp(int daemonId, int inputIdx, String clientName)
        throws IOException {
      long start = Time.now();
      try {
        kp.decryptEncryptedKey(eek);
      } catch (GeneralSecurityException e) {
        LOG.warn("failed to generate and/or decrypt key", e);
      }
      long end = Time.now();
      return end - start;
    }

    @Override
    void printResults() {
      LOG.info("--- " + getOpName() + " inputs ---");
      LOG.info("nrOps = " + getNumOpsRequired());
      LOG.info("nrThreads = " + getNumThreads());
      printStats();
    }
  }

  static void printUsage() {
    System.err.println("Usage: KMSBenchmark"
        + "\n\t"    + OperationStatsBase.OP_ALL_USAGE
        + " | \n\t" + EncryptKeyStats.OP_ENCRYPT_USAGE
        + " | \n\t" + DecryptKeyStats.OP_DECRYPT_USAGE
        + " | \n\t" + GENERAL_OPTIONS_USAGE
    );
    System.err.println();
    GenericOptionsParser.printGenericCommandUsage(System.err);
    ExitUtil.terminate(-1);
  }

  public static KeyProviderCryptoExtension createKeyProviderCryptoExtension(
          final Configuration conf) throws IOException {

    KeyProvider keyProvider = KMSUtil.createKeyProvider(conf,
            CommonConfigurationKeysPublic.HADOOP_SECURITY_KEY_PROVIDER_PATH);
    if (keyProvider == null) {
      throw new IOException("Key provider was not configured.");
    }
    return KeyProviderCryptoExtension.
            createKeyProviderCryptoExtension(keyProvider);
  }

  public static void runBenchmark(Configuration conf, String[] args)
      throws Exception {
    KMSBenchmark bench = null;
    try {
      bench = new KMSBenchmark(conf, args);
      ToolRunner.run(bench, args);
    } finally {
      LOG.info("runBenchmark finished.");
    }
  }

  /**
   * Main method of the benchmark.
   * @param aArgs command line parameters
   */
  @Override // Tool
  public int run(String[] aArgs) throws Exception {
    List<String> args = new ArrayList<String>(Arrays.asList(aArgs));
    if (args.size() < 2 || !args.get(0).startsWith("-op")) {
      printUsage();
    }

    String type = args.get(1);
    boolean runAll = OperationStatsBase.OP_ALL_NAME.equals(type);

    List<OperationStatsBase> ops = new ArrayList<OperationStatsBase>();
    OperationStatsBase opStat = null;
    try {
      if (runAll || EncryptKeyStats.OP_ENCRYPT_KEY.equals(type)) {
        opStat = new EncryptKeyStats(args);
        ops.add(opStat);
      }
      if (runAll || DecryptKeyStats.OP_DECRYPT_KEY.equals(type)) {
        opStat = new DecryptKeyStats(args);
        ops.add(opStat);
      }
      if (ops.isEmpty()) {
        printUsage();
      }

      // run each benchmark
      for (OperationStatsBase op : ops) {
        LOG.info("Starting benchmark: " + op.getOpName());
        op.benchmark();
        op.cleanUp();
      }
      // print statistics
      for (OperationStatsBase op : ops) {
        LOG.info("");
        op.printResults();
      }
    } catch(Exception e) {
      LOG.error("failed to run benchmarks", e);
      throw e;
    }
    return 0;
  }

  public static void main(String[] args) throws Exception {
    runBenchmark(new Configuration(), args);
  }

  @Override // Configurable
  public void setConf(Configuration conf) {
    config = conf;
  }

  @Override // Configurable
  public Configuration getConf() {
    return config;
  }
}