ThemeDataSetGeneratorTest.java

/*******************************************************************************
 * Copyright (c) 2025 Eclipse RDF4J contributors.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Distribution License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/org/documents/edl-v10.php.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 *******************************************************************************/
// Some portions generated by Codex
package org.eclipse.rdf4j.benchmark.rio.util;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import org.eclipse.rdf4j.model.IRI;
import org.eclipse.rdf4j.model.Model;
import org.eclipse.rdf4j.model.util.Values;
import org.eclipse.rdf4j.model.vocabulary.RDF;
import org.junit.jupiter.api.Test;

class ThemeDataSetGeneratorTest {

	private static final String BASE = "http://example.com/theme/";
	private static final long DEFAULT_SEED = 42L;
	private static final long JITTER_SEED_XOR = 0x9E3779B97F4A7C15L;
	private static final int DEFAULT_USER_COUNT = 20000;
	private static final int DEFAULT_TAG_COUNT = 50;
	private static final int[] DEFAULT_CLIQUE_SIZES = new int[] { 3, 4, 5, 6 };

	@Test
	void medicalRecordsGeneratorProducesPatients() throws Exception {
		Model model = generateModel("medicalConfig", "generateMedicalRecords");
		IRI patientType = Values.iri(BASE, "medical/Patient");
		assertFalse(model.isEmpty());
		assertTrue(model.contains(null, RDF.TYPE, patientType));
	}

	@Test
	void socialMediaGeneratorProducesUsers() throws Exception {
		Model model = generateModel("socialMediaConfig", "generateSocialMedia");
		IRI userType = Values.iri(BASE, "social/User");
		assertFalse(model.isEmpty());
		assertTrue(model.contains(null, RDF.TYPE, userType));
	}

	@Test
	void socialMediaGeneratorProducesCliques() throws Exception {
		Model model = generateModel("socialMediaConfig", "generateSocialMedia");
		IRI follows = Values.iri(BASE, "social/follows");
		int[] cliqueSizes = jitterCliqueSizes();
		int startId = 0;
		for (int size : cliqueSizes) {
			assertClique(model, follows, startId, size);
			startId += size;
		}
	}

	@Test
	void libraryGeneratorProducesBooks() throws Exception {
		Model model = generateModel("libraryConfig", "generateLibrary");
		IRI bookType = Values.iri(BASE, "library/Book");
		assertFalse(model.isEmpty());
		assertTrue(model.contains(null, RDF.TYPE, bookType));
	}

	@Test
	void engineeringGeneratorProducesComponents() throws Exception {
		Model model = generateModel("engineeringConfig", "generateEngineering");
		IRI componentType = Values.iri(BASE, "engineering/Component");
		assertFalse(model.isEmpty());
		assertTrue(model.contains(null, RDF.TYPE, componentType));
	}

	@Test
	void highlyConnectedGeneratorProducesNodes() throws Exception {
		Model model = generateModel("highlyConnectedConfig", "generateHighlyConnected");
		IRI nodeType = Values.iri(BASE, "connected/Node");
		assertFalse(model.isEmpty());
		assertTrue(model.contains(null, RDF.TYPE, nodeType));
	}

	@Test
	void trainGeneratorProducesOperationalPoints() throws Exception {
		Model model = generateModel("trainConfig", "generateTrain");
		IRI operationalPointType = Values.iri(BASE, "train/OperationalPoint");
		IRI sectionType = Values.iri(BASE, "train/SectionOfLine");
		assertFalse(model.isEmpty());
		assertTrue(model.contains(null, RDF.TYPE, operationalPointType));
		assertTrue(model.contains(null, RDF.TYPE, sectionType));
	}

	@Test
	void electricalGridGeneratorProducesSubstations() throws Exception {
		Model model = generateModel("electricalGridConfig", "generateElectricalGrid");
		IRI substationType = Values.iri(BASE, "grid/Substation");
		assertFalse(model.isEmpty());
		assertTrue(model.contains(null, RDF.TYPE, substationType));
	}

	@Test
	void pharmaGeneratorProducesDrugsAndTrials() throws Exception {
		Model model = generateModel("pharmaConfig", "generatePharma");
		IRI drugType = Values.iri(BASE, "pharma/Drug");
		IRI trialType = Values.iri(BASE, "pharma/ClinicalTrial");
		IRI sideEffectType = Values.iri(BASE, "pharma/SideEffect");
		assertFalse(model.isEmpty());
		assertTrue(model.contains(null, RDF.TYPE, drugType));
		assertTrue(model.contains(null, RDF.TYPE, trialType));
		assertTrue(model.contains(null, RDF.TYPE, sideEffectType));
	}

	private static Model generateModel(String configMethodName, String generateMethodName) throws Exception {
		Class<?> generator = Class.forName("org.eclipse.rdf4j.benchmark.rio.util.ThemeDataSetGenerator");
		Object config = generator.getMethod(configMethodName).invoke(null);
		Method generate = generator.getMethod(generateMethodName, config.getClass());
		return (Model) generate.invoke(null, config);
	}

	private static void assertClique(Model model, IRI follows, int startId, int size) {
		for (int i = 0; i < size; i++) {
			IRI source = socialUser(startId + i);
			for (int j = 0; j < size; j++) {
				if (i == j) {
					continue;
				}
				IRI target = socialUser(startId + j);
				assertTrue(model.contains(source, follows, target),
						() -> "Missing follows edge " + source + " -> " + target);
			}
		}
	}

	private static IRI socialUser(int id) {
		return Values.iri(BASE, "social/user/" + id);
	}

	private static int[] jitterCliqueSizes() {
		Random jitter = new Random(DEFAULT_SEED ^ JITTER_SEED_XOR);
		int userCount = jitterInt(jitter, DEFAULT_USER_COUNT, 1);
		jitterInt(jitter, DEFAULT_TAG_COUNT, 1);
		return jitterCliqueSizes(jitter, DEFAULT_CLIQUE_SIZES, userCount);
	}

	private static int jitterInt(Random random, int base, int minValue) {
		int delta = base / 2;
		int min = Math.max(minValue, base - delta);
		int max = Math.max(min, base + delta);
		if (min == max) {
			return min;
		}
		return min + random.nextInt(max - min + 1);
	}

	private static int[] jitterCliqueSizes(Random random, int[] baseSizes, int maxTotal) {
		List<Integer> sizes = new ArrayList<>(baseSizes.length);
		int remaining = maxTotal;
		for (int base : baseSizes) {
			if (remaining < 2) {
				break;
			}
			int size = jitterInt(random, base, 2);
			if (size > remaining) {
				size = remaining;
			}
			if (size < 2) {
				break;
			}
			sizes.add(size);
			remaining -= size;
		}
		int[] result = new int[sizes.size()];
		for (int i = 0; i < sizes.size(); i++) {
			result[i] = sizes.get(i);
		}
		return result;
	}
}