PersistentSessionsWorker.java

/*
 * Copyright 2024 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.keycloak.models.sessions.infinispan.changes;

import org.jboss.logging.Logger;
import org.keycloak.common.util.Retry;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.utils.KeycloakModelUtils;

import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.TimeUnit;

/**
 * Run one thread per session type and drain the queues once there is an entry. Will batch entries if possible.
 *
 * @author Alexander Schwartz
 */
public class PersistentSessionsWorker {
    private static final Logger LOG = Logger.getLogger(PersistentSessionsWorker.class);

    private final KeycloakSessionFactory factory;
    private final ArrayBlockingQueue<PersistentUpdate> asyncQueuePersistentUpdate;
    private final int maxBatchSize;
    private final List<Thread> threads = new ArrayList<>();
    private volatile boolean stop;

    public PersistentSessionsWorker(KeycloakSessionFactory factory,
                                    ArrayBlockingQueue<PersistentUpdate> asyncQueuePersistentUpdate, int maxBatchSize) {
        this.factory = factory;
        this.asyncQueuePersistentUpdate = asyncQueuePersistentUpdate;
        this.maxBatchSize = maxBatchSize;
    }

    public void start() {
        threads.add(new BatchWorker(asyncQueuePersistentUpdate));
        threads.forEach(Thread::start);
    }

    private class BatchWorker extends Thread {
        private final ArrayBlockingQueue<PersistentUpdate> queue;

        public BatchWorker(ArrayBlockingQueue<PersistentUpdate> queue) {
            this.queue = queue;
        }

        public void run() {
            Thread.currentThread().setName(this.getClass().getName());
            while (!stop) {
                try {
                    process(queue);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    break;
                }
            }
        }

        private void process(ArrayBlockingQueue<PersistentUpdate> queue) throws InterruptedException {
            ArrayList<PersistentUpdate> batch = new ArrayList<>();
            // Timeout is only a backup if interrupting the worker task in the stop() method didn't work as expected because someone else swallowed the interrupted flag.
            PersistentUpdate polled = queue.poll(1, TimeUnit.SECONDS);
            if (polled != null) {
                batch.add(polled);
                queue.drainTo(batch, maxBatchSize - 1);
                try {
                    LOG.debugf("Processing %d deferred session updates.", batch.size());
                    Retry.executeWithBackoff(iteration -> {
                                if (iteration < 2) {
                                    // attempt to write whole batch in the first two attempts
                                    KeycloakModelUtils.runJobInTransaction(factory,
                                            innerSession -> batch.forEach(c -> c.perform(innerSession)));
                                    batch.forEach(PersistentUpdate::complete);
                                } else {
                                    LOG.warnf("Running single changes in iteration %d for %d entries", iteration, batch.size());
                                    ArrayList<PersistentUpdate> performedChanges = new ArrayList<>();
                                    List<Throwable> throwables = new ArrayList<>();
                                    batch.forEach(change -> {
                                        try {
                                            KeycloakModelUtils.runJobInTransaction(factory,
                                                    change::perform);
                                            change.complete();
                                            performedChanges.add(change);
                                        } catch (Throwable ex) {
                                            throwables.add(ex);
                                        }
                                    });
                                    batch.removeAll(performedChanges);
                                    if (!throwables.isEmpty()) {
                                        RuntimeException ex = new RuntimeException("unable to complete some changes");
                                        throwables.forEach(ex::addSuppressed);
                                        throw ex;
                                    }
                                }
                            },
                            Duration.of(10, ChronoUnit.SECONDS), 0);
                } catch (RuntimeException ex) {
                    batch.forEach(o -> o.fail(ex));
                    LOG.warnf(ex, "Unable to write %d deferred session updates", batch.size());
                }
            }
        }
    }

    public void stop() {
        stop = true;
        threads.forEach(Thread::interrupt);
        threads.forEach(t -> {
            try {
                t.join(TimeUnit.MINUTES.toMillis(1));
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        });
    }
}