NettyClientBuilder.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight.grpc;

import io.grpc.ManagedChannel;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.net.ssl.SSLException;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.LocationSchemes;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;

/**
 * A wrapper around gRPC's Netty builder.
 *
 * <p>It is recommended to use the Netty channel builder directly with {@link
 * org.apache.arrow.flight.FlightGrpcUtils#createFlightClient(BufferAllocator, ManagedChannel)}.
 * However, this class provides an adapter that implements the existing Flight-specific builder
 * interface but allows usage of the Netty builder as well.
 */
public class NettyClientBuilder {
  /**
   * The maximum number of trace events to keep on the gRPC Channel. This value disables channel
   * tracing.
   */
  private static final int MAX_CHANNEL_TRACE_EVENTS = 0;

  protected BufferAllocator allocator;
  protected Location location;
  protected boolean forceTls = false;
  protected int maxInboundMessageSize = Integer.MAX_VALUE;
  protected InputStream trustedCertificates = null;
  protected InputStream clientCertificate = null;
  protected InputStream clientKey = null;
  protected String overrideHostname = null;
  protected List<FlightClientMiddleware.Factory> middleware = new ArrayList<>();
  protected boolean verifyServer = true;

  public NettyClientBuilder() {}

  public NettyClientBuilder(BufferAllocator allocator, Location location) {
    this.allocator = Preconditions.checkNotNull(allocator);
    this.location = Preconditions.checkNotNull(location);
  }

  /** Force the client to connect over TLS. */
  public NettyClientBuilder useTls() {
    this.forceTls = true;
    return this;
  }

  /** Override the hostname checked for TLS. Use with caution in production. */
  public NettyClientBuilder overrideHostname(final String hostname) {
    this.overrideHostname = hostname;
    return this;
  }

  /** Set the maximum inbound message size. */
  public NettyClientBuilder maxInboundMessageSize(int maxSize) {
    Preconditions.checkArgument(maxSize > 0);
    this.maxInboundMessageSize = maxSize;
    return this;
  }

  /** Set the trusted TLS certificates. */
  public NettyClientBuilder trustedCertificates(final InputStream stream) {
    this.trustedCertificates = Preconditions.checkNotNull(stream);
    return this;
  }

  /** Set the trusted TLS certificates. */
  public NettyClientBuilder clientCertificate(
      final InputStream clientCertificate, final InputStream clientKey) {
    Preconditions.checkNotNull(clientKey);
    this.clientCertificate = Preconditions.checkNotNull(clientCertificate);
    this.clientKey = Preconditions.checkNotNull(clientKey);
    return this;
  }

  public BufferAllocator allocator() {
    return allocator;
  }

  public NettyClientBuilder allocator(BufferAllocator allocator) {
    this.allocator = Preconditions.checkNotNull(allocator);
    return this;
  }

  public NettyClientBuilder location(Location location) {
    this.location = Preconditions.checkNotNull(location);
    return this;
  }

  public List<FlightClientMiddleware.Factory> middleware() {
    return Collections.unmodifiableList(middleware);
  }

  public NettyClientBuilder intercept(FlightClientMiddleware.Factory factory) {
    middleware.add(factory);
    return this;
  }

  public NettyClientBuilder verifyServer(boolean verifyServer) {
    this.verifyServer = verifyServer;
    return this;
  }

  /** Create the client from this builder. */
  public NettyChannelBuilder build() {
    final NettyChannelBuilder builder;

    switch (location.getUri().getScheme()) {
      case LocationSchemes.GRPC:
      case LocationSchemes.GRPC_INSECURE:
      case LocationSchemes.GRPC_TLS:
        {
          builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
          break;
        }
      case LocationSchemes.GRPC_DOMAIN_SOCKET:
        {
          // The implementation is platform-specific, so we have to find the classes at runtime
          builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
          try {
            try {
              // Linux
              builder.channelType(
                  Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")
                      .asSubclass(ServerChannel.class));
              final EventLoopGroup elg =
                  Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
                      .asSubclass(EventLoopGroup.class)
                      .getDeclaredConstructor()
                      .newInstance();
              builder.eventLoopGroup(elg);
            } catch (ClassNotFoundException e) {
              // BSD
              builder.channelType(
                  Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")
                      .asSubclass(ServerChannel.class));
              final EventLoopGroup elg =
                  Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
                      .asSubclass(EventLoopGroup.class)
                      .getDeclaredConstructor()
                      .newInstance();
              builder.eventLoopGroup(elg);
            }
          } catch (ClassNotFoundException
              | InstantiationException
              | IllegalAccessException
              | NoSuchMethodException
              | InvocationTargetException e) {
            throw new UnsupportedOperationException(
                "Could not find suitable Netty native transport implementation for domain socket address.");
          }
          break;
        }
      default:
        throw new IllegalArgumentException(
            "Scheme is not supported: " + location.getUri().getScheme());
    }

    if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
      builder.useTransportSecurity();

      final boolean hasTrustedCerts = this.trustedCertificates != null;
      final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null;
      if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) {
        throw new IllegalArgumentException(
            "FlightClient has been configured to disable server verification, "
                + "but certificate options have been specified.");
      }

      final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();

      if (!this.verifyServer) {
        sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
      } else if (this.trustedCertificates != null
          || this.clientCertificate != null
          || this.clientKey != null) {
        if (this.trustedCertificates != null) {
          sslContextBuilder.trustManager(this.trustedCertificates);
        }
        if (this.clientCertificate != null && this.clientKey != null) {
          sslContextBuilder.keyManager(this.clientCertificate, this.clientKey);
        }
      }
      try {
        builder.sslContext(sslContextBuilder.build());
      } catch (SSLException e) {
        throw new RuntimeException(e);
      }

      if (this.overrideHostname != null) {
        builder.overrideAuthority(this.overrideHostname);
      }
    } else {
      builder.usePlaintext();
    }

    builder
        .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
        .maxInboundMessageSize(maxInboundMessageSize)
        .maxInboundMetadataSize(maxInboundMessageSize);
    return builder;
  }
}