Tasks.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.fs.tosfs.common;

import org.apache.hadoop.util.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
import java.util.stream.Stream;

/**
 * Copied from Apache Iceberg.
 */
public final class Tasks {

  private static final Logger LOG = LoggerFactory.getLogger(Tasks.class);

  private Tasks() {
  }

  public static class UnrecoverableException extends RuntimeException {
    public UnrecoverableException(String message) {
      super(message);
    }

    public UnrecoverableException(String message, Throwable cause) {
      super(message, cause);
    }

    public UnrecoverableException(Throwable cause) {
      super(cause);
    }
  }

  public interface FailureTask<I, E extends Exception> {
    void run(I item, Exception exception) throws E;
  }

  public interface Task<I, E extends Exception> {
    void run(I item) throws E;
  }

  public static class Builder<I> {
    private final Iterable<I> items;
    private ExecutorService service = null;
    private FailureTask<I, ?> onFailure = null;
    private boolean stopOnFailure = false;
    private boolean throwFailureWhenFinished = true;
    private Task<I, ?> revertTask = null;
    private boolean stopRevertsOnFailure = false;
    private Task<I, ?> abortTask = null;
    private boolean stopAbortsOnFailure = false;

    // retry settings
    private List<Class<? extends Exception>> stopRetryExceptions =
        Lists.newArrayList(UnrecoverableException.class);
    private List<Class<? extends Exception>> onlyRetryExceptions = null;
    private Predicate<Exception> shouldRetryPredicate = null;
    private int maxAttempts = 1; // not all operations can be retried
    private long minSleepTimeMs = 1000; // 1 second
    private long maxSleepTimeMs = 600000; // 10 minutes
    private long maxDurationMs = 600000; // 10 minutes
    private double scaleFactor = 2.0; // exponential

    public Builder(Iterable<I> items) {
      this.items = items;
    }

    public Builder<I> executeWith(ExecutorService svc) {
      this.service = svc;
      return this;
    }

    public Builder<I> onFailure(FailureTask<I, ?> task) {
      this.onFailure = task;
      return this;
    }

    public Builder<I> stopOnFailure() {
      this.stopOnFailure = true;
      return this;
    }

    public Builder<I> throwFailureWhenFinished() {
      this.throwFailureWhenFinished = true;
      return this;
    }

    public Builder<I> throwFailureWhenFinished(boolean throwWhenFinished) {
      this.throwFailureWhenFinished = throwWhenFinished;
      return this;
    }

    public Builder<I> suppressFailureWhenFinished() {
      this.throwFailureWhenFinished = false;
      return this;
    }

    public Builder<I> revertWith(Task<I, ?> task) {
      this.revertTask = task;
      return this;
    }

    public Builder<I> stopRevertsOnFailure() {
      this.stopRevertsOnFailure = true;
      return this;
    }

    public Builder<I> abortWith(Task<I, ?> task) {
      this.abortTask = task;
      return this;
    }

    public Builder<I> stopAbortsOnFailure() {
      this.stopAbortsOnFailure = true;
      return this;
    }

    @SafeVarargs public final Builder<I> stopRetryOn(Class<? extends Exception>... exceptions) {
      stopRetryExceptions.addAll(Arrays.asList(exceptions));
      return this;
    }

    public Builder<I> shouldRetryTest(Predicate<Exception> shouldRetry) {
      this.shouldRetryPredicate = shouldRetry;
      return this;
    }

    public Builder<I> noRetry() {
      this.maxAttempts = 1;
      return this;
    }

    public Builder<I> retry(int nTimes) {
      this.maxAttempts = nTimes + 1;
      return this;
    }

    public Builder<I> onlyRetryOn(Class<? extends Exception> exception) {
      this.onlyRetryExceptions = Collections.singletonList(exception);
      return this;
    }

    @SafeVarargs public final Builder<I> onlyRetryOn(Class<? extends Exception>... exceptions) {
      this.onlyRetryExceptions = Lists.newArrayList(exceptions);
      return this;
    }

    public Builder<I> exponentialBackoff(long backoffMinSleepTimeMs, long backoffMaxSleepTimeMs,
        long backoffMaxRetryTimeMs, double backoffScaleFactor) {
      this.minSleepTimeMs = backoffMinSleepTimeMs;
      this.maxSleepTimeMs = backoffMaxSleepTimeMs;
      this.maxDurationMs = backoffMaxRetryTimeMs;
      this.scaleFactor = backoffScaleFactor;
      return this;
    }

    public boolean run(Task<I, RuntimeException> task) {
      return run(task, RuntimeException.class);
    }

    public <E extends Exception> boolean run(Task<I, E> task, Class<E> exceptionClass) throws E {
      if (service != null) {
        return runParallel(task, exceptionClass);
      } else {
        return runSingleThreaded(task, exceptionClass);
      }
    }

