TNonblockingSSLSocket.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.thrift.transport;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Transport for use with async ssl client. */
public class TNonblockingSSLSocket extends TNonblockingSocket implements SocketAddressProvider {

  private static final Logger LOGGER =
      LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName());

  private final SSLEngine sslEngine_;

  private final ByteBuffer appUnwrap;
  private final ByteBuffer netUnwrap;

  private final ByteBuffer netWrap;

  private boolean isHandshakeCompleted;

  private SelectionKey selectionKey;

  private final ExecutorService executorService = Executors.newSingleThreadExecutor();

  protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext)
      throws IOException, TTransportException {
    super(host, port, timeout);
    sslEngine_ = sslContext.createSSLEngine(host, port);
    sslEngine_.setUseClientMode(true);

    int appBufferSize = sslEngine_.getSession().getApplicationBufferSize();
    int netBufferSize = sslEngine_.getSession().getPacketBufferSize();
    appUnwrap = ByteBuffer.allocate(appBufferSize);
    netUnwrap = ByteBuffer.allocate(netBufferSize);
    netWrap = ByteBuffer.allocate(netBufferSize);
    isHandshakeCompleted = false;
  }

  /** {@inheritDoc} */
  @Override
  public SelectionKey registerSelector(Selector selector, int interests) throws IOException {
    selectionKey = super.registerSelector(selector, interests);
    return selectionKey;
  }

  /** {@inheritDoc} */
  @Override
  public boolean isOpen() {
    // isConnected() does not return false after close(), but isOpen() does
    return super.isOpen() && isHandshakeCompleted;
  }

  /** {@inheritDoc} */
  @Override
  public void open() throws TTransportException {
    throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket");
  }

  /** {@inheritDoc} */
  @Override
  public synchronized int read(ByteBuffer buffer) throws TTransportException {
    int numBytes = buffer.remaining();
    while (appUnwrap.limit() == appUnwrap.capacity()
        || appUnwrap.remaining() < buffer.remaining()) {
      if (appUnwrap.limit() < appUnwrap.capacity() && appUnwrap.hasRemaining()) {
        buffer.put(appUnwrap);
      }
      try {
        if (doUnwrap() == -1) {
          throw new IOException("Unable to read " + numBytes + " bytes");
        }
      } catch (IOException iox) {
        throw new TTransportException(TTransportException.UNKNOWN, iox);
      }
    }
    if (buffer.hasRemaining()) {
      int originLimit = appUnwrap.limit();
      appUnwrap.limit(appUnwrap.position() + buffer.remaining());
      buffer.put(appUnwrap);
      appUnwrap.limit(originLimit);
    }
    // In SSL mode, the Thrift server may merge the frame size and body into a single TLS package.
    // Setting OP_WRITE to trigger subsequent read operations in the Thrift async client.
    selectionKey.interestOps(SelectionKey.OP_WRITE);
    return numBytes;
  }

  /** {@inheritDoc} */
  @Override
  public synchronized int write(ByteBuffer buffer) throws TTransportException {
    int numBytes = buffer.remaining();

    while (buffer.hasRemaining()) {
      try {
        if (doWrap(buffer) == -1) {
          throw new IOException("Unable to write " + numBytes + " bytes");
        }
      } catch (IOException iox) {
        throw new TTransportException(TTransportException.UNKNOWN, iox);
      }
    }
    return numBytes;
  }

  /** {@inheritDoc} */
  @Override
  public void close() {
    executorService.shutdown();
    sslEngine_.closeOutbound();
    super.close();
  }

  /** {@inheritDoc} */
  @Override
  public boolean startConnect() throws IOException {
    if (this.isOpen()) {
      return true;
    }
    sslEngine_.beginHandshake();
    return super.startConnect() && doHandShake();
  }

  /** {@inheritDoc} */
  @Override
  public boolean finishConnect() throws IOException {
    return super.finishConnect() && doHandShake();
  }

  private synchronized boolean doHandShake() throws IOException {
    while (true) {
      HandshakeStatus hs = sslEngine_.getHandshakeStatus();
      switch (hs) {
        case NEED_UNWRAP:
          if (doUnwrap() == -1) {
            LOGGER.error("Unexpected. Handshake failed abruptly during unwrap");
            return false;
          }
          break;
        case NEED_WRAP:
          if (doWrap(ByteBuffer.wrap(new byte[0])) == -1) {
            LOGGER.error("Unexpected. Handshake failed abruptly during wrap");
            return false;
          }
          break;
        case NEED_TASK:
          doTask();
          break;
        case FINISHED:
        case NOT_HANDSHAKING:
          isHandshakeCompleted = true;
          return true;
        default:
          LOGGER.error("Unknown handshake status. Handshake failed");
          return false;
      }
    }
  }

  private void doTask() {
    Runnable runnable;
    while ((runnable = sslEngine_.getDelegatedTask()) != null) {
      executorService.submit(runnable);
    }
  }

  private int doUnwrap() throws IOException {
    int num = getSocketChannel().read(netUnwrap);
    netUnwrap.flip();
    if (num < 0) {
      LOGGER.error("Failed during read operation. Probably server is down");
      return -1;
    }
    SSLEngineResult unwrapResult;

    try {
      appUnwrap.clear();
      unwrapResult = sslEngine_.unwrap(netUnwrap, appUnwrap);
      netUnwrap.compact();
    } catch (SSLException ex) {
      LOGGER.error(ex.getMessage());
      throw ex;
    }

    switch (unwrapResult.getStatus()) {
      case OK:
        if (appUnwrap.position() > 0) {
          appUnwrap.flip();
        }
        break;
      case CLOSED:
        return -1;
      case BUFFER_OVERFLOW:
        throw new IllegalStateException("Failed to unwrap");
      case BUFFER_UNDERFLOW:
        break;
    }
    return num;
  }

  private int doWrap(ByteBuffer appWrap) throws IOException {
    int num = 0;
    SSLEngineResult wrapResult;
    try {
      wrapResult = sslEngine_.wrap(appWrap, netWrap);
    } catch (SSLException exc) {
      LOGGER.error(exc.getMessage());
      throw exc;
    }

    switch (wrapResult.getStatus()) {
      case OK:
        if (netWrap.position() > 0) {
          netWrap.flip();
          num = getSocketChannel().write(netWrap);
          netWrap.clear();
        }
        break;
      case BUFFER_UNDERFLOW:
        // try again later
        break;
      case BUFFER_OVERFLOW:
        throw new IllegalStateException("Failed to wrap");
      case CLOSED:
        LOGGER.error("SSL session is closed");
        return -1;
    }
    return num;
  }
}