ShuffleHandler.java

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hadoop.mapred;

import static org.fusesource.leveldbjni.JniDBFactory.asString;
import static org.fusesource.leveldbjni.JniDBFactory.bytes;

import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.concurrent.GlobalEventExecutor;

import javax.annotation.Nonnull;

import org.apache.hadoop.thirdparty.com.google.common.cache.CacheBuilder;
import org.apache.hadoop.thirdparty.com.google.common.cache.CacheLoader;
import org.apache.hadoop.thirdparty.com.google.common.cache.LoadingCache;
import org.apache.hadoop.thirdparty.com.google.common.cache.RemovalListener;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataInputByteBuffer;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.proto.ShuffleHandlerRecoveryProtos.JobShuffleInfoProto;
import org.apache.hadoop.mapreduce.MRConfig;
import org.apache.hadoop.mapreduce.security.token.JobTokenIdentifier;
import org.apache.hadoop.mapreduce.security.token.JobTokenSecretManager;
import org.apache.hadoop.metrics2.MetricsSystem;
import org.apache.hadoop.metrics2.annotation.Metric;
import org.apache.hadoop.metrics2.annotation.Metrics;
import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem;
import org.apache.hadoop.metrics2.lib.MutableCounterInt;
import org.apache.hadoop.metrics2.lib.MutableCounterLong;
import org.apache.hadoop.metrics2.lib.MutableGaugeInt;
import org.apache.hadoop.security.proto.SecurityProtos.TokenProto;
import org.apache.hadoop.security.ssl.SSLFactory;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.VersionProto;
import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext;
import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext;
import org.apache.hadoop.yarn.server.api.AuxiliaryService;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.localizer.ContainerLocalizer;
import org.apache.hadoop.yarn.server.records.Version;
import org.apache.hadoop.yarn.server.records.impl.pb.VersionPBImpl;
import org.apache.hadoop.yarn.server.utils.LeveldbIterator;
import org.fusesource.leveldbjni.JniDBFactory;
import org.fusesource.leveldbjni.internal.NativeDB;
import org.iq80.leveldb.DB;
import org.iq80.leveldb.DBException;
import org.iq80.leveldb.Options;
import org.slf4j.LoggerFactory;

import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.hadoop.thirdparty.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.hadoop.thirdparty.protobuf.ByteString;

public class ShuffleHandler extends AuxiliaryService {

  public static final org.slf4j.Logger LOG =
      LoggerFactory.getLogger(ShuffleHandler.class);
  public static final org.slf4j.Logger AUDITLOG =
      LoggerFactory.getLogger(ShuffleHandler.class.getName()+".audit");
  public static final String SHUFFLE_MANAGE_OS_CACHE = "mapreduce.shuffle.manage.os.cache";
  public static final boolean DEFAULT_SHUFFLE_MANAGE_OS_CACHE = true;

  public static final String SHUFFLE_READAHEAD_BYTES = "mapreduce.shuffle.readahead.bytes";
  public static final int DEFAULT_SHUFFLE_READAHEAD_BYTES = 4 * 1024 * 1024;

  public static final String MAX_WEIGHT =
      "mapreduce.shuffle.pathcache.max-weight";
  public static final int DEFAULT_MAX_WEIGHT = 10 * 1024 * 1024;

  public static final String EXPIRE_AFTER_ACCESS_MINUTES =
      "mapreduce.shuffle.pathcache.expire-after-access-minutes";
  public static final int DEFAULT_EXPIRE_AFTER_ACCESS_MINUTES = 5;

  public static final String CONCURRENCY_LEVEL =
      "mapreduce.shuffle.pathcache.concurrency-level";
  public static final int DEFAULT_CONCURRENCY_LEVEL = 16;
  
  // pattern to identify errors related to the client closing the socket early
  // idea borrowed from Netty SslHandler
  public static final Pattern IGNORABLE_ERROR_MESSAGE = Pattern.compile(
      "^.*(?:connection.*reset|connection.*closed|broken.*pipe).*$",
      Pattern.CASE_INSENSITIVE);

