TestTSaslTransports.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.thrift.transport;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import javax.security.sasl.SaslServerFactory;
import org.apache.thrift.TConfiguration;
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.server.ServerTestBase;
import org.apache.thrift.server.TServer;
import org.apache.thrift.server.TServer.Args;
import org.apache.thrift.server.TSimpleServer;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TestTSaslTransports {

  private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class);

  public static final String HOST = "localhost";
  public static final String SERVICE = "thrift-test";
  public static final String PRINCIPAL = "thrift-test-principal";
  public static final String PASSWORD = "super secret password";
  public static final String REALM = "thrift-test-realm";

  public static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
  public static final Map<String, String> UNWRAPPED_PROPS = null;

  public static final String WRAPPED_MECHANISM = "DIGEST-MD5";
  public static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();

  static {
    WRAPPED_PROPS.put(Sasl.QOP, "auth-int");
    WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM);
  }

  private static final String testMessage1 =
      "Hello, world! Also, four "
          + "score and seven years ago our fathers brought forth on this "
          + "continent a new nation, conceived in liberty, and dedicated to the "
          + "proposition that all men are created equal.";

  private static final String testMessage2 =
      "I have a dream that one day "
          + "this nation will rise up and live out the true meaning of its creed: "
          + "'We hold these truths to be self-evident, that all men are created equal.'";

  public static class TestSaslCallbackHandler implements CallbackHandler {
    private final String password;

    public TestSaslCallbackHandler(String password) {
      this.password = password;
    }

    @Override
    public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
      for (Callback c : callbacks) {
        if (c instanceof NameCallback) {
          ((NameCallback) c).setName(PRINCIPAL);
        } else if (c instanceof PasswordCallback) {
          ((PasswordCallback) c).setPassword(password.toCharArray());
        } else if (c instanceof AuthorizeCallback) {
          ((AuthorizeCallback) c).setAuthorized(true);
        } else if (c instanceof RealmCallback) {
          ((RealmCallback) c).setText(REALM);
        } else {
          throw new UnsupportedCallbackException(c);
        }
      }
    }
  }

  private static class ServerThread extends Thread {
    final String mechanism;
    final Map<String, String> props;
    volatile Throwable thrown;

    public ServerThread(String mechanism, Map<String, String> props) {
      this.mechanism = mechanism;
      this.props = props;
    }

    public void run() {
      try {
        internalRun();
      } catch (Throwable t) {
        thrown = t;
      }
    }

    private void internalRun() throws Exception {
      try (TServerSocket serverSocket =
          new TServerSocket(
              new TServerSocket.ServerSocketTransportArgs().port(ServerTestBase.PORT))) {
        acceptAndWrite(serverSocket);
      }
    }

    private void acceptAndWrite(TServerSocket serverSocket) throws Exception {
      TTransport serverTransport = serverSocket.accept();
      TTransport saslServerTransport =
          new TSaslServerTransport(
              mechanism,
              SERVICE,
              HOST,
              props,
              new TestSaslCallbackHandler(PASSWORD),
              serverTransport);

      saslServerTransport.open();

      byte[] inBuf = new byte[testMessage1.getBytes().length];
      // Deliberately read less than the full buffer to ensure
      // that TSaslTransport is correctly buffering reads. This
      // will fail for the WRAPPED test, if it doesn't work.
      saslServerTransport.readAll(inBuf, 0, 5);
      saslServerTransport.readAll(inBuf, 5, 10);
      saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
      LOGGER.debug("server got: {}", new String(inBuf));
      assertEquals(new String(inBuf), testMessage1);

      LOGGER.debug("server writing: {}", testMessage2);
      saslServerTransport.write(testMessage2.getBytes());
      saslServerTransport.flush();

      saslServerTransport.close();
    }
  }

  private void testSaslOpen(final String mechanism, final Map<String, String> props)
      throws Exception {
    ServerThread serverThread = new ServerThread(mechanism, props);
    serverThread.start();

    try {
      Thread.sleep(1000);
    } catch (InterruptedException e) {
      // Ah well.
    }

    try {
      TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
      TTransport saslClientTransport =
          new TSaslClientTransport(
              mechanism,
              PRINCIPAL,
              SERVICE,
              HOST,
              props,
              new TestSaslCallbackHandler(PASSWORD),
              clientSocket);
      saslClientTransport.open();
      LOGGER.debug("client writing: {}", testMessage1);
      saslClientTransport.write(testMessage1.getBytes());
      saslClientTransport.flush();

      byte[] inBuf = new byte[testMessage2.getBytes().length];
      saslClientTransport.readAll(inBuf, 0, inBuf.length);
      LOGGER.debug("client got: {}", new String(inBuf));
      assertEquals(new String(inBuf), testMessage2);

      TTransportException expectedException = null;
      try {
        saslClientTransport.open();
      } catch (TTransportException e) {
        expectedException = e;
      }
      assertNotNull(expectedException);

      saslClientTransport.close();
    } catch (Exception e) {
      LOGGER.warn("Exception caught", e);
      throw e;
    } finally {
      serverThread.interrupt();
      try {
        serverThread.join();
      } catch (InterruptedException e) {
        // Ah well.
      }
      assertNull(serverThread.thrown);
    }
  }

  @Test
  public void testUnwrappedOpen() throws Exception {
    testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
  }

  @Test
  public void testWrappedOpen() throws Exception {
    testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS);
  }

  @Test
  public void testAnonymousOpen() throws Exception {
    testSaslOpen("ANONYMOUS", null);
  }

  /**
   * Test that we get the proper exceptions thrown back the server when the client provides invalid
   * password.
   */
  @Test
  public void testBadPassword() throws Exception {
    ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
    serverThread.start();

    try {
      Thread.sleep(1000);
    } catch (InterruptedException e) {
      // Ah well.
    }

    TTransportException tte =
        assertThrows(
            TTransportException.class,
            () -> {
              TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
              TTransport saslClientTransport =
                  new TSaslClientTransport(
                      UNWRAPPED_MECHANISM,
                      PRINCIPAL,
                      SERVICE,
                      HOST,
                      UNWRAPPED_PROPS,
                      new TestSaslCallbackHandler("NOT THE PASSWORD"),
                      clientSocket);
              saslClientTransport.open();
            },
            "Was able to open transport with bad password");
    LOGGER.error("Exception for bad password", tte);
    assertNotNull(tte.getMessage());
    assertTrue(tte.getMessage().contains("Invalid response"));
    serverThread.interrupt();
    serverThread.join();
    assertNotNull(serverThread.thrown);
    assertTrue(serverThread.thrown.getMessage().contains("Invalid response"));
  }

  @Test
  public void testWithServer() throws Exception {
    new TestTSaslTransportsWithServer().testIt();
  }

  public static class TestTSaslTransportsWithServer extends ServerTestBase {

    private Thread serverThread;
    private TServer server;

    @Override
    public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
      return new TSaslClientTransport(
          WRAPPED_MECHANISM,
          PRINCIPAL,
          SERVICE,
          HOST,
          WRAPPED_PROPS,
          new TestSaslCallbackHandler(PASSWORD),
          underlyingTransport);
    }

    @Override
    public void startServer(
        final TProcessor processor,
        final TProtocolFactory protoFactory,
        final TTransportFactory factory)
        throws Exception {
      serverThread =
          new Thread() {
            public void run() {
              try {
                // Transport
                TServerSocket socket =
                    new TServerSocket(new TServerSocket.ServerSocketTransportArgs().port(PORT));

                TTransportFactory factory =
                    new TSaslServerTransport.Factory(
                        WRAPPED_MECHANISM,
                        SERVICE,
                        HOST,
                        WRAPPED_PROPS,
                        new TestSaslCallbackHandler(PASSWORD));
                server =
                    new TSimpleServer(
                        new Args(socket)
                            .processor(processor)
                            .transportFactory(factory)
                            .protocolFactory(protoFactory));

                // Run it
                LOGGER.debug("Starting the server on port {}", PORT);
                server.serve();
              } catch (Exception e) {
                e.printStackTrace();
                fail(e);
              }
            }
          };
      serverThread.start();
      Thread.sleep(1000);
    }

    @Override
    public void stopServer() throws Exception {
      server.stop();
      try {
        serverThread.join();
      } catch (InterruptedException e) {
        LOGGER.debug("interrupted during sleep", e);
      }
    }
  }

  /** Implementation of SASL ANONYMOUS, used for testing client-side initial responses. */
  private static class AnonymousClient implements SaslClient {
    private final String username;
    private boolean hasProvidedInitialResponse;

    public AnonymousClient(String username) {
      this.username = username;
    }

    @Override
    public String getMechanismName() {
      return "ANONYMOUS";
    }

    @Override
    public boolean hasInitialResponse() {
      return true;
    }

    @Override
    public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
      if (hasProvidedInitialResponse) {
        throw new SaslException("Already complete!");
      }

      hasProvidedInitialResponse = true;
      return username.getBytes(StandardCharsets.UTF_8);
    }

    @Override
    public boolean isComplete() {
      return hasProvidedInitialResponse;
    }

    @Override
    public byte[] unwrap(byte[] incoming, int offset, int len) {
      throw new UnsupportedOperationException();
    }

    @Override
    public byte[] wrap(byte[] outgoing, int offset, int len) {
      throw new UnsupportedOperationException();
    }

    @Override
    public Object getNegotiatedProperty(String propName) {
      return null;
    }

    @Override
    public void dispose() {}
  }

  private static class AnonymousServer implements SaslServer {
    private String user;

    @Override
    public String getMechanismName() {
      return "ANONYMOUS";
    }

    @Override
    public byte[] evaluateResponse(byte[] response) throws SaslException {
      this.user = new String(response, StandardCharsets.UTF_8);
      return null;
    }

    @Override
    public boolean isComplete() {
      return user != null;
    }

    @Override
    public String getAuthorizationID() {
      return user;
    }

    @Override
    public byte[] unwrap(byte[] incoming, int offset, int len) {
      throw new UnsupportedOperationException();
    }

    @Override
    public byte[] wrap(byte[] outgoing, int offset, int len) {
      throw new UnsupportedOperationException();
    }

    @Override
    public Object getNegotiatedProperty(String propName) {
      return null;
    }

    @Override
    public void dispose() {}
  }

  public static class SaslAnonymousFactory implements SaslClientFactory, SaslServerFactory {

    @Override
    public SaslClient createSaslClient(
        String[] mechanisms,
        String authorizationId,
        String protocol,
        String serverName,
        Map<String, ?> props,
        CallbackHandler cbh) {
      for (String mech : mechanisms) {
        if ("ANONYMOUS".equals(mech)) {
          return new AnonymousClient(authorizationId);
        }
      }
      return null;
    }

    @Override
    public SaslServer createSaslServer(
        String mechanism,
        String protocol,
        String serverName,
        Map<String, ?> props,
        CallbackHandler cbh) {
      if ("ANONYMOUS".equals(mechanism)) {
        return new AnonymousServer();
      }
      return null;
    }

    @Override
    public String[] getMechanismNames(Map<String, ?> props) {
      return new String[] {"ANONYMOUS"};
    }
  }

  static {
    java.security.Security.addProvider(new SaslAnonymousProvider());
  }

  public static class SaslAnonymousProvider extends java.security.Provider {
    public SaslAnonymousProvider() {
      super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider");
      put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
      put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
    }
  }

  private static class MockTTransport extends TTransport {

    byte[] badHeader = null;
    private final TMemoryInputTransport readBuffer;

    public MockTTransport(int mode) throws TTransportException {
      readBuffer = new TMemoryInputTransport();
      if (mode == 1) {
        // Invalid status byte
        badHeader = new byte[] {(byte) 0xFF, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x05};
      } else if (mode == 2) {
        // Valid status byte, negative payload length
        badHeader = new byte[] {(byte) 0x01, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF};
      } else if (mode == 3) {
        // Valid status byte, excessively large, bogus payload length
        badHeader = new byte[] {(byte) 0x01, (byte) 0x64, (byte) 0x00, (byte) 0x00, (byte) 0x00};
      }
      readBuffer.reset(badHeader);
    }

    @Override
    public boolean isOpen() {
      return true;
    }

    @Override
    public void open() throws TTransportException {}

    @Override
    public void close() {}

    @Override
    public int read(byte[] buf, int off, int len) throws TTransportException {
      return readBuffer.read(buf, off, len);
    }

    @Override
    public void write(byte[] buf, int off, int len) throws TTransportException {}

    @Override
    public TConfiguration getConfiguration() {
      return readBuffer.getConfiguration();
    }

    @Override
    public void updateKnownMessageSize(long size) throws TTransportException {
      readBuffer.updateKnownMessageSize(size);
    }

    @Override
    public void checkReadBytesAvailable(long numBytes) throws TTransportException {
      readBuffer.checkReadBytesAvailable(numBytes);
    }
  }

  @Test
  public void testBadHeader() {
    TSaslTransport saslTransport;
    try {
      saslTransport = new TSaslServerTransport(new MockTTransport(1));
      saslTransport.receiveSaslMessage();
      fail("Should have gotten an error due to incorrect status byte value.");
    } catch (TTransportException e) {
      assertEquals(e.getMessage(), "Invalid status -1");
    }
    try {
      saslTransport = new TSaslServerTransport(new MockTTransport(2));
      saslTransport.receiveSaslMessage();
      fail("Should have gotten an error due to negative payload length.");
    } catch (TTransportException e) {
      assertEquals(e.getMessage(), "Invalid payload header length: -1");
    }
    try {
      saslTransport = new TSaslServerTransport(new MockTTransport(3));
      saslTransport.receiveSaslMessage();
      fail("Should have gotten an error due to bogus (large) payload length.");
    } catch (TTransportException e) {
      assertEquals(e.getMessage(), "Invalid payload header length: 1677721600");
    }
  }
}