DoSFilter.java
/*
* Copyright 2015 ZXing authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.zxing.web;
import com.google.common.base.Preconditions;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.FilterConfig;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Timer;
import java.util.concurrent.TimeUnit;
/**
* A simplistic {@link Filter} that rejects requests from hosts that are sending too many
* requests in too short a time.
*
* @author Sean Owen
*/
public abstract class DoSFilter implements Filter {
private Timer timer;
private DoSTracker sourceAddrTracker;
@Override
public void init(FilterConfig filterConfig) {
int maxAccessPerTime = Integer.parseInt(filterConfig.getInitParameter("maxAccessPerTime"));
Preconditions.checkArgument(maxAccessPerTime > 0);
int accessTimeSec = Integer.parseInt(filterConfig.getInitParameter("accessTimeSec"));
Preconditions.checkArgument(accessTimeSec > 0);
long accessTimeMS = TimeUnit.MILLISECONDS.convert(accessTimeSec, TimeUnit.SECONDS);
String maxEntriesValue = filterConfig.getInitParameter("maxEntries");
int maxEntries = Integer.MAX_VALUE;
if (maxEntriesValue != null) {
maxEntries = Integer.parseInt(maxEntriesValue);
Preconditions.checkArgument(maxEntries > 0);
}
String maxLoadValue = filterConfig.getInitParameter("maxLoad");
Double maxLoad = null;
if (maxLoadValue != null) {
maxLoad = Double.valueOf(maxLoadValue);
Preconditions.checkArgument(maxLoad > 0.0);
}
String name = getClass().getSimpleName();
timer = new Timer(name);
sourceAddrTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries, maxLoad);
}
@Override
public void doFilter(ServletRequest request,
ServletResponse response,
FilterChain chain) throws IOException, ServletException {
if (isBanned((HttpServletRequest) request)) {
HttpServletResponse servletResponse = (HttpServletResponse) response;
// Send very short response as requests may be very frequent
servletResponse.setStatus(429); // 429 = Too Many Requests from RFC 6585
servletResponse.getWriter().write("Forbidden");
} else {
chain.doFilter(request, response);
}
}
private boolean isBanned(HttpServletRequest request) {
String remoteHost = request.getHeader("x-forwarded-for");
if (remoteHost != null) {
int comma = remoteHost.indexOf(',');
if (comma >= 0) {
remoteHost = remoteHost.substring(0, comma);
}
remoteHost = remoteHost.trim();
}
// Non-short-circuit "|" below is on purpose
return
(remoteHost != null && sourceAddrTracker.isBanned(remoteHost)) |
sourceAddrTracker.isBanned(request.getRemoteAddr());
}
@Override
public void destroy() {
if (timer != null) {
timer.cancel();
}
}
}