WebSocketTimeoutTestCase.java
/*
* JBoss, Home of Professional Open Source.
* Copyright 2023 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.websockets.core.protocol;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.util.CharsetUtil;
import io.undertow.UndertowOptions;
import io.undertow.testutils.DefaultServer;
import io.undertow.testutils.HttpOneOnly;
import io.undertow.util.NetworkUtils;
import io.undertow.websockets.WebSocketConnectionCallback;
import io.undertow.websockets.WebSocketProtocolHandshakeHandler;
import io.undertow.websockets.core.AbstractReceiveListener;
import io.undertow.websockets.core.BufferedTextMessage;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSockets;
import io.undertow.websockets.utils.FrameChecker;
import io.undertow.websockets.utils.WebSocketTestClient;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.xnio.FutureResult;
import org.xnio.OptionMap;
import org.xnio.Options;
import java.net.URI;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
@RunWith(DefaultServer.class)
@HttpOneOnly
public class WebSocketTimeoutTestCase {
protected static final int TESTABLE_TIMEOUT_VALUE = 2000;
protected static final int NON_TESTABLE_TIMEOUT_VALUE = 30180;
protected static final int DEFAULTS_IO_TIMEOUT_VALUE = 500;
private static ScheduledExecutorService SCHEDULER = null;
@DefaultServer.BeforeServerStarts
public static void beforeTest() {
DefaultServer.setServerOptions(OptionMap.builder()
.set(Options.READ_TIMEOUT, DEFAULTS_IO_TIMEOUT_VALUE)
.set(Options.WRITE_TIMEOUT, DEFAULTS_IO_TIMEOUT_VALUE)
.set(UndertowOptions.WEB_SOCKETS_READ_TIMEOUT, TESTABLE_TIMEOUT_VALUE)
.set(UndertowOptions.WEB_SOCKETS_WRITE_TIMEOUT, NON_TESTABLE_TIMEOUT_VALUE).getMap());
SCHEDULER = Executors.newScheduledThreadPool(2);
}
@DefaultServer.AfterServerStops
public static void afterTest() {
SCHEDULER.shutdown();
DefaultServer.setServerOptions(OptionMap.EMPTY);
}
protected WebSocketVersion getVersion() {
return WebSocketVersion.V13;
}
@Test
public void testServerReadTimeout() throws Exception {
DefaultServer.setRootHandler(new WebSocketProtocolHandshakeHandler(
(WebSocketConnectionCallback) (exchange, channel) -> {
channel.getReceiveSetter().set(new AbstractReceiveListener() {
@Override
protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) {
String string = message.getData();
if (string.equals("hello")) {
WebSockets.sendText("world", channel, null);
} else {
WebSockets.sendText(string, channel, null);
}
}
});
channel.resumeReceives();
}));
final FutureResult<?> latch = new FutureResult<>();
WebSocketTestClient client = new WebSocketTestClient(getVersion(), new URI("ws://" + NetworkUtils.formatPossibleIpv6Address(DefaultServer.getHostAddress("default")) + ":" + DefaultServer.getHostPort("default") + "/"));
client.connect();
client.send(new TextWebSocketFrame(Unpooled.copiedBuffer("hello", CharsetUtil.US_ASCII)), new FrameChecker(TextWebSocketFrame.class, "world".getBytes(CharsetUtil.US_ASCII), latch));
latch.getIoFuture().get();
final long watchStart = System.currentTimeMillis();
final long watchTimeout = System.currentTimeMillis() + TESTABLE_TIMEOUT_VALUE + 1000;
final FutureResult<Long> timeoutLatch = new FutureResult<>();
ReadTimeoutChannelGuard readTimeoutChannelGuard = new ReadTimeoutChannelGuard(client, timeoutLatch, watchTimeout);
final ScheduledFuture<?> sf = SCHEDULER.scheduleAtFixedRate(readTimeoutChannelGuard, 0, 50, TimeUnit.MILLISECONDS);
readTimeoutChannelGuard.setTaskScheduledFuture(sf);
final Long watchTimeEnd = timeoutLatch.getIoFuture().get();
if(watchTimeEnd == -1) {
Assert.fail("Timeout did not happen... in time. Were waiting '" + watchTimeout + "' ms, timeout should happen in '" + TESTABLE_TIMEOUT_VALUE + "' ms.");
} else {
long timeSpent = watchTimeEnd - watchStart;
//let's be generous and give 150ms diff( there is "fuzz" coded for 50ms in undertow as well
if(!(timeSpent <= TESTABLE_TIMEOUT_VALUE + 250)) {
Assert.fail("Timeout did not happen... in time. Socket timeout out in '" + timeSpent + "' ms, supposed to happen in '" + TESTABLE_TIMEOUT_VALUE + "' ms.");
}
}
}
private static class ReadTimeoutChannelGuard implements Runnable {
private final WebSocketTestClient channel;
private final FutureResult<Long> resultHandler;
private final long watchEnd;
private ScheduledFuture<?> sf;
ReadTimeoutChannelGuard(final WebSocketTestClient channel, final FutureResult<Long> resultHandler, final long watchEnd) {
super();
this.channel = channel;
this.resultHandler = resultHandler;
this.watchEnd = watchEnd;
}
public void setTaskScheduledFuture(ScheduledFuture<?> sf2) {
this.sf = sf2;
}
@Override
public void run() {
if (System.currentTimeMillis() > watchEnd) {
sf.cancel(false);
if(channelActive()) {
resultHandler.setResult((long) -1);
} else {
resultHandler.setResult(System.currentTimeMillis());
}
} else {
if(!channelActive()) {
sf.cancel(false);
resultHandler.setResult(System.currentTimeMillis());
}
}
}
private boolean channelActive() {
return channel.isOpen();
}
}
}