    private <E extends Exception> boolean runSingleThreaded(
        Task<I, E> task, Class<E> exceptionClass) throws E {
      List<I> succeeded = Lists.newArrayList();
      List<Throwable> exceptions = Lists.newArrayList();

      Iterator<I> iterator = items.iterator();
      boolean threw = true;
      try {
        while (iterator.hasNext()) {
          I item = iterator.next();
          try {
            runTaskWithRetry(task, item);
            succeeded.add(item);
          } catch (Exception e) {
            exceptions.add(e);

            if (onFailure != null) {
              tryRunOnFailure(item, e);
            }

            if (stopOnFailure) {
              break;
            }
          }
        }

        threw = false;

      } finally {
        // threw handles exceptions that were *not* caught by the catch block,
        // and exceptions that were caught and possibly handled by onFailure
        // are kept in exceptions.
        if (threw || !exceptions.isEmpty()) {
          if (revertTask != null) {
            boolean failed = false;
            for (I item : succeeded) {
              try {
                revertTask.run(item);
              } catch (Exception e) {
                failed = true;
                LOG.error("Failed to revert task", e);
                // keep going
              }
              if (stopRevertsOnFailure && failed) {
                break;
              }
            }
          }

          if (abortTask != null) {
            boolean failed = false;
            while (iterator.hasNext()) {
              try {
                abortTask.run(iterator.next());
              } catch (Exception e) {
                failed = true;
                LOG.error("Failed to abort task", e);
                // keep going
              }
              if (stopAbortsOnFailure && failed) {
                break;
              }
            }
          }
        }
      }

      if (throwFailureWhenFinished && !exceptions.isEmpty()) {
        Tasks.throwOne(exceptions, exceptionClass);
      } else if (throwFailureWhenFinished && threw) {
        throw new RuntimeException("Task set failed with an uncaught throwable");
      }

      return !threw;
    }

    private void tryRunOnFailure(I item, Exception failure) {
      try {
        onFailure.run(item, failure);
      } catch (Exception failException) {
        failure.addSuppressed(failException);
        LOG.error("Failed to clean up on failure", failException);
        // keep going
      }
    }

    private <E extends Exception> boolean runParallel(
        final Task<I, E> task, Class<E> exceptionClass) throws E {
      final Queue<I> succeeded = new ConcurrentLinkedQueue<>();
      final Queue<Throwable> exceptions = new ConcurrentLinkedQueue<>();
      final AtomicBoolean taskFailed = new AtomicBoolean(false);
      final AtomicBoolean abortFailed = new AtomicBoolean(false);
      final AtomicBoolean revertFailed = new AtomicBoolean(false);

      List<Future<?>> futures = Lists.newArrayList();

      for (final I item : items) {
        // submit a task for each item that will either run or abort the task
        futures.add(service.submit(() -> {
          if (!(stopOnFailure && taskFailed.get())) {
            // run the task with retries
            boolean threw = true;
            try {
              runTaskWithRetry(task, item);

              succeeded.add(item);

              threw = false;

            } catch (Exception e) {
              taskFailed.set(true);
              exceptions.add(e);

              if (onFailure != null) {
                tryRunOnFailure(item, e);
              }
            } finally {
              if (threw) {
                taskFailed.set(true);
              }
            }

          } else if (abortTask != null) {
            // abort the task instead of running it
            if (stopAbortsOnFailure && abortFailed.get()) {
              return;
            }

            boolean failed = true;
            try {
              abortTask.run(item);
              failed = false;
            } catch (Exception e) {
              LOG.error("Failed to abort task", e);
              // swallow the exception
            } finally {
              if (failed) {
                abortFailed.set(true);
              }
            }
          }
        }));
      }

      // let the above tasks complete (or abort)
      exceptions.addAll(waitFor(futures));
      futures.clear();

      if (taskFailed.get() && revertTask != null) {
        // at least one task failed, revert any that succeeded
        for (final I item : succeeded) {
          futures.add(service.submit(() -> {
            if (stopRevertsOnFailure && revertFailed.get()) {
              return;
            }

            boolean failed = true;
            try {
              revertTask.run(item);
              failed = false;
            } catch (Exception e) {
              LOG.error("Failed to revert task", e);
              // swallow the exception
            } finally {
              if (failed) {
                revertFailed.set(true);
              }
            }
          }));
        }

        // let the revert tasks complete
        exceptions.addAll(waitFor(futures));
      }

      if (throwFailureWhenFinished && !exceptions.isEmpty()) {
        Tasks.throwOne(exceptions, exceptionClass);
      } else if (throwFailureWhenFinished && taskFailed.get()) {
        throw new RuntimeException("Task set failed with an uncaught throwable");
      }

      return !taskFailed.get();
    }