  private static final String STATE_DB_NAME = "mapreduce_shuffle_state";
  private static final String STATE_DB_SCHEMA_VERSION_KEY = "shuffle-schema-version";
  protected static final Version CURRENT_VERSION_INFO = 
      Version.newInstance(1, 0);

  private static final String DATA_FILE_NAME = "file.out";
  private static final String INDEX_FILE_NAME = "file.out.index";

  public static final HttpResponseStatus TOO_MANY_REQ_STATUS =
      new HttpResponseStatus(429, "TOO MANY REQUESTS");
  // This should be kept in sync with Fetcher.FETCH_RETRY_DELAY_DEFAULT
  public static final long FETCH_RETRY_DELAY = 1000L;
  public static final String RETRY_AFTER_HEADER = "Retry-After";

  private int port;
  private EventLoopGroup bossGroup;
  private EventLoopGroup workerGroup;

  @SuppressWarnings("checkstyle:VisibilityModifier")
  protected final ChannelGroup allChannels =
      new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

  private SSLFactory sslFactory;

  @SuppressWarnings("checkstyle:VisibilityModifier")
  protected JobTokenSecretManager secretManager;
  @SuppressWarnings("checkstyle:VisibilityModifier")
  protected Map<String, String> userRsrc;

  private DB stateDb = null;

  public static final String MAPREDUCE_SHUFFLE_SERVICEID =
      "mapreduce_shuffle";

  public static final String SHUFFLE_PORT_CONFIG_KEY = "mapreduce.shuffle.port";
  public static final int DEFAULT_SHUFFLE_PORT = 13562;

  public static final String SHUFFLE_LISTEN_QUEUE_SIZE =
      "mapreduce.shuffle.listen.queue.size";
  public static final int DEFAULT_SHUFFLE_LISTEN_QUEUE_SIZE = 128;

  public static final String SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED =
      "mapreduce.shuffle.connection-keep-alive.enable";
  public static final boolean DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED = false;

  public static final String SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT =
      "mapreduce.shuffle.connection-keep-alive.timeout";
  public static final int DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT = 5; //seconds

  public static final String SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE =
      "mapreduce.shuffle.mapoutput-info.meta.cache.size";
  public static final int DEFAULT_SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE =
      1000;

  public static final String CONNECTION_CLOSE = "close";

  public static final String SUFFLE_SSL_FILE_BUFFER_SIZE_KEY =
      "mapreduce.shuffle.ssl.file.buffer.size";

  public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024;

  public static final String MAX_SHUFFLE_CONNECTIONS = "mapreduce.shuffle.max.connections";
  public static final int DEFAULT_MAX_SHUFFLE_CONNECTIONS = 0; // 0 implies no limit
  
  public static final String MAX_SHUFFLE_THREADS = "mapreduce.shuffle.max.threads";
  // 0 implies Netty default of 2 * number of available processors
  public static final int DEFAULT_MAX_SHUFFLE_THREADS = 0;
  
  public static final String SHUFFLE_BUFFER_SIZE = 
      "mapreduce.shuffle.transfer.buffer.size";
  public static final int DEFAULT_SHUFFLE_BUFFER_SIZE = 128 * 1024;
  
  public static final String  SHUFFLE_TRANSFERTO_ALLOWED = 
      "mapreduce.shuffle.transferTo.allowed";
  public static final boolean DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = true;
  public static final boolean WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = 
      false;
  static final String TIMEOUT_HANDLER = "timeout";

  /* the maximum number of files a single GET request can
   open simultaneously during shuffle
   */
  public static final String SHUFFLE_MAX_SESSION_OPEN_FILES =
      "mapreduce.shuffle.max.session-open-files";
  public static final int DEFAULT_SHUFFLE_MAX_SESSION_OPEN_FILES = 3;


