TestCookieHandling.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight.client;

import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.IOException;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.ErrorFlightMetadata;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightMethod;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.RequestContext;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

/**
 * Tests for correct handling of cookies from the FlightClient using {@link ClientCookieMiddleware}.
 */
public class TestCookieHandling {
  private static final String SET_COOKIE_HEADER = "Set-Cookie";
  private static final String COOKIE_HEADER = "Cookie";
  private BufferAllocator allocator;
  private FlightServer server;
  private FlightClient client;

  private ClientCookieMiddlewareTestFactory testFactory = new ClientCookieMiddlewareTestFactory();
  private ClientCookieMiddleware cookieMiddleware = new ClientCookieMiddleware(testFactory);

  @BeforeEach
  public void setup() throws Exception {
    allocator = new RootAllocator(Long.MAX_VALUE);
    startServerAndClient();
  }

  @AfterEach
  public void cleanup() throws Exception {
    testFactory = new ClientCookieMiddlewareTestFactory();
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    AutoCloseables.close(client, server, allocator);
    client = null;
    server = null;
    allocator = null;
  }

  @Test
  public void basicCookie() {
    CallHeaders headersToSend = new ErrorFlightMetadata();
    headersToSend.insert(SET_COOKIE_HEADER, "k=v");
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);
    assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());
  }

  @Test
  public void cookieStaysAfterMultipleRequests() {
    CallHeaders headersToSend = new ErrorFlightMetadata();
    headersToSend.insert(SET_COOKIE_HEADER, "k=v");
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);
    assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());

    headersToSend = new ErrorFlightMetadata();
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);
    assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());

    headersToSend = new ErrorFlightMetadata();
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);
    assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());
  }

  @Disabled
  @Test
  public void cookieAutoExpires() {
    CallHeaders headersToSend = new ErrorFlightMetadata();
    headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2");
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);
    // Note: using max-age changes cookie version from 0->1, which quotes values.
    assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());

    headersToSend = new ErrorFlightMetadata();
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);
    assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());

    try {
      Thread.sleep(5000);
    } catch (InterruptedException ignored) {
    }

    // Verify that the k cookie was discarded because it expired.
    assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty());
  }

  @Test
  public void cookieExplicitlyExpires() {
    CallHeaders headersToSend = new ErrorFlightMetadata();
    headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2");
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);
    // Note: using max-age changes cookie version from 0->1, which quotes values.
    assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());

    // Note: The JDK treats Max-Age < 0 as not expired and treats 0 as expired.
    // This violates the RFC, which states that less than zero and zero should both be expired.
    headersToSend = new ErrorFlightMetadata();
    headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=0");
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);

    // Verify that the k cookie was discarded because the server told the client it is expired.
    assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty());
  }

  @Disabled
  @Test
  public void cookieExplicitlyExpiresWithMaxAgeMinusOne() {
    CallHeaders headersToSend = new ErrorFlightMetadata();
    headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2");
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);
    // Note: using max-age changes cookie version from 0->1, which quotes values.
    assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());

    headersToSend = new ErrorFlightMetadata();

    // The Java HttpCookie class has a bug where it uses a -1 maxAge to indicate
    // a persistent cookie, when the RFC spec says this should mean the cookie expires immediately.
    headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=-1");
    cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
    cookieMiddleware.onHeadersReceived(headersToSend);

    // Verify that the k cookie was discarded because the server told the client it is expired.
    assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty());
  }

  @Test
  public void changeCookieValue() {
    CallHeaders headersToSend = new ErrorFlightMetadata();
    headersToSend.insert(SET_COOKIE_HEADER, "k=v");
    cookieMiddleware.onHeadersReceived(headersToSend);
    assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());

    headersToSend = new ErrorFlightMetadata();
    headersToSend.insert(SET_COOKIE_HEADER, "k=v2");
    cookieMiddleware.onHeadersReceived(headersToSend);
    assertEquals("k=v2", cookieMiddleware.getValidCookiesAsString());
  }

  @Test
  public void multipleCookiesWithSetCookie() {
    CallHeaders headersToSend = new ErrorFlightMetadata();
    headersToSend.insert(SET_COOKIE_HEADER, "firstKey=firstVal");
    headersToSend.insert(SET_COOKIE_HEADER, "secondKey=secondVal");
    cookieMiddleware.onHeadersReceived(headersToSend);
    assertEquals(
        "firstKey=firstVal; secondKey=secondVal", cookieMiddleware.getValidCookiesAsString());
  }

  @Test
  public void cookieStaysAfterMultipleRequestsEndToEnd() {
    client.handshake();
    assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString());
    client.handshake();
    assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString());
    client.listFlights(Criteria.ALL);
    assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString());
  }

  /** A server middleware component that injects SET_COOKIE_HEADER into the outgoing headers. */
  static class SetCookieHeaderInjector implements FlightServerMiddleware {
    private final Factory factory;

    public SetCookieHeaderInjector(Factory factory) {
      this.factory = factory;
    }

    @Override
    public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
      if (!factory.receivedCookieHeader) {
        outgoingHeaders.insert(SET_COOKIE_HEADER, "k=v");
      }
    }

    @Override
    public void onCallCompleted(CallStatus status) {}

    @Override
    public void onCallErrored(Throwable err) {}

    static class Factory implements FlightServerMiddleware.Factory<SetCookieHeaderInjector> {
      private boolean receivedCookieHeader = false;

      @Override
      public SetCookieHeaderInjector onCallStarted(
          CallInfo info, CallHeaders incomingHeaders, RequestContext context) {
        receivedCookieHeader = null != incomingHeaders.get(COOKIE_HEADER);
        return new SetCookieHeaderInjector(this);
      }
    }
  }

  public static class ClientCookieMiddlewareTestFactory extends ClientCookieMiddleware.Factory {

    private ClientCookieMiddleware clientCookieMiddleware;

    @Override
    public ClientCookieMiddleware onCallStarted(CallInfo info) {
      this.clientCookieMiddleware = new ClientCookieMiddleware(this);
      return this.clientCookieMiddleware;
    }
  }

  private void startServerAndClient() throws IOException {
    final FlightProducer flightProducer =
        new NoOpFlightProducer() {
          @Override
          public void listFlights(
              CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
            listener.onCompleted();
          }
        };

    this.server =
        FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), flightProducer)
            .middleware(
                FlightServerMiddleware.Key.of("test"), new SetCookieHeaderInjector.Factory())
            .build()
            .start();

    this.client =
        FlightClient.builder(allocator, server.getLocation()).intercept(testFactory).build();
  }
}