OpenAIVLMParserTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.parser.vlm;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.ByteArrayInputStream;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.xml.sax.helpers.DefaultHandler;

import org.apache.tika.exception.TikaException;
import org.apache.tika.http.TikaTestHttpServer;
import org.apache.tika.io.TikaInputStream;
import org.apache.tika.metadata.Metadata;
import org.apache.tika.parser.ParseContext;
import org.apache.tika.sax.BodyContentHandler;

public class OpenAIVLMParserTest {

    private static final ObjectMapper MAPPER = new ObjectMapper();

    private TikaTestHttpServer server;
    private OpenAIVLMParser parser;
    private VLMOCRConfig config;

    @BeforeEach
    void setUp() throws Exception {
        server = new TikaTestHttpServer();

        config = new VLMOCRConfig();
        config.setBaseUrl(server.url());
        config.setModel("test-model");
        config.setPrompt("Extract text from this image.");
        config.setMaxTokens(1024);
        config.setTimeoutSeconds(10);

        // Queue 200 for the GET /v1/models health check in initialize()
        server.enqueue(new TikaTestHttpServer.MockResponse(200, "{\"object\":\"list\"}"));
        parser = new OpenAIVLMParser(config);
        parser.initialize();
        server.clearRequests(); // discard the health-check request from the log
    }

    @AfterEach
    void tearDown() {
        server.shutdown();
    }

    @Test
    void testSuccessfulOcr() throws Exception {
        String ocrText = "Hello, World!\nThis is extracted text.";

        server.enqueue(new TikaTestHttpServer.MockResponse(200,
                buildChatResponse(ocrText, 100, 20)));

        Metadata metadata = new Metadata();
        metadata.set(Metadata.CONTENT_TYPE, "image/png");

        BodyContentHandler handler = new BodyContentHandler();
        byte[] fakeImage = new byte[]{(byte) 0x89, 'P', 'N', 'G'};

        try (TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(fakeImage))) {
            parser.parse(tis, handler, metadata, new ParseContext());
        }

        assertTrue(handler.toString().contains("Hello, World!"));
        assertEquals("test-model", metadata.get(AbstractVLMParser.VLM_MODEL));
        assertEquals("100", metadata.get(AbstractVLMParser.VLM_PROMPT_TOKENS));
        assertEquals("20", metadata.get(AbstractVLMParser.VLM_COMPLETION_TOKENS));

        TikaTestHttpServer.RecordedRequest request = server.takeRequest();
        assertEquals("/v1/chat/completions", request.path());
        assertEquals("POST", request.method());

        JsonNode body = MAPPER.readTree(request.body());
        assertEquals("test-model", body.get("model").asText());
        assertEquals(1024, body.get("max_tokens").asInt());

        JsonNode messages = body.get("messages");
        assertNotNull(messages);
        assertEquals("user", messages.get(0).get("role").asText());

