AbstractFramedStreamSinkConduit.java

/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.conduits;

import io.undertow.UndertowMessages;
import org.xnio.Buffers;
import org.xnio.IoUtils;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.AbstractStreamSinkConduit;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSinkConduit;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayDeque;
import java.util.Deque;

import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.anyAreSet;

/**
 * Utility class to ease the implementation of framed protocols. This call provides a queue of frames, and a callback
 * that can be invoked when a frame event occurs.
 * <p>
 * When a write takes place all frames are attempted to be written out at once via a gathering write. Frames can be
 * queued via {@link #queueFrame(io.undertow.conduits.AbstractFramedStreamSinkConduit.FrameCallBack, java.nio.ByteBuffer...)}.
 *
 * @author Stuart Douglas
 */
public class AbstractFramedStreamSinkConduit extends AbstractStreamSinkConduit<StreamSinkConduit> {

    private final Deque<Frame> frameQueue = new ArrayDeque<>();
    /**
     * The total amount of data that has been queued to be written out
     */
    private long queuedData = 0;
    /**
     * The total number of buffers that have been queued to be written out
     */
    private int bufferCount = 0;

    private int state;

    private static final int FLAG_WRITES_TERMINATED = 1;
    private static final int FLAG_DELEGATE_SHUTDOWN = 2;

    /**
     * Construct a new instance.
     *
     * @param next the delegate conduit to set
     */
    protected AbstractFramedStreamSinkConduit(StreamSinkConduit next) {
        super(next);
    }

    /**
     * Queues a frame for sending.
     *
     * @param callback
     * @param data
     */
    protected void queueFrame(FrameCallBack callback, ByteBuffer... data) {
        queuedData += Buffers.remaining(data);
        bufferCount += data.length;
        frameQueue.add(new Frame(callback, data, 0, data.length));
    }

