Jetty12WebSocketDestination.java
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.transport.websocket.jetty12;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.InetSocketAddress;
import java.net.URL;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.Enumeration;
import java.util.Locale;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
import java.util.logging.Level;
import java.util.logging.Logger;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletConfig;
import jakarta.servlet.ServletContext;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.cxf.Bus;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.service.model.EndpointInfo;
import org.apache.cxf.transport.http.DestinationRegistry;
import org.apache.cxf.transport.http_jetty.JettyHTTPDestination;
import org.apache.cxf.transport.http_jetty.JettyHTTPHandler;
import org.apache.cxf.transport.http_jetty.JettyHTTPServerEngineFactory;
import org.apache.cxf.transport.websocket.InvalidPathException;
import org.apache.cxf.transport.websocket.WebSocketConstants;
import org.apache.cxf.transport.websocket.WebSocketDestinationService;
import org.apache.cxf.transport.websocket.WebSocketUtils;
import org.apache.cxf.transport.websocket.jetty.WebSocketServletHolder;
import org.apache.cxf.transport.websocket.jetty.WebSocketVirtualServletRequest;
import org.apache.cxf.transport.websocket.jetty.WebSocketVirtualServletResponse;
import org.apache.cxf.workqueue.WorkQueueManager;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.ee10.websocket.server.JettyWebSocketCreator;
import org.eclipse.jetty.ee10.websocket.server.JettyWebSocketServerContainer;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
/**
*
*/
public class Jetty12WebSocketDestination extends JettyHTTPDestination implements
WebSocketDestinationService {
private static final Logger LOG = LogUtils.getL7dLogger(Jetty12WebSocketDestination.class);
//REVISIT make these keys configurable
private static final String REQUEST_ID_KEY = WebSocketConstants.DEFAULT_REQUEST_ID_KEY;
private static final String RESPONSE_ID_KEY = WebSocketConstants.DEFAULT_RESPONSE_ID_KEY;
private final Executor executor;
private JettyWebSocketServerContainer webSocketServerContainer;
private Object address;
public Jetty12WebSocketDestination(Bus bus, DestinationRegistry registry, EndpointInfo ei,
JettyHTTPServerEngineFactory serverEngineFactory) throws IOException {
super(bus, registry, ei,
serverEngineFactory == null ? null : new URL(getNonWSAddress(ei)),
serverEngineFactory);
executor = bus.getExtension(WorkQueueManager.class).getAutomaticWorkQueue();
}
@Override
public void invokeInternal(ServletConfig config, ServletContext context, HttpServletRequest req,
HttpServletResponse resp) throws IOException {
super.invoke(config, context, req, resp);
}
public void invoke(final ServletConfig config,
final ServletContext context,
final HttpServletRequest request,
final HttpServletResponse response) throws IOException {
JettyWebSocketServerContainer wssc = getWebSocketContainer(context);
JettyWebSocketCreator creator = getCreator();
address = request.getAttribute("org.apache.cxf.transport.endpoint.address");
try {
if (wssc.upgrade(creator, request, response)) {
return;
}
} catch (Exception ex) {
//do nothing
}
if (address != null) {
request.setAttribute("org.apache.cxf.transport.endpoint.address", address);
}
super.invoke(config, context, request, response);
}
private static String getNonWSAddress(EndpointInfo endpointInfo) {
String address = endpointInfo.getAddress();
if (address.startsWith("ws")) {
address = "http" + address.substring(2);
}
return address;
}
@Override
protected String getAddress(EndpointInfo endpointInfo) {
return getNonWSAddress(endpointInfo);
}
public JettyWebSocketCreator getCreator() {
return new Creator();
}
public synchronized JettyWebSocketServerContainer getWebSocketContainer(ServletContext context) {
if (webSocketServerContainer == null) {
webSocketServerContainer = JettyWebSocketServerContainer.getContainer(context);
if (webSocketServerContainer == null) {
webSocketServerContainer = JettyWebSocketServerContainer.ensureContainer(context);
}
return webSocketServerContainer;
}
return webSocketServerContainer;
}
@Override
protected JettyHTTPHandler createJettyHTTPHandler(JettyHTTPDestination jhd, boolean cmExact) {
return new JettyWebSocketHandler(jhd, cmExact, this);
}
/**
* Activate receipt of incoming messages.
*/
protected void activate() {
synchronized (this) {
if (registry != null) {
registry.addDestination(this);
}
}
LOG.log(Level.FINE, "Activating receipt of incoming messages");
// pick the handler supporting websocket if jetty-websocket is available otherwise pick the default handler.
if (engine != null) {
handler = createJettyHTTPHandler(this, contextMatchOnExact());
engine.addServant(nurl, handler);
((JettyWebSocketHandler)handler).initHandler(engine.getServer());
}
}
@Override
public void shutdown() {
super.shutdown();
}
private void invoke(final byte[] data, final int offset, final int length, final Session session) {
// invoke the service asynchronously as the jetty websocket's onMessage is synchronously blocked
// make sure the byte array passed to this method is immutable, as the websocket framework
// may corrupt the byte array after this method is returned (i.e., before the data is returned in
// the executor's thread.
executeServiceTask(new Runnable() {
@Override
public void run() {
HttpServletResponse response = null;
try {
WebSocketServletHolder holder = new Jetty12WebSocketHolder(session);
response = createServletResponse(holder);
HttpServletRequest request = createServletRequest(data, offset, length, holder, session);
String reqid = request.getHeader(REQUEST_ID_KEY);
if (reqid != null) {
if (WebSocketUtils.isContainingCRLF(reqid)) {
LOG.warning("Invalid characters (CR/LF) in header " + REQUEST_ID_KEY);
} else {
response.setHeader(RESPONSE_ID_KEY, reqid);
}
}
invoke(null, null, request, response);
} catch (InvalidPathException ex) {
reportErrorStatus(session, 400, response);
} catch (Exception e) {
LOG.log(Level.WARNING, "Failed to invoke service", e);
reportErrorStatus(session, 500, response);
}
}
});
}
private void executeServiceTask(Runnable r) {
try {
executor.execute(r);
} catch (RejectedExecutionException e) {
LOG.warning(
"Executor queue is full, run the service invocation task in caller thread."
+ " Users can specify a larger executor queue to avoid this.");
r.run();
}
}
private void reportErrorStatus(Session session, int i, HttpServletResponse resp) {
try {
resp.sendError(i);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
private WebSocketVirtualServletRequest createServletRequest(byte[] data, int offset, int length,
WebSocketServletHolder holder,
Session session)
throws IOException {
return new WebSocketVirtualServletRequest(holder, new ByteArrayInputStream(data, offset, length), session);
}
private WebSocketVirtualServletResponse createServletResponse(WebSocketServletHolder holder) throws IOException {
return new WebSocketVirtualServletResponse(holder);
}
private final class Creator implements JettyWebSocketCreator {
@Override
public Object createWebSocket(JettyServerUpgradeRequest req, JettyServerUpgradeResponse resp) {
return new Jetty12WebSocket();
}
}
class Jetty12WebSocketHolder implements WebSocketServletHolder {
final Session session;
Jetty12WebSocketHolder(Session s) {
session = s;
}
public String getAuthType() {
return null;
}
public String getContextPath() {
return getHttpServletRequest(session.getUpgradeRequest()).getContextPath();
}
public String getLocalAddr() {
return null;
}
public String getLocalName() {
return null;
}
public int getLocalPort() {
return 0;
}
public Locale getLocale() {
return null;
}
public Enumeration<Locale> getLocales() {
return null;
}
public String getProtocol() {
return null;
}
public String getRemoteAddr() {
return null;
}
public String getRemoteHost() {
return null;
}
public int getRemotePort() {
return 0;
}
public String getRequestURI() {
return session.getUpgradeRequest().getRequestURI().getPath();
}
public StringBuffer getRequestURL() {
return new StringBuffer(session.getUpgradeRequest().getRequestURI().toString());
}
public DispatcherType getDispatcherType() {
return null;
}
public boolean isSecure() {
return false;
}
public String getPathInfo() {
return session.getUpgradeRequest().getRequestURI().getPath();
}
public String getPathTranslated() {
return session.getUpgradeRequest().getRequestURI().getPath();
}
public String getScheme() {
return "ws";
}
public String getServerName() {
return null;
}
public String getServletPath() {
return "";
}
public ServletContext getServletContext() {
return null;
}
public int getServerPort() {
return ((InetSocketAddress)session.getLocalSocketAddress()).getPort();
}
public Principal getUserPrincipal() {
return null;
}
public Object getAttribute(String name) {
try {
final UpgradeRequest upgradeRequest = session.getUpgradeRequest();
return getHttpServletRequest(upgradeRequest).getAttribute(name);
} catch (Exception ex) {
if (name.equals("org.apache.cxf.transport.endpoint.address")) {
return address;
} else {
return null;
}
}
}
@Override
public void write(byte[] data, int offset, int length) throws IOException {
session.sendBinary(ByteBuffer.wrap(data, offset, length), null);
}
private HttpServletRequest getHttpServletRequest(final UpgradeRequest upgradeRequest) {
if (upgradeRequest instanceof JettyServerUpgradeRequest) {
return ((JettyServerUpgradeRequest)upgradeRequest).getHttpServletRequest();
} else {
throw new IllegalStateException("Unsupported upgrade request class: " + upgradeRequest.getClass());
}
}
}
@org.eclipse.jetty.websocket.api.annotations.WebSocket
public class Jetty12WebSocket {
volatile Session session;
@org.eclipse.jetty.websocket.api.annotations.OnWebSocketOpen
public void onOpen(Session sess) {
this.session = sess;
}
@org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage
public void onMessage(String message) {
try {
byte[] bdata = message.getBytes("utf-8");
onBinaryMessage(ByteBuffer.wrap(bdata), null);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
}
@org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage
public void onBinaryMessage(ByteBuffer message, Callback callback) {
byte[] payload = new byte[message.remaining()];
message.get(payload);
invoke(payload, 0, payload.length, session);
}
@org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose
public void onClose(int code, String message) {
// members.remove(this);
}
}
}