PerMessageDeflateExtension.java
/*
* Copyright (c) 2013, 2017 Oracle and/or its affiliates. All rights reserved.
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License v. 2.0, which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the
* Eclipse Public License v. 2.0 are satisfied: GNU General Public License,
* version 2 with the GNU Classpath Exception, which is available at
* https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
*/
package org.glassfish.tyrus.ext.extension.deflate;
import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import org.glassfish.tyrus.core.extension.ExtendedExtension;
import org.glassfish.tyrus.core.frame.Frame;
/**
* Compression Extensions for WebSocket draft-ietf-hybi-permessage-compression-15
* <p>
* http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-15
* <pre>TODO:
* - parameters (window sizes, context takeovers).
* - context (some utility methods to get the typed params - T getParam(Class<T>))
* </pre>
*
* @author Pavel Bucek (pavel.bucek at oracle.com)
*/
public class PerMessageDeflateExtension implements ExtendedExtension {
private static final Pool<byte[]> BYTE_ARRAY_POOL = new Pool<byte[]>() {
@Override
byte[] create() {
return new byte[8192];
}
};
private static final String INFLATER = PerMessageDeflateExtension.class.getName() + ".INFLATER";
private static final String DEFLATER = PerMessageDeflateExtension.class.getName() + ".DEFLATER";
private static final Logger LOGGER = Logger.getLogger(PerMessageDeflateExtension.class.getName());
private static final boolean DEBUG = LOGGER.isLoggable(Level.FINE);
private static final byte[] TAIL = {0x00, 0x00, (byte) 0xff, (byte) 0xff};
@Override
public Frame processIncoming(ExtensionContext context, Frame frame) {
final Inflater decompresser = (Inflater) context.getProperties().get(INFLATER);
if (DEBUG) {
LOGGER.fine("Incoming frame: " + frame);
}
// per-message-deflate draft, chapter 8.2.3.1:
// "Note that the RSV1 bit is set only on the first frame."
if ((frame.isRsv1() || frame.getOpcode() == 0x00) && !frame.isControlFrame()) {
// Decompress the bytes
final int payloadLength = (int) frame.getPayloadLength();
List<PartialResultWithLength<byte[]>> wholeResult = new ArrayList<PartialResultWithLength<byte[]>>();
int wholeResultLength = 0;
int tmp = processCompressed(decompresser, frame.getPayloadData(), payloadLength, wholeResult);
if (tmp == -1) {
return frame;
} else {
wholeResultLength += tmp;
}
if (frame.isFin()) {
tmp = processCompressed(decompresser, TAIL, 4, wholeResult);
if (tmp == -1) {
return frame;
} else {
wholeResultLength += tmp;
}
}
byte[] completeResult = new byte[wholeResultLength];
wholeResultLength = 0;
for (PartialResultWithLength<byte[]> partialResult : wholeResult) {
tmp = partialResult.getLength();
final byte[] result = partialResult.getResult();
System.arraycopy(result, 0, completeResult, wholeResultLength, tmp);
BYTE_ARRAY_POOL.recycle(result);
wholeResultLength += tmp;
}
return Frame.builder(frame).payloadData(completeResult).rsv1(false).build();
} else {
return frame;
}
}
private int processCompressed(Inflater decompresser, byte[] compressed, int length,
List<PartialResultWithLength<byte[]>> partialResults) {
decompresser.setInput(compressed, 0, length);
int decompressedLength = 0;
do {
byte[] result = BYTE_ARRAY_POOL.take();
int partialResultLength;
try {
partialResultLength = decompresser.inflate(result);
} catch (DataFormatException e) {
LOGGER.log(Level.INFO, e.getMessage(), e);
return -1;
}
if (partialResultLength != 0) {
partialResults.add(new PartialResultWithLength<byte[]>(partialResultLength, result));
decompressedLength += partialResultLength;
} else {
BYTE_ARRAY_POOL.recycle(result);
}
} while (decompresser.getRemaining() > 0);
return decompressedLength;
}
@Override
public Frame processOutgoing(ExtensionContext context, Frame frame) {
final Deflater compresser = (Deflater) context.getProperties().get(DEFLATER);
if (DEBUG) {
LOGGER.fine("Outgoing frame: " + frame);
}
if (!frame.isControlFrame()) {
List<PartialResultWithLength<byte[]>> wholeResult = new ArrayList<PartialResultWithLength<byte[]>>();
int wholeResultLength = 0;
// Compress the bytes
final int payloadLength = (int) frame.getPayloadLength();
compresser.setInput(frame.getPayloadData(), 0, payloadLength);
int compressedDataLength;
do {
byte[] output = BYTE_ARRAY_POOL.take();
compressedDataLength = compresser.deflate(output, 0, output.length, Deflater.SYNC_FLUSH);
if (compressedDataLength > 0) {
wholeResult.add(new PartialResultWithLength<byte[]>(compressedDataLength, output));
wholeResultLength += compressedDataLength;
} else {
BYTE_ARRAY_POOL.recycle(output);
}
} while (compressedDataLength > 0);
byte[] completeResult = new byte[wholeResultLength];
wholeResultLength = 0;
for (PartialResultWithLength<byte[]> partialResult : wholeResult) {
int tmp = partialResult.getLength();
final byte[] result = partialResult.getResult();
System.arraycopy(result, 0, completeResult, wholeResultLength, tmp);
BYTE_ARRAY_POOL.recycle(result);
wholeResultLength += tmp;
}
boolean strip = false;
if (completeResult.length >= 4
&& completeResult[completeResult.length - 4] == TAIL[0]
&& completeResult[completeResult.length - 3] == TAIL[1]
&& completeResult[completeResult.length - 2] == TAIL[2]
&& completeResult[completeResult.length - 1] == TAIL[3]) {
strip = true;
}
return Frame.builder(frame)
.payloadData(completeResult)
.payloadLength(strip ? completeResult.length - 4 : completeResult.length)
.rsv1(true).build();
} else {
return frame;
}
}
private void init(ExtensionContext context) {
// TODO: configurable compression level
Deflater compresser = new Deflater(9, true);
Inflater decompresser = new Inflater(true);
compresser.setStrategy(Deflater.DEFAULT_STRATEGY);
context.getProperties().put(INFLATER, decompresser);
context.getProperties().put(DEFLATER, compresser);
}
@Override
public List<Parameter> onExtensionNegotiation(ExtensionContext context, List<Parameter> requestedParameters) {
init(context);
return Collections.<Parameter>emptyList();
}
@Override
public void onHandshakeResponse(ExtensionContext context, List<Parameter> responseParameters) {
init(context);
}
@Override
public void destroy(ExtensionContext context) {
final Inflater decompresser = (Inflater) context.getProperties().get(INFLATER);
final Deflater compresser = (Deflater) context.getProperties().get(DEFLATER);
context.getProperties().remove(DEFLATER);
context.getProperties().remove(INFLATER);
if (decompresser != null) {
decompresser.end();
}
if (compresser != null) {
compresser.end();
}
}
@Override
public String getName() {
return "permessage-deflate";
}
@Override
public List<Parameter> getParameters() {
return Collections.<Parameter>emptyList();
}
/**
* Generic pool that instances of T which are expensive to create.
*
* @author Jitendra Kotamraju
* @author Pavel Bucek (pavel.bucek at oracle.com)
*/
private abstract static class Pool<T> {
// volatile since multiple threads may access queue reference
private volatile WeakReference<ConcurrentLinkedQueue<T>> queue;
/**
* Gets a new object from the pool.
* <p>
* If no object is available in the pool, this method creates a new one.
*
* @return always non-null.
*/
public final T take() {
T t = getQueue().poll();
if (t == null) {
return create();
}
return t;
}
/**
* Create new instance to be added into pool.
*
* @return new instance.
*/
abstract T create();
private ConcurrentLinkedQueue<T> getQueue() {
WeakReference<ConcurrentLinkedQueue<T>> q = queue;
if (q != null) {
ConcurrentLinkedQueue<T> d = q.get();
if (d != null) {
return d;
}
}
// overwrite the queue
ConcurrentLinkedQueue<T> d = new ConcurrentLinkedQueue<T>();
queue = new WeakReference<ConcurrentLinkedQueue<T>>(d);
return d;
}
/**
* Returns an object back to the pool.
*/
public final void recycle(T t) {
getQueue().offer(t);
}
}
private static class PartialResultWithLength<T> {
private final int length;
private final T result;
private PartialResultWithLength(int length, T result) {
this.length = length;
this.result = result;
}
public int getLength() {
return length;
}
public T getResult() {
return result;
}
}
}