PipesMessageTest.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.pipes.core.protocol;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.Test;
class PipesMessageTest {
@Test
void testRoundTripEmptyPayload() throws IOException {
for (PipesMessageType type : new PipesMessageType[]{
PipesMessageType.PING, PipesMessageType.ACK,
PipesMessageType.READY, PipesMessageType.SHUT_DOWN}) {
PipesMessage original = new PipesMessage(type, new byte[0]);
PipesMessage roundTripped = roundTrip(original);
assertEquals(type, roundTripped.type());
assertEquals(0, roundTripped.payload().length);
}
}
@Test
void testRoundTripWithPayload() throws IOException {
byte[] payload = "hello world".getBytes(StandardCharsets.UTF_8);
PipesMessage original = PipesMessage.finished(payload);
PipesMessage roundTripped = roundTrip(original);
assertEquals(PipesMessageType.FINISHED, roundTripped.type());
assertArrayEquals(payload, roundTripped.payload());
}
@Test
void testRoundTripAllTypes() throws IOException {
byte[] payload = "test".getBytes(StandardCharsets.UTF_8);
for (PipesMessageType type : PipesMessageType.values()) {
PipesMessage original = new PipesMessage(type, payload);
PipesMessage roundTripped = roundTrip(original);
assertEquals(type, roundTripped.type());
assertArrayEquals(payload, roundTripped.payload());
}
}
@Test
void testWorkingMessageRoundTrip() throws IOException {
PipesMessage original = PipesMessage.working(42L);
PipesMessage roundTripped = roundTrip(original);
assertEquals(PipesMessageType.WORKING, roundTripped.type());
assertEquals(42L, roundTripped.lastProgressMillis());
}
@Test
void testDesyncDetectionBadMagic() {
byte[] bad = new byte[]{0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00};
assertThrows(ProtocolDesyncException.class, () ->
PipesMessage.read(new DataInputStream(new ByteArrayInputStream(bad))));
}
@Test
void testDesyncDetectionPartialMagic() {
// First byte correct, second wrong
byte[] bad = new byte[]{0x54, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00};
assertThrows(ProtocolDesyncException.class, () ->
PipesMessage.read(new DataInputStream(new ByteArrayInputStream(bad))));
}
@Test
void testEofBeforeMagic() {
byte[] empty = new byte[0];
assertThrows(EOFException.class, () ->
PipesMessage.read(new DataInputStream(new ByteArrayInputStream(empty))));
}
@Test
void testEofAfterFirstMagicByte() {
byte[] partial = new byte[]{0x54};
assertThrows(EOFException.class, () ->
PipesMessage.read(new DataInputStream(new ByteArrayInputStream(partial))));
}
@Test
void testNegativePayloadLength() throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
dos.write(PipesMessage.MAGIC_0);
dos.write(PipesMessage.MAGIC_1);
dos.write(PipesMessageType.FINISHED.getByte());
dos.writeInt(-1); // negative length
dos.flush();
assertThrows(IOException.class, () ->
PipesMessage.read(new DataInputStream(new ByteArrayInputStream(baos.toByteArray()))));
}
@Test
void testOversizedPayloadRejection() throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
dos.write(PipesMessage.MAGIC_0);
dos.write(PipesMessage.MAGIC_1);
dos.write(PipesMessageType.FINISHED.getByte());
dos.writeInt(PipesMessage.MAX_PAYLOAD_BYTES + 1);
dos.flush();
assertThrows(IOException.class, () ->
PipesMessage.read(new DataInputStream(new ByteArrayInputStream(baos.toByteArray()))));
}
@Test
void testRequiresAckAssertions() {
assertFalse(PipesMessageType.PING.requiresAck());
assertFalse(PipesMessageType.ACK.requiresAck());
assertFalse(PipesMessageType.NEW_REQUEST.requiresAck());
assertFalse(PipesMessageType.SHUT_DOWN.requiresAck());
assertFalse(PipesMessageType.READY.requiresAck());
assertFalse(PipesMessageType.WORKING.requiresAck());
assertTrue(PipesMessageType.STARTUP_FAILED.requiresAck());
assertTrue(PipesMessageType.INTERMEDIATE_RESULT.requiresAck());
assertTrue(PipesMessageType.FINISHED.requiresAck());
assertTrue(PipesMessageType.OOM.requiresAck());
assertTrue(PipesMessageType.TIMEOUT.requiresAck());
assertTrue(PipesMessageType.UNSPECIFIED_CRASH.requiresAck());
}
@Test
void testGetByteAndLookupInverse() {
for (PipesMessageType type : PipesMessageType.values()) {
byte b = type.getByte();
PipesMessageType looked = PipesMessageType.lookup(b);
assertEquals(type, looked, "lookup(getByte()) failed for " + type);
}
}
@Test
void testLookupUnknownByte() {
assertThrows(IllegalArgumentException.class, () -> PipesMessageType.lookup(0xFF));
assertThrows(IllegalArgumentException.class, () -> PipesMessageType.lookup(0x00));
}
@Test
void testExitCodes() {
assertTrue(PipesMessageType.OOM.getExitCode().isPresent());
assertEquals(18, PipesMessageType.OOM.getExitCode().getAsInt());
assertTrue(PipesMessageType.TIMEOUT.getExitCode().isPresent());
assertEquals(17, PipesMessageType.TIMEOUT.getExitCode().getAsInt());
assertTrue(PipesMessageType.UNSPECIFIED_CRASH.getExitCode().isPresent());
assertEquals(19, PipesMessageType.UNSPECIFIED_CRASH.getExitCode().getAsInt());
assertFalse(PipesMessageType.PING.getExitCode().isPresent());
assertFalse(PipesMessageType.FINISHED.getExitCode().isPresent());
assertFalse(PipesMessageType.READY.getExitCode().isPresent());
}
@Test
void testConvenienceFactories() throws IOException {
assertEquals(PipesMessageType.PING, roundTrip(PipesMessage.ping()).type());
assertEquals(PipesMessageType.ACK, roundTrip(PipesMessage.ack()).type());
assertEquals(PipesMessageType.READY, roundTrip(PipesMessage.ready()).type());
assertEquals(PipesMessageType.SHUT_DOWN, roundTrip(PipesMessage.shutDown()).type());
byte[] data = "test".getBytes(StandardCharsets.UTF_8);
assertEquals(PipesMessageType.NEW_REQUEST, roundTrip(PipesMessage.newRequest(data)).type());
assertEquals(PipesMessageType.FINISHED, roundTrip(PipesMessage.finished(data)).type());
assertEquals(PipesMessageType.INTERMEDIATE_RESULT, roundTrip(PipesMessage.intermediateResult(data)).type());
assertEquals(PipesMessageType.STARTUP_FAILED, roundTrip(PipesMessage.startupFailed(data)).type());
assertEquals(PipesMessageType.OOM, roundTrip(PipesMessage.crash(PipesMessageType.OOM, data)).type());
}
private PipesMessage roundTrip(PipesMessage msg) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
msg.write(new DataOutputStream(baos));
return PipesMessage.read(new DataInputStream(new ByteArrayInputStream(baos.toByteArray())));
}
}