RequestLimit.java

/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.server.handlers;

import io.undertow.UndertowLogger;
import io.undertow.server.Connectors;
import io.undertow.server.ExchangeCompletionListener;
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.SameThreadExecutor;

import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;

/**
 * Represents a limit on a number of running requests.
 * <p>
 * This is basically a counter with a configured set of limits, that is used by {@link RequestLimitingHandler}.
 * <p>
 * When the number of active requests goes over the configured max requests then requests will be suspended and queued.
 * <p>
 * If the queue is full requests will be rejected with a 503 Service Unavailable according to RFC7231 Section 6.6.4.
 * <p>
 * The reason why this is abstracted out into a separate class is so that multiple handlers can share the same state. This
 * allows for fine grained control of resources.
 *
 * @author Stuart Douglas
 * @see RequestLimitingHandler
 */
public class RequestLimit {
    @SuppressWarnings("unused")
    private volatile int requests;
    private volatile int max;

    private static final AtomicIntegerFieldUpdater<RequestLimit> requestsUpdater = AtomicIntegerFieldUpdater.newUpdater(RequestLimit.class, "requests");


    /**
     * The handler that will be invoked if the queue is full.
     */
    private volatile HttpHandler failureHandler = new ResponseCodeHandler(503);

    private final Queue<SuspendedRequest> queue;

    private final ExchangeCompletionListener COMPLETION_LISTENER = new ExchangeCompletionListener() {

        @Override
        public void exchangeEvent(final HttpServerExchange exchange, final NextListener nextListener) {
            SuspendedRequest task = null;
            boolean found = false;
            while ((task = queue.poll()) != null) {
                try {
                    task.exchange.addExchangeCompleteListener(COMPLETION_LISTENER);
                    task.exchange.dispatch(task.next);
                    found = true;
                    break;
                } catch (Throwable e) {
                    UndertowLogger.ROOT_LOGGER.error("Suspended request was skipped", e);
                }
            }

            if (!found) {
                decrementRequests();
            }

            nextListener.proceed();
        }
    };


    public RequestLimit(int maximumConcurrentRequests) {
        this(maximumConcurrentRequests, -1);
    }

    /**
     * Construct a new instance. The maximum number of concurrent requests must be at least one.
     *
     * @param maximumConcurrentRequests the maximum concurrent requests
     * @param queueSize                 The maximum number of requests to queue
     */
    public RequestLimit(int maximumConcurrentRequests, int queueSize) {
        if (maximumConcurrentRequests < 1) {
            throw new IllegalArgumentException("Maximum concurrent requests must be at least 1");
        }
        max = maximumConcurrentRequests;

        this.queue = new LinkedBlockingQueue<>(queueSize <= 0 ? Integer.MAX_VALUE : queueSize);
    }

    public void handleRequest(final HttpServerExchange exchange, final HttpHandler next) throws Exception {
        int oldVal, newVal;
        do {
            oldVal = requests;
            if (oldVal >= max) {
                exchange.dispatch(SameThreadExecutor.INSTANCE, new Runnable() {
                    @Override
                    public void run() {
                        //we have to try again in the sync block
                        //we need to have already dispatched for thread safety reasons
                        synchronized (RequestLimit.this) {
                            int oldVal, newVal;
                            do {
                                oldVal = requests;
                                if (oldVal >= max) {
                                    if (!queue.offer(new SuspendedRequest(exchange, next))) {
                                        Connectors.executeRootHandler(failureHandler, exchange);
                                    }
                                    return;
                                }
                                newVal = oldVal + 1;
                            } while (!requestsUpdater.compareAndSet(RequestLimit.this, oldVal, newVal));
                            exchange.addExchangeCompleteListener(COMPLETION_LISTENER);
                            exchange.dispatch(next);
                        }
                    }
                });
                return;
            }
            newVal = oldVal + 1;
        } while (!requestsUpdater.compareAndSet(this, oldVal, newVal));
        exchange.addExchangeCompleteListener(COMPLETION_LISTENER);
        next.handleRequest(exchange);
    }

    /**
     * Get the maximum concurrent requests.
     *
     * @return the maximum concurrent requests
     */
    public int getMaximumConcurrentRequests() {
        return max;
    }

    /**
     * Set the maximum concurrent requests.  The value must be greater than or equal to one.
     *
     * @param newMax the maximum concurrent requests
     */
    public int setMaximumConcurrentRequests(int newMax) {
        if (newMax < 1) {
            throw new IllegalArgumentException("Maximum concurrent requests must be at least 1");
        }
        int oldMax = this.max;
        this.max = newMax;
        if(newMax > oldMax) {
            synchronized (this) {
                while (!queue.isEmpty()) {
                    int oldVal, newVal;
                    do {
                        oldVal = requests;
                        if (oldVal >= max) {
                            return oldMax;
                        }
                        newVal = oldVal + 1;
                    } while (!requestsUpdater.compareAndSet(this, oldVal, newVal));
                    SuspendedRequest res = queue.poll();
                    res.exchange.dispatch(res.next);
                }
            }
        }
        return oldMax;
    }

    private void decrementRequests() {
        requestsUpdater.decrementAndGet(this);
    }

    public HttpHandler getFailureHandler() {
        return failureHandler;
    }

    public void setFailureHandler(HttpHandler failureHandler) {
        this.failureHandler = failureHandler;
    }

    private static final class SuspendedRequest {
        final HttpServerExchange exchange;
        final HttpHandler next;

        private SuspendedRequest(HttpServerExchange exchange, HttpHandler next) {
            this.exchange = exchange;
            this.next = next;
        }
    }
}