TikaInputStreamTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.io;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Instant;
import java.util.Locale;
import java.util.Random;

import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.tika.metadata.Metadata;
import org.apache.tika.metadata.TikaCoreProperties;

public class TikaInputStreamTest {

    private static final Logger LOG = LoggerFactory.getLogger(TikaInputStreamTest.class);

    @TempDir
    Path tempDir;

    @Test
    public void testFileBased() throws IOException {
        Path path = createTempFile("Hello, World!");
        TikaInputStream tis = TikaInputStream.get(path);
        assertTrue(tis.hasFile());
        assertNull(tis.getOpenContainer());

        assertEquals(path, TikaInputStream.get(tis).getPath(),
                "The file returned by the getFile() method should" +
                                " be the file used to instantiate a TikaInputStream");

        assertEquals("Hello, World!", readStream(tis),
                "The contents of the TikaInputStream should equal the" +
                        " contents of the underlying file");

        tis.close();
        assertTrue(Files.exists(path),
                "The close() method must not remove the file used to" +
                " instantiate a TikaInputStream");

    }

    @Test
    public void testStreamBased() throws IOException {
        InputStream input = IOUtils.toInputStream("Hello, World!", UTF_8);
        TikaInputStream tis = TikaInputStream.get(input);
        assertFalse(tis.hasFile());
        assertNull(tis.getOpenContainer());

        Path file = TikaInputStream.get(tis).getPath();
        assertTrue(file != null && Files.isRegularFile(file));
        assertTrue(tis.hasFile());
        assertNull(tis.getOpenContainer());

        assertEquals("Hello, World!", readFile(file),
                "The contents of the file returned by the getFile method" +
                        " should equal the contents of the TikaInputStream");

        assertEquals("Hello, World!", readStream(tis),
                "The contents of the TikaInputStream should not get modified" +
                        " by reading the file first");

        tis.close();
        assertFalse(Files.exists(file),
                "The close() method must remove the temporary file created by a TikaInputStream");
    }

    private Path createTempFile(String data) throws IOException {
        Path file = Files.createTempFile(tempDir, "tika-", ".tmp");
        Files.write(file, data.getBytes(UTF_8));
        return file;
    }

    private String readFile(Path file) throws IOException {
        return new String(Files.readAllBytes(file), UTF_8);
    }

    private String readStream(InputStream stream) throws IOException {
        return IOUtils.toString(stream, UTF_8);
    }

    @Test
    public void testGetMetadata() throws Exception {
        URL url = TikaInputStreamTest.class.getResource("test.txt");
        Metadata metadata = new Metadata();
        TikaInputStream.get(url, metadata).close();
        assertEquals("test.txt", metadata.get(TikaCoreProperties.RESOURCE_NAME_KEY));
        assertEquals(Long.toString(Files.size(Paths.get(url.toURI()))),
                metadata.get(Metadata.CONTENT_LENGTH));
    }

    // ========== New Caching Tests ==========

