ExpirationTimeProducer.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight.integration.tests;

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.ActionType;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.CancelFlightInfoRequest;
import org.apache.arrow.flight.CancelFlightInfoResult;
import org.apache.arrow.flight.CancelStatus;
import org.apache.arrow.flight.FlightConstants;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.RenewFlightEndpointRequest;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;

/**
 * The server used for testing FlightEndpoint.expiration_time.
 *
 * <p>GetFlightInfo() returns a FlightInfo that has the following three FlightEndpoints:
 *
 * <ol>
 *   <li>No expiration time
 *   <li>5 seconds expiration time
 *   <li>6 seconds expiration time
 * </ol>
 *
 * The client can't read data from the first endpoint multiple times but can read data from the
 * second and third endpoints. The client can't re-read data from the second endpoint 5 seconds
 * later. The client can't re-read data from the third endpoint 6 seconds later.
 *
 * <p>The client can cancel a returned FlightInfo by pre-defined CancelFlightInfo action. The client
 * can't read data from endpoints even within 6 seconds after the action.
 *
 * <p>The client can extend the expiration time of a FlightEndpoint in a returned FlightInfo by
 * pre-defined RenewFlightEndpoint action. The client can read data from endpoints multiple times
 * within more 10 seconds after the action.
 */
final class ExpirationTimeProducer extends NoOpFlightProducer {
  public static final Schema SCHEMA =
      new Schema(
          Collections.singletonList(Field.notNullable("number", Types.MinorType.UINT4.getType())));

  private final BufferAllocator allocator;
  private final List<EndpointStatus> statuses;

  ExpirationTimeProducer(BufferAllocator allocator) {
    this.allocator = allocator;
    this.statuses = new ArrayList<>();
  }

