GeminiVLMParserTest.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.parser.vlm;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.ByteArrayInputStream;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.xml.sax.helpers.DefaultHandler;
import org.apache.tika.exception.TikaException;
import org.apache.tika.http.TikaTestHttpServer;
import org.apache.tika.io.TikaInputStream;
import org.apache.tika.metadata.Metadata;
import org.apache.tika.mime.MediaType;
import org.apache.tika.parser.ParseContext;
import org.apache.tika.sax.BodyContentHandler;
public class GeminiVLMParserTest {
private static final ObjectMapper MAPPER = new ObjectMapper();
private TikaTestHttpServer server;
private GeminiVLMParser parser;
private VLMOCRConfig config;
@BeforeEach
void setUp() throws Exception {
server = new TikaTestHttpServer();
config = new VLMOCRConfig();
config.setBaseUrl(server.url());
config.setModel("gemini-2.5-flash");
config.setPrompt("Extract all text from this document.");
config.setMaxTokens(4096);
config.setTimeoutSeconds(10);
config.setApiKey("test-gemini-key");
// Queue 200 for the GET /v1beta/models health check in initialize()
server.enqueue(new TikaTestHttpServer.MockResponse(200, "{\"models\":[]}"));
parser = new GeminiVLMParser(config);
parser.initialize();
server.clearRequests(); // discard the health-check request from the log
}
@AfterEach
void tearDown() {
server.shutdown();
}
@Test
void testSuccessfulImageOcr() throws Exception {
server.enqueue(new TikaTestHttpServer.MockResponse(200,
buildGeminiResponse("Hello from Gemini!", 80, 15)));
Metadata metadata = new Metadata();
metadata.set(Metadata.CONTENT_TYPE, "image/png");
BodyContentHandler handler = new BodyContentHandler();
try (TikaInputStream tis = TikaInputStream.get(
new ByteArrayInputStream(new byte[]{(byte) 0x89, 'P', 'N', 'G'}))) {
parser.parse(tis, handler, metadata, new ParseContext());
}
assertTrue(handler.toString().contains("Hello from Gemini!"));
assertEquals("gemini-2.5-flash", metadata.get(AbstractVLMParser.VLM_MODEL));
assertEquals("80", metadata.get(AbstractVLMParser.VLM_PROMPT_TOKENS));
assertEquals("15", metadata.get(AbstractVLMParser.VLM_COMPLETION_TOKENS));
TikaTestHttpServer.RecordedRequest request = server.takeRequest();
assertTrue(request.path().contains("/v1beta/models/gemini-2.5-flash:generateContent"));
assertTrue(request.path().contains("key=test-gemini-key"));
assertEquals("POST", request.method());
JsonNode body = MAPPER.readTree(request.body());
JsonNode contents = body.get("contents");
assertNotNull(contents);
assertEquals(1, contents.size());
JsonNode parts = contents.get(0).get("parts");
assertEquals(2, parts.size());
assertEquals("Extract all text from this document.", parts.get(0).get("text").asText());
JsonNode inlineData = parts.get(1).get("inline_data");
assertNotNull(inlineData);
assertEquals("image/png", inlineData.get("mime_type").asText());
assertNotNull(inlineData.get("data").asText());
assertEquals(4096, body.get("generationConfig").get("maxOutputTokens").asInt());
}
@Test
void testPdfSupport() throws Exception {
server.enqueue(new TikaTestHttpServer.MockResponse(200,
buildGeminiResponse("PDF content extracted", 200, 50)));
Metadata metadata = new Metadata();
metadata.set(Metadata.CONTENT_TYPE, "application/pdf");
BodyContentHandler handler = new BodyContentHandler();
byte[] fakePdf = "%PDF-1.4 fake content".getBytes(java.nio.charset.StandardCharsets.UTF_8);
try (TikaInputStream tis = TikaInputStream.get(new ByteArrayInputStream(fakePdf))) {
parser.parse(tis, handler, metadata, new ParseContext());
}
assertTrue(handler.toString().contains("PDF content extracted"));
TikaTestHttpServer.RecordedRequest request = server.takeRequest();
JsonNode body = MAPPER.readTree(request.body());
JsonNode inlineData =
body.get("contents").get(0).get("parts").get(1).get("inline_data");
assertEquals("application/pdf", inlineData.get("mime_type").asText());
}
@Test
void testSupportedTypesIncludesPdf() {
assertTrue(parser.getSupportedTypes(new ParseContext())
.contains(MediaType.application("pdf")));
}
@Test
void testSupportedTypesIncludesImages() {
ParseContext ctx = new ParseContext();
assertTrue(parser.getSupportedTypes(ctx).stream()
.anyMatch(mt -> mt.toString().contains("png")));
assertTrue(parser.getSupportedTypes(ctx).stream()
.anyMatch(mt -> mt.toString().contains("heic")));
}
@Test
void testApiKeyAsQueryParam() throws Exception {
server.enqueue(new TikaTestHttpServer.MockResponse(200,
buildGeminiResponse("ok", 10, 5)));
Metadata metadata = new Metadata();
metadata.set(Metadata.CONTENT_TYPE, "image/jpeg");
try (TikaInputStream tis = TikaInputStream.get(
new ByteArrayInputStream(new byte[]{1, 2}))) {
parser.parse(tis, new BodyContentHandler(), metadata, new ParseContext());
}
TikaTestHttpServer.RecordedRequest request = server.takeRequest();
assertTrue(request.path().contains("key=test-gemini-key"),
"API key should be in query params, not header");
// Gemini does NOT use Bearer auth
assertEquals(null, request.header("authorization"));
}
@Test
void testServerError() throws Exception {
server.enqueue(new TikaTestHttpServer.MockResponse(500,
"{\"error\":{\"message\":\"internal\"}}"));
Metadata metadata = new Metadata();
metadata.set(Metadata.CONTENT_TYPE, "image/png");
assertThrows(TikaException.class, () -> {
try (TikaInputStream tis = TikaInputStream.get(
new ByteArrayInputStream(new byte[]{1, 2, 3}))) {
parser.parse(tis, new DefaultHandler(), metadata, new ParseContext());
}
});
}
@Test
void testGeminiErrorResponse() {
String errorJson = "{\"error\":{\"code\":400,\"message\":\"Invalid API key\"}}";
assertThrows(TikaException.class,
() -> parser.extractResponseText(errorJson, new Metadata()));
}
@Test
void testExtractResponseTextMultipleParts() throws Exception {
String json = "{\"candidates\":[{\"content\":{\"parts\":["
+ "{\"text\":\"Part one\"},"
+ "{\"text\":\"Part two\"}"
+ "],\"role\":\"model\"}}],"
+ "\"usageMetadata\":{\"promptTokenCount\":50,\"candidatesTokenCount\":20}}";
Metadata metadata = new Metadata();
String result = parser.extractResponseText(json, metadata);
assertEquals("Part one\nPart two", result);
assertEquals("50", metadata.get(AbstractVLMParser.VLM_PROMPT_TOKENS));
assertEquals("20", metadata.get(AbstractVLMParser.VLM_COMPLETION_TOKENS));
}
@Test
void testBuildRequestJson() {
String json = parser.buildRequestJson(config, "AAAA", "application/pdf");
assertTrue(json.contains("\"mime_type\":\"application/pdf\""));
assertTrue(json.contains("\"data\":\"AAAA\""));
assertTrue(json.contains("\"maxOutputTokens\":4096"));
assertTrue(json.contains("Extract all text from this document."));
assertTrue(!json.contains("\"messages\""));
assertTrue(!json.contains("\"max_tokens\""));
}
@Test
void testSkipOcr() throws Exception {
config.setSkipOcr(true);
parser = new GeminiVLMParser(config);
try (TikaInputStream tis = TikaInputStream.get(
new ByteArrayInputStream(new byte[]{1, 2}))) {
Metadata metadata = new Metadata();
metadata.set(Metadata.CONTENT_TYPE, "image/png");
parser.parse(tis, new BodyContentHandler(), metadata, new ParseContext());
}
assertEquals(0, server.getRequestCount());
}
@Test
void testDefaultConfig() {
GeminiVLMParser defaultParser = new GeminiVLMParser();
assertEquals("https://generativelanguage.googleapis.com", defaultParser.getBaseUrl());
assertEquals("gemini-2.5-flash", defaultParser.getModel());
}
private String buildGeminiResponse(String text, int promptTokens, int completionTokens) {
return String.format(java.util.Locale.ROOT,
"{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"%s\"}],"
+ "\"role\":\"model\"},\"finishReason\":\"STOP\"}],"
+ "\"usageMetadata\":{\"promptTokenCount\":%d,"
+ "\"candidatesTokenCount\":%d,\"totalTokenCount\":%d}}",
text.replace("\\", "\\\\").replace("\"", "\\\"")
.replace("\n", "\\n"),
promptTokens, completionTokens, promptTokens + completionTokens);
}
}