    private <E extends Exception> void runTaskWithRetry(
        Task<I, E> task, I item) throws E {
      long start = System.currentTimeMillis();
      int attempt = 0;
      while (true) {
        attempt += 1;
        try {
          task.run(item);
          break;

        } catch (Exception e) {
          long durationMs = System.currentTimeMillis() - start;
          if (attempt >= maxAttempts || (durationMs > maxDurationMs && attempt > 1)) {
            if (durationMs > maxDurationMs) {
              LOG.info("Stopping retries after {} ms", durationMs);
            }
            throw e;
          }

          if (shouldRetryPredicate != null) {
            if (!shouldRetryPredicate.test(e)) {
              throw e;
            }

          } else if (onlyRetryExceptions != null) {
            // if onlyRetryExceptions are present, then this retries if one is found
            boolean matchedRetryException = false;
            for (Class<? extends Exception> exClass : onlyRetryExceptions) {
              if (exClass.isInstance(e)) {
                matchedRetryException = true;
                break;
              }
            }
            if (!matchedRetryException) {
              throw e;
            }

          } else {
            // otherwise, always retry unless one of the stop exceptions is found
            for (Class<? extends Exception> exClass : stopRetryExceptions) {
              if (exClass.isInstance(e)) {
                throw e;
              }
            }
          }

          int delayMs =
              (int) Math.min(minSleepTimeMs * Math.pow(scaleFactor, attempt - 1), maxSleepTimeMs);
          int jitter = ThreadLocalRandom.current().nextInt(Math.max(1, (int) (delayMs * 0.1)));

          LOG.warn("Retrying task after failure: {}", e.getMessage(), e);

          try {
            TimeUnit.MILLISECONDS.sleep(delayMs + jitter);
          } catch (InterruptedException ie) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(ie);
          }
        }
      }
    }
  }

  private static Collection<Throwable> waitFor(
      Collection<Future<?>> futures) {
    while (true) {
      int numFinished = 0;
      for (Future<?> future : futures) {
        if (future.isDone()) {
          numFinished += 1;
        }
      }

      if (numFinished == futures.size()) {
        List<Throwable> uncaught = Lists.newArrayList();
        // all of the futures are done, get any uncaught exceptions
        for (Future<?> future : futures) {
          try {
            future.get();

          } catch (InterruptedException e) {
            LOG.warn("Interrupted while getting future results", e);
            for (Throwable t : uncaught) {
              e.addSuppressed(t);
            }
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);

          } catch (CancellationException e) {
            // ignore cancellations

          } catch (ExecutionException e) {
            Throwable cause = e.getCause();
            if (Error.class.isInstance(cause)) {
              for (Throwable t : uncaught) {
                cause.addSuppressed(t);
              }
              throw (Error) cause;
            }

            if (cause != null) {
              uncaught.add(e);
            }

            LOG.warn("Task threw uncaught exception", cause);
          }
        }

        return uncaught;

      } else {
        try {
          Thread.sleep(10);
        } catch (InterruptedException e) {
          LOG.warn("Interrupted while waiting for tasks to finish", e);

          for (Future<?> future : futures) {
            future.cancel(true);
          }
          Thread.currentThread().interrupt();
          throw new RuntimeException(e);
        }
      }
    }
  }

  /**
   * A range, [ 0, size ).
   */
  private static class Range implements Iterable<Integer> {
    private final int size;

    Range(int size) {
      this.size = size;
    }

    @Override
    public Iterator<Integer> iterator() {
      return new Iterator<Integer>() {
        private int current = 0;

        @Override
        public boolean hasNext() {
          return current < size;
        }

        @Override
        public Integer next() {
          if (!hasNext()) {
            throw new NoSuchElementException("No more items.");
          }
          int ret = current;
          current += 1;
          return ret;
        }
      };
    }
  }

  public static Builder<Integer> range(int upTo) {
    return new Builder<>(new Range(upTo));
  }

  public static <I> Builder<I> foreach(Iterator<I> items) {
    return new Builder<>(() -> items);
  }

  public static <I> Builder<I> foreach(Iterable<I> items) {
    return new Builder<>(items);
  }

  @SafeVarargs public static <I> Builder<I> foreach(I... items) {
    return new Builder<>(Arrays.asList(items));
  }

  public static <I> Builder<I> foreach(Stream<I> items) {
    return new Builder<>(items::iterator);
  }

  private static <E extends Exception> void throwOne(Collection<Throwable> exceptions,
      Class<E> allowedException) throws E {
    Iterator<Throwable> iter = exceptions.iterator();
    Throwable exception = iter.next();
    Class<? extends Throwable> exceptionClass = exception.getClass();

    while (iter.hasNext()) {
      Throwable other = iter.next();
      if (!exceptionClass.isInstance(other)) {
        exception.addSuppressed(other);
      }
    }

    castAndThrow(exception, allowedException);
  }

  public static <E extends Exception> void castAndThrow(
      Throwable exception, Class<E> exceptionClass) throws E {
    if (exception instanceof RuntimeException) {
      throw (RuntimeException) exception;
    } else if (exception instanceof Error) {
      throw (Error) exception;
    } else if (exceptionClass.isInstance(exception)) {
      throw (E) exception;
    }
    throw new RuntimeException(exception);
  }
}