  @Override
  public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
    statuses.clear();
    List<FlightEndpoint> endpoints = new ArrayList<>();
    Instant now = Instant.now();
    endpoints.add(addEndpoint("No expiration time", null));
    endpoints.add(addEndpoint("5 seconds", now.plus(5, ChronoUnit.SECONDS)));
    endpoints.add(addEndpoint("6 seconds", now.plus(6, ChronoUnit.SECONDS)));
    return new FlightInfo(SCHEMA, descriptor, endpoints, -1, -1);
  }

  @Override
  public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
    // Obviously, not safe (since we don't lock), but we assume calls are not concurrent
    int index = parseIndexFromTicket(ticket);
    EndpointStatus status = statuses.get(index);
    if (status.cancelled) {
      listener.error(
          CallStatus.NOT_FOUND
              .withDescription(
                  "Invalid flight: cancelled: "
                      + new String(ticket.getBytes(), StandardCharsets.UTF_8))
              .toRuntimeException());
      return;
    } else if (status.expirationTime != null && Instant.now().isAfter(status.expirationTime)) {
      listener.error(
          CallStatus.NOT_FOUND
              .withDescription(
                  "Invalid flight: expired: "
                      + new String(ticket.getBytes(), StandardCharsets.UTF_8))
              .toRuntimeException());
      return;
    } else if (status.expirationTime == null && status.numGets > 0) {
      listener.error(
          CallStatus.NOT_FOUND
              .withDescription(
                  "Invalid flight: can't read multiple times: "
                      + new String(ticket.getBytes(), StandardCharsets.UTF_8))
              .toRuntimeException());
      return;
    }
    status.numGets++;

    try (final VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA, allocator)) {
      listener.start(root);
      UInt4Vector vector = (UInt4Vector) root.getVector(0);
      vector.setSafe(0, index);
      root.setRowCount(1);
      listener.putNext();
    }
    listener.completed();
  }

  @Override
  public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
    try {
      if (action.getType().equals(FlightConstants.CANCEL_FLIGHT_INFO.getType())) {
        CancelFlightInfoRequest request =
            CancelFlightInfoRequest.deserialize(ByteBuffer.wrap(action.getBody()));
        CancelStatus cancelStatus = CancelStatus.UNSPECIFIED;
        for (FlightEndpoint endpoint : request.getInfo().getEndpoints()) {
          int index = parseIndexFromTicket(endpoint.getTicket());
          EndpointStatus status = statuses.get(index);
          if (status.cancelled) {
            cancelStatus = CancelStatus.NOT_CANCELLABLE;
          } else {
            status.cancelled = true;
            if (cancelStatus == CancelStatus.UNSPECIFIED) {
              cancelStatus = CancelStatus.CANCELLED;
            }
          }
        }
        listener.onNext(new Result(new CancelFlightInfoResult(cancelStatus).serialize().array()));
      } else if (action.getType().equals(FlightConstants.RENEW_FLIGHT_ENDPOINT.getType())) {
        RenewFlightEndpointRequest request =
            RenewFlightEndpointRequest.deserialize(ByteBuffer.wrap(action.getBody()));
        FlightEndpoint endpoint = request.getFlightEndpoint();
        int index = parseIndexFromTicket(endpoint.getTicket());
        EndpointStatus status = statuses.get(index);
        if (status.cancelled) {
          listener.onError(
              CallStatus.INVALID_ARGUMENT
                  .withDescription("Invalid flight: cancelled: " + index)
                  .toRuntimeException());
          return;
        }

        String ticketBody = new String(endpoint.getTicket().getBytes(), StandardCharsets.UTF_8);
        ticketBody += ": renewed (+ 10 seconds)";
        Ticket ticket = new Ticket(ticketBody.getBytes(StandardCharsets.UTF_8));
        Instant expiration = Instant.now().plus(10, ChronoUnit.SECONDS);
        status.expirationTime = expiration;
        FlightEndpoint newEndpoint =
            new FlightEndpoint(
                ticket, expiration, endpoint.getLocations().toArray(new Location[0]));
        listener.onNext(new Result(newEndpoint.serialize().array()));
      } else {
        listener.onError(
            CallStatus.INVALID_ARGUMENT
                .withDescription("Unknown action: " + action.getType())
                .toRuntimeException());
        return;
      }
    } catch (IOException | URISyntaxException e) {
      listener.onError(
          CallStatus.INTERNAL.withCause(e).withDescription(e.toString()).toRuntimeException());
      return;
    }
    listener.onCompleted();
  }

  @Override
  public void listActions(CallContext context, StreamListener<ActionType> listener) {
    listener.onNext(FlightConstants.CANCEL_FLIGHT_INFO);
    listener.onNext(FlightConstants.RENEW_FLIGHT_ENDPOINT);
    listener.onCompleted();
  }

  private FlightEndpoint addEndpoint(String ticket, Instant expirationTime) {
    Ticket flightTicket =
        new Ticket(
            String.format("%d: %s", statuses.size(), ticket).getBytes(StandardCharsets.UTF_8));
    statuses.add(new EndpointStatus(expirationTime));
    return new FlightEndpoint(flightTicket, expirationTime);
  }

  private int parseIndexFromTicket(Ticket ticket) {
    final String contents = new String(ticket.getBytes(), StandardCharsets.UTF_8);
    int index = contents.indexOf(':');
    if (index == -1) {
      throw CallStatus.INVALID_ARGUMENT
          .withDescription(
              "Invalid ticket: " + new String(ticket.getBytes(), StandardCharsets.UTF_8))
          .toRuntimeException();
    }
    int endpointIndex = Integer.parseInt(contents.substring(0, index));
    if (endpointIndex < 0 || endpointIndex >= statuses.size()) {
      throw CallStatus.NOT_FOUND.withDescription("Out of bounds").toRuntimeException();
    }
    return endpointIndex;
  }

  /** The status of a returned endpoint. */
  static final class EndpointStatus {
    Instant expirationTime;
    int numGets;
    boolean cancelled;

    EndpointStatus(Instant expirationTime) {
      this.expirationTime = expirationTime;
      this.numGets = 0;
      this.cancelled = false;
    }
  }
}