    @Test
    public void testMarkReset() throws IOException {
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            byte[] first = new byte[5];
            tis.read(first);
            assertEquals("Hello", str(first));

            tis.mark(100);

            byte[] next = new byte[2];
            tis.read(next);
            assertEquals(", ", str(next));

            tis.reset();

            byte[] again = new byte[2];
            tis.read(again);
            assertEquals(", ", str(again));
        }
    }

    @Test
    public void testMarkResetAtZero() throws IOException {
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            tis.mark(100);

            byte[] buffer = new byte[data.length];
            tis.read(buffer);

            tis.reset();
            assertEquals(0, tis.getPosition());

            byte[] again = new byte[data.length];
            tis.read(again);
            assertArrayEquals(data, again);
        }
    }

    @Test
    public void testMultipleMarkReset() throws IOException {
        byte[] data = bytes("ABCDEFGHIJ");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            tis.mark(100);
            byte[] buf = new byte[3];
            tis.read(buf);
            assertEquals("ABC", str(buf));

            tis.reset();
            tis.mark(100);
            buf = new byte[5];
            tis.read(buf);
            assertEquals("ABCDE", str(buf));

            tis.mark(100);
            buf = new byte[3];
            tis.read(buf);
            assertEquals("FGH", str(buf));

            tis.reset();
            buf = new byte[3];
            tis.read(buf);
            assertEquals("FGH", str(buf));
        }
    }

    @Test
    public void testRewind() throws IOException {
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            byte[] buffer = new byte[data.length];
            tis.read(buffer);
            assertEquals(data.length, tis.getPosition());

            tis.rewind();
            assertEquals(0, tis.getPosition());

            byte[] again = new byte[data.length];
            tis.read(again);
            assertArrayEquals(data, again);
        }
    }

    @Test
    public void testGetPathPreservesPosition() throws IOException {
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(data))) {
            tis.enableRewind(); // Enable caching for getPath() support after reading

            byte[] buf = new byte[5];
            tis.read(buf);
            assertEquals(5, tis.getPosition());

            Path path = tis.getPath();
            assertNotNull(path);
            assertEquals(5, tis.getPosition());

            buf = new byte[2];
            tis.read(buf);
            assertEquals(", ", str(buf));
        }
    }

    @Test
    public void testFileBackedMarkReset() throws IOException {
        Path tempFile = createTempFile("ABCDEFGHIJ");

        try (TikaInputStream tis = TikaInputStream.get(tempFile)) {
            byte[] buf = new byte[3];
            tis.read(buf);
            assertEquals("ABC", str(buf));

            tis.mark(100);

            buf = new byte[4];
            tis.read(buf);
            assertEquals("DEFG", str(buf));

            tis.reset();
            assertEquals(3, tis.getPosition());

            buf = new byte[4];
            tis.read(buf);
            assertEquals("DEFG", str(buf));
        }
    }

    @Test
    public void testSkip() throws IOException {
        byte[] data = bytes("ABCDEFGHIJ");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            tis.skip(3);
            assertEquals(3, tis.getPosition());

            byte[] buf = new byte[4];
            tis.read(buf);
            assertEquals("DEFG", str(buf));
        }
    }

    @Test
    public void testLargeStreamSpillsToFile() throws IOException {
        byte[] data = new byte[2 * 1024 * 1024]; // 2MB
        new Random(42).nextBytes(data);

        try (TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(data))) {
            tis.enableRewind(); // Enable caching for rewind support

            byte[] buffer = new byte[data.length];
            int totalRead = 0;
            int n;
            while ((n = tis.read(buffer, totalRead, buffer.length - totalRead)) != -1) {
                totalRead += n;
                if (totalRead >= buffer.length) {
                    break;
                }
            }
            assertEquals(data.length, totalRead);

            tis.rewind();
            assertEquals(0, tis.getPosition());

            byte[] again = new byte[data.length];
            totalRead = 0;
            while ((n = tis.read(again, totalRead, again.length - totalRead)) != -1) {
                totalRead += n;
                if (totalRead >= again.length) {
                    break;
                }
            }
            assertArrayEquals(data, again);
        }
    }

    @Test
    public void testResetWithoutMark() throws IOException {
        byte[] data = bytes("Hello");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            tis.read();
            assertThrows(IOException.class, tis::reset);
        }
    }

    @Test
    public void testPeek() throws IOException {
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            byte[] peekBuffer = new byte[5];
            int peeked = tis.peek(peekBuffer);
            assertEquals(5, peeked);
            assertEquals("Hello", str(peekBuffer));

            assertEquals(0, tis.getPosition());

            byte[] readBuffer = new byte[5];
            tis.read(readBuffer);
            assertEquals("Hello", str(readBuffer));
            assertEquals(5, tis.getPosition());
        }
    }

    @Test
    public void testPosition() throws IOException {
        byte[] data = bytes("ABCDEFGHIJ");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            assertEquals(0, tis.getPosition());

            tis.read();
            assertEquals(1, tis.getPosition());

            tis.read(new byte[3]);
            assertEquals(4, tis.getPosition());

            tis.skip(2);
            assertEquals(6, tis.getPosition());
        }
    }

    @Test
    public void testLength() throws IOException {
        byte[] data = bytes("Hello");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            assertTrue(tis.hasLength());
            assertEquals(5, tis.getLength());
        }
    }

    @Test
    public void testCloseShield() throws IOException {
        byte[] data = bytes("Hello");
        TikaInputStream tis = TikaInputStream.get(data);

        assertFalse(tis.isCloseShield());
        tis.setCloseShield();
        assertTrue(tis.isCloseShield());

        tis.close();

        byte[] buf = new byte[5];
        tis.read(buf);
        assertEquals("Hello", str(buf));

        tis.removeCloseShield();
        assertFalse(tis.isCloseShield());
        tis.close();
    }

    // ========== Randomized Tests ==========

    // --- ByteArray Backed Tests ---

    @RepeatedTest(50)
    public void testRandomizedSizeAndTypes() throws Exception {
        int minSz = 0;
        int maxSz = 2 * 1024 * 1024;

        long seed = System.currentTimeMillis();
        Random random = new Random(seed);
        int sz = minSz + random.nextInt(maxSz - minSz + 1);
        BackingType type = BackingType.values()[random.nextInt(BackingType.values().length)];
        runRandomizedTest(sz, type, seed);
    }

    @RepeatedTest(10)
    public void testRandomizedSizeStepsAndTypes() throws Exception {
        for (int sz : new int[]{ 0, 100, 8191, 8192, 8193, 100_000, 2 * 1024 * 1024 }) {
             for (BackingType type : BackingType.values()) {
               runRandomizedTest(sz, type);
             }
        }
    }

    @RepeatedTest(10)
    public void testRandomizedOperationsTest() throws Exception {
        long seed = Instant.now().getEpochSecond();
        for (int sz : new int[]{ 0, 100, 8191, 8192, 8193, 100_000, 2 * 1024 * 1024 }) {
            for (BackingType type : BackingType.values()) {
                runRandomizedOperationsTest(sz, type, seed);
            }
        }
    }




    /**
     * Backing strategy types for TikaInputStream.
     */
    private enum BackingType {
        BYTE_ARRAY,  // TikaInputStream.get(byte[]) - ByteArrayBackedStrategy
        FILE,        // TikaInputStream.get(Path) - FileBackedStrategy
        STREAM       // TikaInputStream.get(InputStream) - StreamBackedStrategy
    }

    /**
     * Reproduce a failing randomized test by specifying the backingType and seed.
     * Enable this test and set the parameters from a failing test's log output.
     * The size is derived from the seed using the same logic as the original test.
     */
    @Disabled("Enable this test to reproduce a specific failure using seed from logs")
    @Test
    public void reproduceRandomizedTestFailure() throws Exception {
        // Set these values from the failing test's log output:
        BackingType backingType = BackingType.STREAM;
        long seed = 1768257360409L;
        int size = 1;
        runRandomizedOperationsTest(size, backingType, seed);
    }

    private void runRandomizedTest(int size, BackingType backingType) throws Exception {
        runRandomizedTest(size, backingType, System.currentTimeMillis());
    }

    private void runRandomizedTest(int size, BackingType backingType, long seed) throws Exception {
        LOG.debug("runRandomizedTest: size={}, backingType={}, seed={}", size, backingType, seed);
        Random random = new Random(seed);

        byte[] data = new byte[size];
        if (size > 0) {
            random.nextBytes(data);
        }
        String expectedDigest = computeDigest(data);

        try (TikaInputStream tis = createTikaInputStream(data, backingType)) {
            byte[] readData = readAllBytes(tis);
            String actualDigest = computeDigest(readData);
            assertEquals(expectedDigest, actualDigest,
                    "Digest mismatch for size=" + size + ", backingType=" + backingType + ", seed=" + seed);
            assertEquals(size, readData.length);
        }

        try (TikaInputStream tis = createTikaInputStream(data, backingType)) {
            byte[] readData = readAllBytes(tis);
            tis.rewind();
            byte[] rereadData = readAllBytes(tis);
            assertArrayEquals(readData, rereadData,
                    "Data mismatch after rewind for size=" + size + ", backingType=" + backingType + ", seed=" + seed);
        }
    }

    private void runRandomizedOperationsTest(int size, BackingType backingType, long seed) throws Exception {
        LOG.debug("runRandomizedOperationsTest: size={}, backingType={}, seed={}", size, backingType, seed);
        String ctx = "size=" + size + ", backingType=" + backingType + ", seed=" + seed;
        Random random = new Random(seed);
        // Skip the first random value (used for size selection in the calling test)
        random.nextInt();

        byte[] data = new byte[size];
        if (size > 0) {
            random.nextBytes(data);
        }

        try (TikaInputStream tis = createTikaInputStream(data, backingType)) {
            int position = 0;
            int markPosition = -1;
            int numOps = random.nextInt(50) + 10;

            for (int op = 0; op < numOps; op++) {
                int operation = random.nextInt(9);

                switch (operation) {
                    case 0: // single byte read
                        if (position < size) {
                            int expectedByte = data[position] & 0xFF;
                            int actualByte = tis.read();
                            assertEquals(expectedByte, actualByte,
                                    "Single byte read mismatch at position " + position + ", " + ctx);
                            position++;
                        } else {
                            assertEquals(-1, tis.read(),
                                    "Expected EOF at position " + position + ", " + ctx);
                        }
                        break;

                    case 1: // bulk read
                        // Ensure readLen is at least 1 to avoid zero-length buffer reads
                        int readLen = random.nextInt(Math.min(1000, size + 100)) + 1;
                        byte[] buffer = new byte[readLen];
                        int bytesRead = tis.read(buffer);
                        if (position >= size) {
                            assertTrue(bytesRead <= 0,
                                    "Expected EOF, got " + bytesRead + " bytes, " + ctx);
                        } else {
                            assertTrue(bytesRead > 0,
                                    "Expected data, got " + bytesRead + ", " + ctx);
                            for (int i = 0; i < bytesRead; i++) {
                                assertEquals(data[position + i], buffer[i],
                                        "Bulk read mismatch at offset " + i + ", " + ctx);
                            }
                            position += bytesRead;
                        }
                        break;

                    case 2: // skip
                        long skipAmount = random.nextInt(size + 100);
                        long skipped = tis.skip(skipAmount);
                        assertTrue(skipped >= 0 && skipped <= skipAmount,
                                "Skip returned invalid value " + skipped + ", " + ctx);
                        position += (int) skipped;
                        if (position > size) {
                            position = size;
                        }
                        break;

                    case 3: // mark
                        int readLimit = random.nextInt(size + 100) + 1;
                        tis.mark(readLimit);
                        markPosition = position;
                        break;

                    case 4: // reset
                        if (markPosition >= 0) {
                            tis.reset();
                            position = markPosition;
                            markPosition = -1;
                        }
                        break;

                    case 5: // rewind
                        tis.rewind();
                        position = 0;
                        markPosition = -1;
                        break;

                    case 6: // getPath - forces spill to file for stream-backed
                        Path path = tis.getPath();
                        assertNotNull(path, "getPath() returned null, " + ctx);
                        assertTrue(Files.exists(path), "Path doesn't exist, " + ctx);
                        assertEquals(size, Files.size(path), "File size mismatch, " + ctx);
                        break;

                    case 7: // peek
                        int peekLen = random.nextInt(Math.min(100, size + 10)) + 1;
                        byte[] peekBuf = new byte[peekLen];
                        int peeked = tis.peek(peekBuf);
                        if (position >= size) {
                            assertTrue(peeked <= 0, "Expected EOF on peek, " + ctx);
                        } else {
                            assertTrue(peeked > 0, "Expected data on peek, " + ctx);
                            for (int i = 0; i < peeked; i++) {
                                assertEquals(data[position + i], peekBuf[i],
                                        "Peek mismatch at offset " + i + ", " + ctx);
                            }
                        }
                        // position doesn't change after peek, but peek() uses mark/reset
                        // internally which overwrites any existing mark
                        markPosition = -1;
                        break;

                    case 8: // available
                        int avail = tis.available();
                        assertTrue(avail >= 0, "available() returned negative, " + ctx);
                        break;

                    default:
                        break;
                }
                assertEquals(position, tis.getPosition(),
                        "Position mismatch after operation " + operation + ", " + ctx);
            }

            tis.rewind();
            assertEquals(0, tis.getPosition(), "Position should be 0 after rewind, " + ctx);
            byte[] finalRead = readAllBytes(tis);
            String expectedDigest = computeDigest(data);
            String actualDigest = computeDigest(finalRead);
            assertEquals(expectedDigest, actualDigest, "Final digest mismatch, " + ctx);
        }
    }

    @Test
    public void testMarkBeyondStreamLength() throws Exception {
        byte[] data = bytes("Short");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            tis.mark(1000);
            byte[] buf = readAllBytes(tis);
            assertEquals("Short", str(buf));
            tis.reset();
            assertEquals(0, tis.getPosition());
            buf = readAllBytes(tis);
            assertEquals("Short", str(buf));
        }
    }

    @Test
    public void testSkipBeyondStreamLength() throws Exception {
        byte[] data = bytes("Short");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            long skipped = tis.skip(1000);
            assertEquals(5, skipped);
            assertEquals(-1, tis.read());
        }
    }

    @Test
    public void testMarkResetSkipCombination() throws Exception {
        byte[] data = bytes("ABCDEFGHIJKLMNOPQRSTUVWXYZ");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            tis.mark(100);
            tis.skip(10);
            assertEquals(10, tis.getPosition());

            byte[] buf = new byte[5];
            tis.read(buf);
            assertEquals("KLMNO", str(buf));

            tis.reset();
            assertEquals(0, tis.getPosition());

            buf = new byte[5];
            tis.read(buf);
            assertEquals("ABCDE", str(buf));
        }
    }

    @Test
    public void testFileBackedMarkResetSkip() throws Exception {
        byte[] data = bytes("ABCDEFGHIJKLMNOPQRSTUVWXYZ");
        Path tempFile = createTempFile("ABCDEFGHIJKLMNOPQRSTUVWXYZ");

        try (TikaInputStream tis = TikaInputStream.get(tempFile)) {
            tis.skip(5);
            assertEquals(5, tis.getPosition());

            tis.mark(100);

            tis.skip(10);
            assertEquals(15, tis.getPosition());

            byte[] buf = new byte[5];
            tis.read(buf);
            assertEquals("PQRST", str(buf));

            tis.reset();
            assertEquals(5, tis.getPosition());

            buf = new byte[5];
            tis.read(buf);
            assertEquals("FGHIJ", str(buf));
        }
    }

    // ========== CachingSource Tests ==========

    @Test
    public void testCachingSourceUpdatesMetadataOnSpill() throws IOException {
        byte[] data = bytes("Hello, World!");
        Metadata metadata = new Metadata();
        // Don't set CONTENT_LENGTH - let CachingSource set it on spill

        try (TemporaryResources tmp = new TemporaryResources()) {
            CachingSource source = new CachingSource(
                    new ByteArrayInputStream(data), tmp, -1, metadata);
            source.enableRewind(); // Enable caching for spill support

            // Read all data
            byte[] buffer = new byte[data.length];
            int totalRead = 0;
            int n;
            while ((n = source.read(buffer, totalRead, buffer.length - totalRead)) != -1) {
                totalRead += n;
                if (totalRead >= buffer.length) break;
            }

            // Before spill, metadata should not have length
            assertNull(metadata.get(Metadata.CONTENT_LENGTH));

            // Force spill to file
            Path path = source.getPath(".tmp");
            assertNotNull(path);
            assertTrue(Files.exists(path));

            // After spill, metadata should have length
            assertEquals("13", metadata.get(Metadata.CONTENT_LENGTH));

            source.close();
        }
    }

    @Test
    public void testCachingSourceDoesNotOverwriteExistingMetadata() throws IOException {
        byte[] data = bytes("Hello, World!");
        Metadata metadata = new Metadata();
        // Pre-set CONTENT_LENGTH
        metadata.set(Metadata.CONTENT_LENGTH, "999");

        try (TemporaryResources tmp = new TemporaryResources()) {
            CachingSource source = new CachingSource(
                    new ByteArrayInputStream(data), tmp, -1, metadata);
            source.enableRewind(); // Enable caching for seek/spill support

            // Read and spill
            IOUtils.toByteArray(source);
            source.seekTo(0);
            Path path = source.getPath(".tmp");

            // Existing value should not be overwritten
            assertEquals("999", metadata.get(Metadata.CONTENT_LENGTH));

            source.close();
        }
    }

    @Test
    public void testCachingSourceSeekTo() throws IOException {
        byte[] data = bytes("ABCDEFGHIJ");

        try (TemporaryResources tmp = new TemporaryResources()) {
            CachingSource source = new CachingSource(
                    new ByteArrayInputStream(data), tmp, -1, null);
            source.enableRewind(); // Enable caching for seek support

            // Read first 5 bytes
            byte[] buf = new byte[5];
            source.read(buf);
            assertEquals("ABCDE", str(buf));

            // Seek back to position 2
            source.seekTo(2);

            // Read again
            buf = new byte[3];
            source.read(buf);
            assertEquals("CDE", str(buf));

            source.close();
        }
    }

    @Test
    public void testCachingSourceAfterSpill() throws IOException {
        byte[] data = bytes("ABCDEFGHIJ");

        try (TemporaryResources tmp = new TemporaryResources()) {
            CachingSource source = new CachingSource(
                    new ByteArrayInputStream(data), tmp, -1, null);
            source.enableRewind(); // Enable caching for spill/seek support

            // Read first 5 bytes
            byte[] buf = new byte[5];
            source.read(buf);
            assertEquals("ABCDE", str(buf));

            // Force spill
            Path path = source.getPath(".tmp");
            assertTrue(Files.exists(path));

            // Continue reading after spill
            buf = new byte[5];
            source.read(buf);
            assertEquals("FGHIJ", str(buf));

            // Seek back and read again
            source.seekTo(0);
            buf = new byte[10];
            source.read(buf);
            assertEquals("ABCDEFGHIJ", str(buf));

            source.close();
        }
    }

    // ========== enableRewind() Tests ==========

    @Test
    public void testEnableRewindByteArrayNoOp() throws Exception {
        // ByteArraySource is always rewindable - enableRewind() is no-op
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(data)) {
            tis.enableRewind(); // Should be no-op

            byte[] buf = new byte[5];
            tis.read(buf);
            assertEquals("Hello", str(buf));

            tis.rewind();
            assertEquals(0, tis.getPosition());

            buf = new byte[5];
            tis.read(buf);
            assertEquals("Hello", str(buf));
        }
    }

    @Test
    public void testEnableRewindFileNoOp() throws Exception {
        // FileSource is always rewindable - enableRewind() is no-op
        Path tempFile = createTempFile("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(tempFile)) {
            tis.enableRewind(); // Should be no-op

            byte[] buf = new byte[5];
            tis.read(buf);
            assertEquals("Hello", str(buf));

            tis.rewind();
            assertEquals(0, tis.getPosition());

            buf = new byte[5];
            tis.read(buf);
            assertEquals("Hello", str(buf));
        }
    }

    @Test
    public void testEnableRewindStreamEnablesCaching() throws Exception {
        // CachingSource starts in passthrough mode, enableRewind() enables caching
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(data))) {
            tis.enableRewind(); // Enable caching mode

            byte[] buf = new byte[5];
            tis.read(buf);
            assertEquals("Hello", str(buf));

            tis.rewind();
            assertEquals(0, tis.getPosition());

            buf = new byte[5];
            tis.read(buf);
            assertEquals("Hello", str(buf));
        }
    }

    @Test
    public void testEnableRewindAfterReadThrows() throws Exception {
        // enableRewind() must be called at position 0
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(data))) {
            tis.read(); // Read one byte, position is now 1
            assertEquals(1, tis.getPosition());

            assertThrows(IllegalStateException.class, tis::enableRewind,
                    "enableRewind() should throw when position != 0");
        }
    }

    @Test
    public void testEnableRewindMultipleCallsNoOp() throws Exception {
        // Multiple enableRewind() calls should be safe (no-op after first)
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(data))) {
            tis.enableRewind();
            tis.enableRewind(); // Should be no-op
            tis.enableRewind(); // Should be no-op

            byte[] buf = readAllBytes(tis);
            assertEquals("Hello, World!", str(buf));

            tis.rewind();
            buf = readAllBytes(tis);
            assertEquals("Hello, World!", str(buf));
        }
    }

    @Test
    public void testStreamWithoutEnableRewindCannotRewind() throws Exception {
        // Without enableRewind(), CachingSource is in passthrough mode
        // rewind() should fail after reading in passthrough mode
        byte[] data = bytes("Hello, World!");
        try (TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(data))) {
            // Don't call enableRewind()

            byte[] buf = new byte[5];
            tis.read(buf);
            assertEquals("Hello", str(buf));

            // rewind() internally calls reset() which calls seekTo()
            // In passthrough mode, seekTo() fails if not at current position
            assertThrows(IOException.class, tis::rewind,
                    "rewind() should fail in passthrough mode after reading");
        }
    }

    @Test
    public void testMarkResetThenEnableRewind() throws Exception {
        // Test transitioning from passthrough mode (using BufferedInputStream's mark/reset)
        // to caching mode via enableRewind()
        byte[] data = bytes("ABCDEFGHIJKLMNOPQRSTUVWXYZ");
        try (TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(data))) {
            // Passthrough mode - use BufferedInputStream's mark/reset
            tis.mark(100);
            byte[] buf = new byte[5];
            tis.read(buf);
            assertEquals("ABCDE", str(buf));

            tis.reset();  // Back to 0
            assertEquals(0, tis.getPosition());

            // Another mark/reset cycle in passthrough mode
            tis.mark(100);
            buf = new byte[10];
            tis.read(buf);
            assertEquals("ABCDEFGHIJ", str(buf));

            tis.reset();  // Back to 0 again
            assertEquals(0, tis.getPosition());

            // Now enable rewind (switches to caching mode)
            tis.enableRewind();

            // Should still work with caching mode
            buf = new byte[5];
            tis.read(buf);
            assertEquals("ABCDE", str(buf));

            tis.rewind();  // Full rewind now works
            assertEquals(0, tis.getPosition());

            buf = readAllBytes(tis);
            assertEquals("ABCDEFGHIJKLMNOPQRSTUVWXYZ", str(buf));
        }
    }

    // ========== Helper Methods ==========

    private TikaInputStream createTikaInputStream(byte[] data, boolean fileBacked) throws IOException {
        return createTikaInputStream(data, fileBacked ? BackingType.FILE : BackingType.STREAM);
    }

    private TikaInputStream createTikaInputStream(byte[] data, BackingType backingType) throws IOException {
        switch (backingType) {
            case BYTE_ARRAY:
                return TikaInputStream.get(data);
            case FILE:
                Path file = Files.createTempFile(tempDir, "test_", ".bin");
                Files.write(file, data);
                return TikaInputStream.get(file);
            case STREAM:
                TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(data));
                tis.enableRewind(); // Enable caching for rewind support in tests
                return tis;
            default:
                throw new IllegalArgumentException("Unknown backing type: " + backingType);
        }
    }

    private byte[] readAllBytes(TikaInputStream tis) throws IOException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        byte[] buffer = new byte[4096];
        int n;
        while ((n = tis.read(buffer)) != -1) {
            baos.write(buffer, 0, n);
        }
        return baos.toByteArray();
    }

    private String computeDigest(byte[] data) throws NoSuchAlgorithmException {
        MessageDigest md = MessageDigest.getInstance("SHA-256");
        byte[] digest = md.digest(data);
        StringBuilder sb = new StringBuilder();
        for (byte b : digest) {
            sb.append(String.format(Locale.ROOT, "%02x", b));
        }
        return sb.toString();
    }

    private static byte[] bytes(String s) {
        return s.getBytes(UTF_8);
    }

    private static String str(byte[] b) {
        return new String(b, UTF_8);
    }
}