  @Metrics(about="Shuffle output metrics", context="mapred")
  static class ShuffleMetrics implements ChannelFutureListener {
    @Metric("Shuffle output in bytes")
        MutableCounterLong shuffleOutputBytes;
    @Metric("# of failed shuffle outputs")
        MutableCounterInt shuffleOutputsFailed;
    @Metric("# of succeeeded shuffle outputs")
        MutableCounterInt shuffleOutputsOK;
    @Metric("# of current shuffle connections")
        MutableGaugeInt shuffleConnections;

    @Override
    public void operationComplete(ChannelFuture future) throws Exception {
      if (future.isSuccess()) {
        shuffleOutputsOK.incr();
      } else {
        shuffleOutputsFailed.incr();
      }
      shuffleConnections.decr();
    }
  }

  @SuppressWarnings("checkstyle:VisibilityModifier")
  protected final MetricsSystem ms;
  @SuppressWarnings("checkstyle:VisibilityModifier")
  final ShuffleMetrics metrics;

  ShuffleHandler(MetricsSystem ms) {
    super(MAPREDUCE_SHUFFLE_SERVICEID);
    this.ms = ms;
    metrics = ms.register(new ShuffleMetrics());
  }

  public ShuffleHandler() {
    this(DefaultMetricsSystem.instance());
  }

  /**
   * Serialize the shuffle port into a ByteBuffer for use later on.
   * @param port the port to be sent to the ApplciationMaster
   * @return the serialized form of the port.
   * @throws IOException on failure
   */
  public static ByteBuffer serializeMetaData(int port) throws IOException {
    //TODO these bytes should be versioned
    DataOutputBuffer portDob = new DataOutputBuffer();
    portDob.writeInt(port);
    return ByteBuffer.wrap(portDob.getData(), 0, portDob.getLength());
  }

  /**
   * A helper function to deserialize the metadata returned by ShuffleHandler.
   * @param meta the metadata returned by the ShuffleHandler
   * @return the port the Shuffle Handler is listening on to serve shuffle data.
   * @throws IOException on failure
   */
  public static int deserializeMetaData(ByteBuffer meta) throws IOException {
    //TODO this should be returning a class not just an int
    DataInputByteBuffer in = new DataInputByteBuffer();
    in.reset(meta);
    int port = in.readInt();
    return port;
  }

