TestErrorMetadata.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.rpc.Status;
import io.grpc.Metadata;
import io.grpc.StatusRuntimeException;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.protobuf.StatusProto;
import java.nio.charset.StandardCharsets;
import org.apache.arrow.flight.perf.impl.PerfOuterClass;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.junit.jupiter.api.Test;
public class TestErrorMetadata {
private static final Metadata.BinaryMarshaller<Status> marshaller =
ProtoUtils.metadataMarshaller(Status.getDefaultInstance());
/** Ensure metadata attached to a gRPC error is propagated. */
@Test
public void testGrpcMetadata() throws Exception {
PerfOuterClass.Perf perf =
PerfOuterClass.Perf.newBuilder()
.setStreamCount(12)
.setRecordsPerBatch(1000)
.setRecordsPerStream(1000000L)
.build();
StatusRuntimeExceptionProducer producer = new StatusRuntimeExceptionProducer(perf);
try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
final FlightServer s =
FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer)
.build()
.start();
final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
final CallStatus flightStatus =
FlightTestUtil.assertCode(
FlightStatusCode.CANCELLED,
() -> {
FlightStream stream =
client.getStream(new Ticket("abs".getBytes(StandardCharsets.UTF_8)));
stream.next();
});
PerfOuterClass.Perf newPerf = null;
ErrorFlightMetadata metadata = flightStatus.metadata();
assertNotNull(metadata);
assertEquals(2, metadata.keys().size());
assertTrue(metadata.containsKey("grpc-status-details-bin"));
Status status = marshaller.parseBytes(metadata.getByte("grpc-status-details-bin"));
for (Any details : status.getDetailsList()) {
if (details.is(PerfOuterClass.Perf.class)) {
try {
newPerf = details.unpack(PerfOuterClass.Perf.class);
} catch (InvalidProtocolBufferException e) {
fail();
}
}
}
assertNotNull(newPerf);
assertEquals(perf, newPerf);
}
}
/** Ensure metadata attached to a Flight error is propagated. */
@Test
public void testFlightMetadata() throws Exception {
try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
final FlightServer s =
FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), new CallStatusProducer())
.build()
.start();
final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
CallStatus flightStatus =
FlightTestUtil.assertCode(
FlightStatusCode.INVALID_ARGUMENT,
() -> {
FlightStream stream = client.getStream(new Ticket(new byte[0]));
stream.next();
});
ErrorFlightMetadata metadata = flightStatus.metadata();
assertNotNull(metadata);
assertEquals("foo", metadata.get("x-foo"));
assertArrayEquals(new byte[] {1}, metadata.getByte("x-bar-bin"));
flightStatus =
FlightTestUtil.assertCode(
FlightStatusCode.INVALID_ARGUMENT,
() -> {
client.getInfo(FlightDescriptor.command(new byte[0]));
});
metadata = flightStatus.metadata();
assertNotNull(metadata);
assertEquals("foo", metadata.get("x-foo"));
assertArrayEquals(new byte[] {1}, metadata.getByte("x-bar-bin"));
}
}
private static class StatusRuntimeExceptionProducer extends NoOpFlightProducer {
private final PerfOuterClass.Perf perf;
private StatusRuntimeExceptionProducer(PerfOuterClass.Perf perf) {
this.perf = perf;
}
@Override
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
StatusRuntimeException sre =
StatusProto.toStatusRuntimeException(
Status.newBuilder()
.setCode(1)
.setMessage("Testing 1 2 3")
.addDetails(Any.pack(perf, "arrow/meta/types"))
.build());
listener.error(sre);
}
}
private static class CallStatusProducer extends NoOpFlightProducer {
ErrorFlightMetadata metadata;
CallStatusProducer() {
this.metadata = new ErrorFlightMetadata();
metadata.insert("x-foo", "foo");
metadata.insert("x-bar-bin", new byte[] {1});
}
@Override
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
listener.error(
CallStatus.INVALID_ARGUMENT
.withDescription("Failed")
.withMetadata(metadata)
.toRuntimeException());
}
@Override
public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
throw CallStatus.INVALID_ARGUMENT
.withDescription("Failed")
.withMetadata(metadata)
.toRuntimeException();
}
}
}