    public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
        return src.transferTo(position, count, new ConduitWritableByteChannel(this));
    }

    public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
        return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
    }

    @Override
    public int write(ByteBuffer src) throws IOException {
        if (anyAreSet(state, FLAG_WRITES_TERMINATED)) {
            throw UndertowMessages.MESSAGES.channelIsClosed();
        }
        return (int) doWrite(new ByteBuffer[]{src}, 0, 1);
    }

    @Override
    public long write(ByteBuffer[] srcs, int offs, int len) throws IOException {
        if (anyAreSet(state, FLAG_WRITES_TERMINATED)) {
            throw UndertowMessages.MESSAGES.channelIsClosed();
        }
        return doWrite(srcs, offs, len);
    }

    @Override
    public int writeFinal(ByteBuffer src) throws IOException {
        return Conduits.writeFinalBasic(this, src);
    }

    @Override
    public long writeFinal(ByteBuffer[] srcs, int offs, int len) throws IOException {
        return Conduits.writeFinalBasic(this, srcs, offs, len);
    }


    private long doWrite(ByteBuffer[] additionalData, int offs, int len) throws IOException {
        ByteBuffer[] buffers = new ByteBuffer[bufferCount + (additionalData == null ? 0 : len)];
        int count = 0;
        for (Frame frame : frameQueue) {
            for (int i = frame.offs; i < frame.offs + frame.len; ++i) {
                buffers[count++] = frame.data[i];
            }
        }

        if (additionalData != null) {
            for (int i = offs; i < offs + len; ++i) {
                buffers[count++] = additionalData[i];
            }
        }
        try {
            long written = next.write(buffers, 0, buffers.length);
            if (written > this.queuedData) {
                this.queuedData = 0;
            } else {
                this.queuedData -= written;
            }
            long toAllocate = written;
            Frame frame = frameQueue.peek();
            while (frame != null) {
                if (frame.remaining > toAllocate) {
                    frame.remaining -= toAllocate;
                    return 0;
                } else {
                    frameQueue.poll(); //this frame is done, remove it
                    //note that after we start calling done() we can't re-use the buffers[] array
                    //as pooled buffers may have been returned to the pool and re-used
                    FrameCallBack cb = frame.callback;
                    if (cb != null) {
                        cb.done();
                    }
                    bufferCount -= frame.len;
                    toAllocate -= frame.remaining;
                }
                frame = frameQueue.peek();
            }
            return toAllocate;

        } catch (IOException | RuntimeException | Error e) {
            IOException ioe = e instanceof IOException ? (IOException) e : new IOException(e);
            //on exception we fail every item in the frame queue
            try {
                for (Frame frame : frameQueue) {
                    FrameCallBack cb = frame.callback;
                    if (cb != null) {
                        cb.failed(ioe);
                    }
                }
                frameQueue.clear();
                bufferCount = 0;
                queuedData = 0;
            } finally {
                throw e;
            }
        }
    }

    protected long queuedDataLength() {
        return queuedData;
    }


    @Override
    public void terminateWrites() throws IOException {
        if (anyAreSet(state, FLAG_WRITES_TERMINATED)) {
            return;
        }
        queueCloseFrames();
        state |= FLAG_WRITES_TERMINATED;
        if (queuedData == 0) {
            state |= FLAG_DELEGATE_SHUTDOWN;
            doTerminateWrites();
            finished();
        }
    }

    protected void doTerminateWrites() throws IOException {
        next.terminateWrites();
    }

    protected boolean flushQueuedData() throws IOException {
        if (queuedData > 0) {
            doWrite(null, 0, 0);
        }
        if (queuedData > 0) {
            return false;
        }
        if (anyAreSet(state, FLAG_WRITES_TERMINATED) && allAreClear(state, FLAG_DELEGATE_SHUTDOWN)) {
            doTerminateWrites();
            state |= FLAG_DELEGATE_SHUTDOWN;
            finished();
        }
        return next.flush();
    }

    @Override
    public void truncateWrites() throws IOException {
        for (Frame frame : frameQueue) {
            FrameCallBack cb = frame.callback;
            if (cb != null) {
                cb.failed(UndertowMessages.MESSAGES.channelIsClosed());
            }
        }
    }

    protected boolean isWritesTerminated() {
        return anyAreSet(state, FLAG_WRITES_TERMINATED);
    }

    protected void queueCloseFrames() {

    }

    protected void finished() {

    }

    /**
     * Interface that is called when a frame event takes place. The events are:
     * <p>
     * <ul>
     * <li>
     * Done - The fame has been written out
     * </li>
     * <li>
     * Failed - The frame write failed
     * </li>
     * </ul>
     */
    public interface FrameCallBack {

        void done();

        void failed(IOException e);

    }

    private static class Frame {

        final FrameCallBack callback;
        final ByteBuffer[] data;
        final int offs;
        final int len;
        long remaining;

        private Frame(FrameCallBack callback, ByteBuffer[] data, int offs, int len) {
            this.callback = callback;
            this.data = data;
            this.offs = offs;
            this.len = len;
            this.remaining = Buffers.remaining(data, offs, len);
        }
    }

    protected static class PooledBufferFrameCallback implements FrameCallBack {

        private final PooledByteBuffer buffer;

        public PooledBufferFrameCallback(PooledByteBuffer buffer) {
            this.buffer = buffer;
        }

        @Override
        public void done() {
            buffer.close();
        }

        @Override
        public void failed(IOException e) {
            buffer.close();
        }
    }


    protected static class PooledBuffersFrameCallback implements FrameCallBack {

        private final PooledByteBuffer[] buffers;

        public PooledBuffersFrameCallback(PooledByteBuffer... buffers) {
            this.buffers = buffers;
        }

        @Override
        public void done() {
            for (PooledByteBuffer buffer : buffers) {
                buffer.close();
            }
        }

        @Override
        public void failed(IOException e) {
            done();
        }
    }
}