  /**
   * A helper function to serialize the JobTokenIdentifier to be sent to the
   * ShuffleHandler as ServiceData.
   * @param jobToken the job token to be used for authentication of
   * shuffle data requests.
   * @return the serialized version of the jobToken.
   * @throws IOException on failure
   */
  public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken)
      throws IOException {
    //TODO these bytes should be versioned
    DataOutputBuffer jobTokenDob = new DataOutputBuffer();
    jobToken.write(jobTokenDob);
    return ByteBuffer.wrap(jobTokenDob.getData(), 0, jobTokenDob.getLength());
  }

  public static Token<JobTokenIdentifier> deserializeServiceData(ByteBuffer secret)
      throws IOException {
    DataInputByteBuffer in = new DataInputByteBuffer();
    in.reset(secret);
    Token<JobTokenIdentifier> jt = new Token<JobTokenIdentifier>();
    jt.readFields(in);
    return jt;
  }

  @Override
  public void initializeApplication(ApplicationInitializationContext context) {

    String user = context.getUser();
    ApplicationId appId = context.getApplicationId();
    ByteBuffer secret = context.getApplicationDataForService();
    // TODO these bytes should be versioned
    try {
      Token<JobTokenIdentifier> jt = deserializeServiceData(secret);
       // TODO: Once SHuffle is out of NM, this can use MR APIs
      JobID jobId = new JobID(Long.toString(appId.getClusterTimestamp()), appId.getId());
      recordJobShuffleInfo(jobId, user, jt);
    } catch (IOException e) {
      LOG.error("Error during initApp", e);
      // TODO add API to AuxiliaryServices to report failures
    }
  }

  @Override
  public void stopApplication(ApplicationTerminationContext context) {
    ApplicationId appId = context.getApplicationId();
    JobID jobId = new JobID(Long.toString(appId.getClusterTimestamp()), appId.getId());
    try {
      removeJobShuffleInfo(jobId);
    } catch (IOException e) {
      LOG.error("Error during stopApp", e);
      // TODO add API to AuxiliaryServices to report failures
    }
  }

  @Override
  protected void serviceInit(Configuration conf) throws Exception {
    int maxShuffleThreads = conf.getInt(MAX_SHUFFLE_THREADS,
                                        DEFAULT_MAX_SHUFFLE_THREADS);
    // Since Netty 4.x, the value of 0 threads would default to:
    // io.netty.channel.MultithreadEventLoopGroup.DEFAULT_EVENT_LOOP_THREADS
    // by simply passing 0 value to NioEventLoopGroup constructor below.
    // However, this logic to determinte thread count
    // was in place so we can keep it for now.
    if (maxShuffleThreads == 0) {
      maxShuffleThreads = 2 * Runtime.getRuntime().availableProcessors();
    }

    ThreadFactory bossFactory = new ThreadFactoryBuilder()
        .setNameFormat("ShuffleHandler Netty Boss #%d")
        .build();
    ThreadFactory workerFactory = new ThreadFactoryBuilder()
        .setNameFormat("ShuffleHandler Netty Worker #%d")
        .build();
    
    bossGroup = new NioEventLoopGroup(1, bossFactory);
    workerGroup = new NioEventLoopGroup(maxShuffleThreads, workerFactory);
    super.serviceInit(new Configuration(conf));
  }

  protected ShuffleChannelHandlerContext createHandlerContext() {
    Configuration conf = getConfig();

    final LoadingCache<AttemptPathIdentifier, AttemptPathInfo> pathCache =
        CacheBuilder.newBuilder().expireAfterAccess(
                conf.getInt(EXPIRE_AFTER_ACCESS_MINUTES, DEFAULT_EXPIRE_AFTER_ACCESS_MINUTES),
                TimeUnit.MINUTES).softValues().concurrencyLevel(conf.getInt(CONCURRENCY_LEVEL,
                DEFAULT_CONCURRENCY_LEVEL)).
            removalListener(
                (RemovalListener<AttemptPathIdentifier, AttemptPathInfo>) notification -> {
                  if (LOG.isDebugEnabled()) {
                    LOG.debug("PathCache Eviction: " + notification.getKey() +
                        ", Reason=" + notification.getCause());
                  }
                }
            ).maximumWeight(conf.getInt(MAX_WEIGHT, DEFAULT_MAX_WEIGHT)).weigher(
                (key, value) -> key.jobId.length() + key.user.length() +
                    key.attemptId.length()+
                    value.indexPath.toString().length() +
                    value.dataPath.toString().length()
            ).build(new CacheLoader<AttemptPathIdentifier, AttemptPathInfo>() {
              @Override
              public AttemptPathInfo load(@Nonnull AttemptPathIdentifier key) throws
                  Exception {
                String base = getBaseLocation(key.jobId, key.user);
                String attemptBase = base + key.attemptId;
                Path indexFileName = getAuxiliaryLocalPathHandler()
                    .getLocalPathForRead(attemptBase + "/" + INDEX_FILE_NAME);
                Path mapOutputFileName = getAuxiliaryLocalPathHandler()
                    .getLocalPathForRead(attemptBase + "/" + DATA_FILE_NAME);

                if (LOG.isDebugEnabled()) {
                  LOG.debug("Loaded : " + key + " via loader");
                }
                return new AttemptPathInfo(indexFileName, mapOutputFileName);
              }
            });

    return new ShuffleChannelHandlerContext(conf,
        userRsrc,
        secretManager,
        pathCache,
        new IndexCache(new JobConf(conf)),
        metrics,
        allChannels
    );
  }

  // TODO change AbstractService to throw InterruptedException
  @Override
  protected void serviceStart() throws Exception {
    Configuration conf = getConfig();
    userRsrc = new ConcurrentHashMap<>();
    secretManager = new JobTokenSecretManager();
    recoverState(conf);

    if (conf.getBoolean(MRConfig.SHUFFLE_SSL_ENABLED_KEY,
        MRConfig.SHUFFLE_SSL_ENABLED_DEFAULT)) {
      LOG.info("Encrypted shuffle is enabled.");
      sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf);
      sslFactory.init();
    }

    ShuffleChannelHandlerContext handlerContext = createHandlerContext();
    ServerBootstrap bootstrap = new ServerBootstrap();
    bootstrap.group(bossGroup, workerGroup)
        .channel(NioServerSocketChannel.class)
        .option(ChannelOption.SO_BACKLOG,
            conf.getInt(SHUFFLE_LISTEN_QUEUE_SIZE,
                DEFAULT_SHUFFLE_LISTEN_QUEUE_SIZE))
        .childOption(ChannelOption.SO_KEEPALIVE, true)
        .childHandler(new ShuffleChannelInitializer(
            handlerContext,
            sslFactory)
        );
    port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT);
    Channel ch = bootstrap.bind(new InetSocketAddress(port)).sync().channel();
    port = ((InetSocketAddress)ch.localAddress()).getPort();
    allChannels.add(ch);
    conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port));
    handlerContext.setPort(port);
    LOG.info(getName() + " listening on port " + port);
    super.serviceStart();
  }

  @Override
  protected void serviceStop() throws Exception {
    allChannels.close().awaitUninterruptibly(10, TimeUnit.SECONDS);

    if (sslFactory != null) {
      sslFactory.destroy();
    }

    if (stateDb != null) {
      stateDb.close();
    }
    ms.unregisterSource(ShuffleMetrics.class.getSimpleName());

    if (bossGroup != null) {
      bossGroup.shutdownGracefully();
    }

    if (workerGroup != null) {
      workerGroup.shutdownGracefully();
    }

    super.serviceStop();
  }

  @Override
  public synchronized ByteBuffer getMetaData() {
    try {
      return serializeMetaData(port); 
    } catch (IOException e) {
      LOG.error("Error during getMeta", e);
      // TODO add API to AuxiliaryServices to report failures
      return null;
    }
  }

  private void recoverState(Configuration conf) throws IOException {
    Path recoveryRoot = getRecoveryPath();
    if (recoveryRoot != null) {
      startStore(recoveryRoot);
      Pattern jobPattern = Pattern.compile(JobID.JOBID_REGEX);
      LeveldbIterator iter = null;
      try {
        iter = new LeveldbIterator(stateDb);
        iter.seek(bytes(JobID.JOB));
        while (iter.hasNext()) {
          Map.Entry<byte[],byte[]> entry = iter.next();
          String key = asString(entry.getKey());
          if (!jobPattern.matcher(key).matches()) {
            break;
          }
          recoverJobShuffleInfo(key, entry.getValue());
        }
      } catch (DBException e) {
        throw new IOException("Database error during recovery", e);
      } finally {
        if (iter != null) {
          iter.close();
        }
      }
    }
  }

  private void startStore(Path recoveryRoot) throws IOException {
    Options options = new Options();
    options.createIfMissing(false);
    Path dbPath = new Path(recoveryRoot, STATE_DB_NAME);
    LOG.info("Using state database at " + dbPath + " for recovery");
    File dbfile = new File(dbPath.toString());
    try {
      stateDb = JniDBFactory.factory.open(dbfile, options);
    } catch (NativeDB.DBException e) {
      if (e.isNotFound() || e.getMessage().contains(" does not exist ")) {
        LOG.info("Creating state database at " + dbfile);
        options.createIfMissing(true);
        try {
          stateDb = JniDBFactory.factory.open(dbfile, options);
          storeVersion();
        } catch (DBException dbExc) {
          throw new IOException("Unable to create state store", dbExc);
        }
      } else {
        throw e;
      }
    }
    checkVersion();
  }
  
  @VisibleForTesting
  Version loadVersion() throws IOException {
    byte[] data = stateDb.get(bytes(STATE_DB_SCHEMA_VERSION_KEY));
    // if version is not stored previously, treat it as CURRENT_VERSION_INFO.
    if (data == null || data.length == 0) {
      return getCurrentVersion();
    }
    Version version =
        new VersionPBImpl(VersionProto.parseFrom(data));
    return version;
  }

  private void storeSchemaVersion(Version version) throws IOException {
    String key = STATE_DB_SCHEMA_VERSION_KEY;
    byte[] data = 
        ((VersionPBImpl) version).getProto().toByteArray();
    try {
      stateDb.put(bytes(key), data);
    } catch (DBException e) {
      throw new IOException(e.getMessage(), e);
    }
  }
  
  private void storeVersion() throws IOException {
    storeSchemaVersion(CURRENT_VERSION_INFO);
  }
  
  // Only used for test
  @VisibleForTesting
  void storeVersion(Version version) throws IOException {
    storeSchemaVersion(version);
  }

  protected Version getCurrentVersion() {
    return CURRENT_VERSION_INFO;
  }
  
  /**
   * 1) Versioning scheme: major.minor. For e.g. 1.0, 1.1, 1.2...1.25, 2.0 etc.
   * 2) Any incompatible change of DB schema is a major upgrade, and any
   *    compatible change of DB schema is a minor upgrade.
   * 3) Within a minor upgrade, say 1.1 to 1.2:
   *    overwrite the version info and proceed as normal.
   * 4) Within a major upgrade, say 1.2 to 2.0:
   *    throw exception and indicate user to use a separate upgrade tool to
   *    upgrade shuffle info or remove incompatible old state.
   */
  private void checkVersion() throws IOException {
    Version loadedVersion = loadVersion();
    LOG.info("Loaded state DB schema version info " + loadedVersion);
    if (loadedVersion.equals(getCurrentVersion())) {
      return;
    }
    if (loadedVersion.isCompatibleTo(getCurrentVersion())) {
      LOG.info("Storing state DB schema version info " + getCurrentVersion());
      storeVersion();
    } else {
      throw new IOException(
        "Incompatible version for state DB schema: expecting DB schema version " 
            + getCurrentVersion() + ", but loading version " + loadedVersion);
    }
  }

  private void addJobToken(JobID jobId, String user,
      Token<JobTokenIdentifier> jobToken) {
    userRsrc.put(jobId.toString(), user);
    secretManager.addTokenForJob(jobId.toString(), jobToken);
    LOG.info("Added token for " + jobId.toString());
  }

  private void recoverJobShuffleInfo(String jobIdStr, byte[] data)
      throws IOException {
    JobID jobId;
    try {
      jobId = JobID.forName(jobIdStr);
    } catch (IllegalArgumentException e) {
      throw new IOException("Bad job ID " + jobIdStr + " in state store", e);
    }

    JobShuffleInfoProto proto = JobShuffleInfoProto.parseFrom(data);
    String user = proto.getUser();
    TokenProto tokenProto = proto.getJobToken();
    Token<JobTokenIdentifier> jobToken = new Token<>(
        tokenProto.getIdentifier().toByteArray(),
        tokenProto.getPassword().toByteArray(),
        new Text(tokenProto.getKind()), new Text(tokenProto.getService()));
    addJobToken(jobId, user, jobToken);
  }

  private void recordJobShuffleInfo(JobID jobId, String user,
      Token<JobTokenIdentifier> jobToken) throws IOException {
    if (stateDb != null) {
      TokenProto tokenProto = TokenProto.newBuilder()
          .setIdentifier(ByteString.copyFrom(jobToken.getIdentifier()))
          .setPassword(ByteString.copyFrom(jobToken.getPassword()))
          .setKind(jobToken.getKind().toString())
          .setService(jobToken.getService().toString())
          .build();
      JobShuffleInfoProto proto = JobShuffleInfoProto.newBuilder()
          .setUser(user).setJobToken(tokenProto).build();
      try {
        stateDb.put(bytes(jobId.toString()), proto.toByteArray());
      } catch (DBException e) {
        throw new IOException("Error storing " + jobId, e);
      }
    }
    addJobToken(jobId, user, jobToken);
  }

  private void removeJobShuffleInfo(JobID jobId) throws IOException {
    String jobIdStr = jobId.toString();
    secretManager.removeTokenForJob(jobIdStr);
    userRsrc.remove(jobIdStr);
    if (stateDb != null) {
      try {
        stateDb.delete(bytes(jobIdStr));
      } catch (DBException e) {
        throw new IOException("Unable to remove " + jobId
            + " from state store", e);
      }
    }
  }

  static class TimeoutHandler extends IdleStateHandler {
    private final int connectionKeepAliveTimeOut;
    private boolean enabledTimeout;

    TimeoutHandler(int connectionKeepAliveTimeOut) {
      //disable reader timeout
      //set writer timeout to configured timeout value
      //disable all idle timeout
      super(0, connectionKeepAliveTimeOut, 0, TimeUnit.SECONDS);
      this.connectionKeepAliveTimeOut = connectionKeepAliveTimeOut;
    }

    void setEnabledTimeout(boolean enabledTimeout) {
      this.enabledTimeout = enabledTimeout;
    }

    @Override
    public void channelIdle(ChannelHandlerContext ctx, IdleStateEvent e) {
      if (e.state() == IdleState.WRITER_IDLE && enabledTimeout) {
        LOG.debug("Closing channel as writer was idle for {} seconds", connectionKeepAliveTimeOut);
        ctx.channel().close();
      }
    }
  }

  @SuppressWarnings("checkstyle:VisibilityModifier")
  static class AttemptPathInfo {
    // TODO Change this over to just store local dir indices, instead of the
    // entire path. Far more efficient.
    public final Path indexPath;
    public final Path dataPath;

    AttemptPathInfo(Path indexPath, Path dataPath) {
      this.indexPath = indexPath;
      this.dataPath = dataPath;
    }
  }

  @SuppressWarnings("checkstyle:VisibilityModifier")
  static class AttemptPathIdentifier {
    public final String jobId;
    public final String user;
    public final String attemptId;

    AttemptPathIdentifier(String jobId, String user, String attemptId) {
      this.jobId = jobId;
      this.user = user;
      this.attemptId = attemptId;
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) {
        return true;
      }
      if (o == null || getClass() != o.getClass()) {
        return false;
      }

      AttemptPathIdentifier that = (AttemptPathIdentifier) o;

      if (!attemptId.equals(that.attemptId)) {
        return false;
      }
      if (!jobId.equals(that.jobId)) {
        return false;
      }

      return true;
    }

    @Override
    public int hashCode() {
      int result = jobId.hashCode();
      result = 31 * result + attemptId.hashCode();
      return result;
    }

    @Override
    public String toString() {
      return "AttemptPathIdentifier{" +
          "attemptId='" + attemptId + '\'' +
          ", jobId='" + jobId + '\'' +
          '}';
    }
  }

  private static String getBaseLocation(String jobId, String user) {
    final JobID jobID = JobID.forName(jobId);
    final ApplicationId appID =
        ApplicationId.newInstance(Long.parseLong(jobID.getJtIdentifier()),
            jobID.getId());
    return ContainerLocalizer.USERCACHE + "/" + user + "/"
        + ContainerLocalizer.APPCACHE + "/"
        + appID + "/output" + "/";
  }
}