TestLeak.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Test;
/** Tests for scenarios where Flight could leak memory. */
public class TestLeak {
private static final int ROWS = 2048;
private static Schema getSchema() {
return new Schema(
Arrays.asList(
Field.nullable("0", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("1", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("2", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("3", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("4", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("5", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("6", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("7", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("8", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("9", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
Field.nullable("10", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))));
}
/**
* Ensure that if the client cancels, the server does not leak memory.
*
* <p>In gRPC, canceling the stream from the client sends an event to the server. Once processed,
* gRPC will start silently rejecting messages sent by the server. However, Flight depends on gRPC
* processing these messages in order to free the associated memory.
*/
@Test
public void testCancelingDoGetDoesNotLeak() throws Exception {
final CountDownLatch callFinished = new CountDownLatch(1);
try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
final FlightServer s =
FlightServer.builder(
allocator,
forGrpcInsecure(LOCALHOST, 0),
new LeakFlightProducer(allocator, callFinished))
.build()
.start();
final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
final FlightStream stream = client.getStream(new Ticket(new byte[0]));
stream.getRoot();
stream.cancel("Cancel", null);
// Wait for the call to finish. (Closing the allocator while a call is ongoing is a guaranteed
// leak.)
callFinished.await(60, TimeUnit.SECONDS);
s.shutdown();
s.awaitTermination();
}
}
@Test
public void testCancelingDoPutDoesNotBlock() throws Exception {
final CountDownLatch callFinished = new CountDownLatch(1);
try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
final FlightServer s =
FlightServer.builder(
allocator,
forGrpcInsecure(LOCALHOST, 0),
new LeakFlightProducer(allocator, callFinished))
.build()
.start();
final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
try (final VectorSchemaRoot root = VectorSchemaRoot.create(getSchema(), allocator)) {
final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]);
final SyncPutListener listener = new SyncPutListener();
final FlightClient.ClientStreamListener stream =
client.startPut(descriptor, root, listener);
// Wait for the server to cancel
callFinished.await(60, TimeUnit.SECONDS);
for (int col = 0; col < 11; col++) {
final Float8Vector vector = (Float8Vector) root.getVector(Integer.toString(col));
vector.allocateNew();
for (int row = 0; row < ROWS; row++) {
vector.setSafe(row, 10.);
}
}
root.setRowCount(ROWS);
// Unlike DoGet, this method fairly reliably will write the message to the stream, so even
// without the fix
// for ARROW-7343, this won't leak memory.
// However, it will block if FlightClient doesn't check for cancellation.
stream.putNext();
stream.completed();
}
s.shutdown();
s.awaitTermination();
}
}
/** A FlightProducer that always produces a fixed data stream with metadata on the side. */
private static class LeakFlightProducer extends NoOpFlightProducer {
private final BufferAllocator allocator;
private final CountDownLatch callFinished;
public LeakFlightProducer(BufferAllocator allocator, CountDownLatch callFinished) {
this.allocator = allocator;
this.callFinished = callFinished;
}
@Override
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
BufferAllocator childAllocator = allocator.newChildAllocator("foo", 0, Long.MAX_VALUE);
VectorSchemaRoot root = VectorSchemaRoot.create(TestLeak.getSchema(), childAllocator);
root.allocateNew();
listener.start(root);
// We can't poll listener#isCancelled since gRPC has two distinct "is cancelled" flags.
// TODO: should we continue leaking gRPC semantics? Can we even avoid this?
listener.setOnCancelHandler(
() -> {
try {
for (int col = 0; col < 11; col++) {
final Float8Vector vector = (Float8Vector) root.getVector(Integer.toString(col));
vector.allocateNew();
for (int row = 0; row < ROWS; row++) {
vector.setSafe(row, 10.);
}
}
root.setRowCount(ROWS);
// Once the call is "really cancelled" (setOnCancelListener has run/is running), this
// call is actually a
// no-op on the gRPC side and will leak the ArrowMessage unless Flight checks for
// this.
listener.putNext();
listener.completed();
} finally {
try {
root.close();
childAllocator.close();
} finally {
// Don't let the test hang if we throw above
callFinished.countDown();
}
}
});
}
@Override
public Runnable acceptPut(
CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
return () -> {
flightStream.getRoot();
ackStream.onError(CallStatus.CANCELLED.withDescription("CANCELLED").toRuntimeException());
callFinished.countDown();
ackStream.onCompleted();
};
}
}
}