        JsonNode parts = messages.get(0).get("content");
        assertEquals("text", parts.get(0).get("type").asText());
        assertEquals("image_url", parts.get(1).get("type").asText());
        assertTrue(parts.get(1).get("image_url").get("url").asText()
                .startsWith("data:image/png;base64,"));
    }

    @Test
    void testServerError() throws Exception {
        server.enqueue(new TikaTestHttpServer.MockResponse(500, "{\"error\":\"boom\"}"));

        Metadata metadata = new Metadata();
        metadata.set(Metadata.CONTENT_TYPE, "image/png");

        assertThrows(TikaException.class, () -> {
            try (TikaInputStream tis = TikaInputStream.get(
                    new ByteArrayInputStream(new byte[]{1, 2, 3, 4}))) {
                parser.parse(tis, new DefaultHandler(), metadata, new ParseContext());
            }
        });
    }

    @Test
    void testSkipOcr() throws Exception {
        config.setSkipOcr(true);
        parser = new OpenAIVLMParser(config);

        Metadata metadata = new Metadata();
        metadata.set(Metadata.CONTENT_TYPE, "image/png");
        BodyContentHandler handler = new BodyContentHandler();

        try (TikaInputStream tis = TikaInputStream.get(
                new ByteArrayInputStream(new byte[]{1, 2, 3, 4}))) {
            parser.parse(tis, handler, metadata, new ParseContext());
        }

        assertEquals(0, server.getRequestCount());
    }

    @Test
    void testFileSizeFiltering() throws Exception {
        config.setMinFileSizeToOcr(100);
        parser = new OpenAIVLMParser(config);

        Metadata metadata = new Metadata();
        metadata.set(Metadata.CONTENT_TYPE, "image/png");

        try (TikaInputStream tis = TikaInputStream.get(
                new ByteArrayInputStream(new byte[10]))) {
            parser.parse(tis, new BodyContentHandler(), metadata, new ParseContext());
        }

        assertEquals(0, server.getRequestCount());
    }

    @Test
    void testApiKeyHeader() throws Exception {
        config.setApiKey("sk-test-key");
        parser = new OpenAIVLMParser(config);

        server.enqueue(new TikaTestHttpServer.MockResponse(200,
                buildChatResponse("text", 10, 5)));

        Metadata metadata = new Metadata();
        metadata.set(Metadata.CONTENT_TYPE, "image/jpeg");

        try (TikaInputStream tis = TikaInputStream.get(
                new ByteArrayInputStream(new byte[]{(byte) 0xFF, (byte) 0xD8}))) {
            parser.parse(tis, new BodyContentHandler(), metadata, new ParseContext());
        }

        assertEquals("Bearer sk-test-key", server.takeRequest().header("authorization"));
    }

    @Test
    void testAzureStyleAuth() throws Exception {
        config.setApiKey("azure-key-123");
        parser = new OpenAIVLMParser(config);
        config.setCompletionsPath("/openai/deployments/gpt-4o/chat/completions?api-version=2024-02-01");
        parser = new OpenAIVLMParser(config);
        parser.setApiKeyHeaderName("api-key");
        parser.setApiKeyPrefix("");
        parser.setCompletionsPath(
                "/openai/deployments/gpt-4o/chat/completions?api-version=2024-02-01");

        server.enqueue(new TikaTestHttpServer.MockResponse(200,
                buildChatResponse("text", 10, 5)));

        Metadata metadata = new Metadata();
        metadata.set(Metadata.CONTENT_TYPE, "image/jpeg");

        try (TikaInputStream tis = TikaInputStream.get(
                new ByteArrayInputStream(new byte[]{(byte) 0xFF, (byte) 0xD8}))) {
            parser.parse(tis, new BodyContentHandler(), metadata, new ParseContext());
        }

        TikaTestHttpServer.RecordedRequest request = server.takeRequest();
        assertEquals("azure-key-123", request.header("api-key"));
        assertNull(request.header("authorization"));
        assertTrue(request.path().startsWith(
                "/openai/deployments/gpt-4o/chat/completions"));
    }

    @Test
    void testCustomCompletionsPathSkipsHealthCheck() {
        config.setCompletionsPath("/custom/path");
        assertNull(parser.getHealthCheckUrl(config));
    }

    @Test
    void testDefaultCompletionsPathHasHealthCheck() {
        assertNotNull(parser.getHealthCheckUrl(config));
    }

    @Test
    void testPerRequestConfigOverride() throws Exception {
        VLMOCRConfig override = new VLMOCRConfig();
        override.setBaseUrl(server.url());
        override.setModel("override-model");
        override.setPrompt("Custom.");
        override.setMaxTokens(2048);
        override.setTimeoutSeconds(10);

        server.enqueue(new TikaTestHttpServer.MockResponse(200,
                buildChatResponse("ok", 10, 5)));

        Metadata metadata = new Metadata();
        metadata.set(Metadata.CONTENT_TYPE, "image/png");
        ParseContext ctx = new ParseContext();
        ctx.set(VLMOCRConfig.class, override);

        try (TikaInputStream tis = TikaInputStream.get(
                new ByteArrayInputStream(new byte[]{1, 2}))) {
            parser.parse(tis, new BodyContentHandler(), metadata, ctx);
        }

        JsonNode body = MAPPER.readTree(server.takeRequest().body());
        assertEquals("override-model", body.get("model").asText());
        assertEquals(2048, body.get("max_tokens").asInt());
    }

    @Test
    void testBuildRequestJson() {
        String json = parser.buildRequestJson(config, "AAAA", "image/png");
        assertTrue(json.contains("\"model\":\"test-model\""));
        assertTrue(json.contains("data:image/png;base64,AAAA"));
    }

    @Test
    void testExtractResponseText() throws Exception {
        Metadata metadata = new Metadata();
        String result = parser.extractResponseText(
                buildChatResponse("Hello", 50, 10), metadata);
        assertEquals("Hello", result);
        assertEquals("50", metadata.get(AbstractVLMParser.VLM_PROMPT_TOKENS));
    }

    @Test
    void testExtractResponseTextNoChoices() {
        assertThrows(TikaException.class,
                () -> parser.extractResponseText("{\"choices\":[]}", new Metadata()));
    }

    @Test
    void testSupportedTypes() {
        assertTrue(parser.getSupportedTypes(new ParseContext()).stream()
                .anyMatch(mt -> mt.toString().contains("png")));
    }

    @Test
    void testSupportedTypesWhenSkipped() {
        config.setSkipOcr(true);
        parser = new OpenAIVLMParser(config);
        assertEquals(0, parser.getSupportedTypes(new ParseContext()).size());
    }

    private String buildChatResponse(String content, int prompt, int completion) {
        return String.format(java.util.Locale.ROOT,
                "{\"choices\":[{\"message\":{\"content\":\"%s\"}}],"
                        + "\"usage\":{\"prompt_tokens\":%d,\"completion_tokens\":%d}}",
                content.replace("\\", "\\\\").replace("\"", "\\\"")
                        .replace("\n", "\\n"),
                prompt, completion);
    }
}