ConcurrentHashMapStorage.java

/*
 * Copyright 2020 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.keycloak.models.map.storage.chm;

import org.keycloak.models.KeycloakTransaction;
import org.keycloak.models.map.common.StringKeyConverter;
import org.keycloak.models.map.common.AbstractEntity;
import org.keycloak.models.map.common.DeepCloner;
import org.keycloak.models.map.common.EntityField;
import org.keycloak.models.map.common.HasRealmId;
import org.keycloak.models.map.common.UpdatableEntity;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.jboss.logging.Logger;
import org.keycloak.models.map.storage.CrudOperations;
import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelEntityUtil;
import org.keycloak.models.map.storage.QueryParameters;
import org.keycloak.models.map.storage.chm.MapModelCriteriaBuilder.UpdatePredicatesFunc;
import org.keycloak.models.map.storage.criteria.DefaultModelCriteria;
import org.keycloak.storage.SearchableModelField;
import java.util.function.Consumer;
import java.util.Collection;
import java.util.Set;
import java.util.function.BiFunction;

public class ConcurrentHashMapStorage<K, V extends AbstractEntity & UpdatableEntity, M, CRUD extends CrudOperations<V, M>> implements MapStorage<V, M>, KeycloakTransaction, HasRealmId {

    private final static Logger log = Logger.getLogger(ConcurrentHashMapStorage.class);

    protected boolean active;
    protected boolean rollback;
    protected final TaskMap tasks = new TaskMap();
    protected final CRUD map;
    protected final StringKeyConverter<K> keyConverter;
    protected final DeepCloner cloner;
    protected final Map<SearchableModelField<? super M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates;
    protected final EntityField<V> realmIdEntityField;
    private String realmId;
    private final boolean mapHasRealmId;

    protected static final class TaskKey {
        private final String realmId;
        private final String key;

        public TaskKey(String realmId, String key) {
            this.realmId = realmId;
            this.key = key;
        }

        private Object getRealmId() {
            return this.realmId;
        }

        public String getKey() {
            return key;
        }

        @Override
        public int hashCode() {
            return Objects.hash(this.key, this.realmId);
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (getClass() != obj.getClass()) {
                return false;
            }
            final TaskKey other = (TaskKey) obj;
            return Objects.equals(this.key, other.key) && Objects.equals(this.realmId, other.realmId);
        }

        @Override
        public String toString() {
            return key + " / " + realmId;
        }

        static TaskKey keyFor(String realmId, String id) {
            return new TaskKey(realmId, id);
        }
    }

    protected class TaskMap {

        private final Map<TaskKey, MapTaskWithValue> map = new LinkedHashMap<>();

        public boolean isEmpty() {
            return map.isEmpty();
        }

        public boolean containsKey(String key) {
            return map.containsKey(TaskKey.keyFor(realmId, key));
        }

        public MapTaskWithValue get(String key) {
            return map.get(TaskKey.keyFor(realmId, key));
        }

        public MapTaskWithValue put(String key, MapTaskWithValue value) {
            return map.put(TaskKey.keyFor(realmId, key), value);
        }

        public void clear() {
            map.clear();
        }

        public Collection<MapTaskWithValue> values() {
            return map.values();
        }

        public Set<Entry<TaskKey, MapTaskWithValue>> entrySet() {
            return map.entrySet();
        }

        public MapTaskWithValue merge(String key, MapTaskWithValue value, BiFunction<? super MapTaskWithValue, ? super MapTaskWithValue, ? extends MapTaskWithValue> remappingFunction) {
            return map.merge(TaskKey.keyFor(realmId, key), value, remappingFunction);
        }
    }

    protected enum MapOperation {
        CREATE, UPDATE, DELETE,
    }

    public ConcurrentHashMapStorage(CRUD map, StringKeyConverter<K> keyConverter, DeepCloner cloner, Map<SearchableModelField<? super M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates) {
        this(map, keyConverter, cloner, fieldPredicates, null);
    }

    public ConcurrentHashMapStorage(CRUD map, StringKeyConverter<K> keyConverter, DeepCloner cloner, Map<SearchableModelField<? super M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates, EntityField<V> realmIdEntityField) {
        this.map = map;
        this.keyConverter = keyConverter;
        this.cloner = cloner;
        this.fieldPredicates = fieldPredicates;
        this.realmIdEntityField = realmIdEntityField;
        this.mapHasRealmId = map instanceof HasRealmId;
    }

    @Override
    public void begin() {
        active = true;
    }

    @Override
    public void commit() {
        if (rollback) {
            throw new RuntimeException("Rollback only!");
        }

        final Consumer<String> setRealmId = mapHasRealmId ? ((HasRealmId) map)::setRealmId : a -> {};
        if (! tasks.isEmpty()) {
            log.tracef("Commit - %s", map);
            for (MapTaskWithValue value : tasks.values()) {
                setRealmId.accept(value.getRealmId());
                value.execute();
            }
        }
    }

    @Override
    public void rollback() {
        tasks.clear();
    }

    @Override
    public void setRollbackOnly() {
        rollback = true;
    }

    @Override
    public boolean getRollbackOnly() {
        return rollback;
    }

    @Override
    public boolean isActive() {
        return active;
    }

    private MapModelCriteriaBuilder<K, V, M> createCriteriaBuilder() {
        return new MapModelCriteriaBuilder<>(keyConverter, fieldPredicates);
    }

    /**
     * Adds a given task if not exists for the given key
     */
    protected void addTask(String key, MapTaskWithValue task) {
        log.tracef("Adding operation %s for %s @ %08x", task.getOperation(), key, System.identityHashCode(task.getValue()));

        tasks.merge(key, task, MapTaskCompose::new);
    }

    /**
     * Returns a deep clone of an entity. If the clone is already in the transaction, returns this one.
     * <p>
     * Usually used before giving an entity from a source back to the caller,
     * to prevent changing it directly in the data store, but to keep transactional properties.
     * @param origEntity Original entity
     * @return
     */
    public V registerEntityForChanges(V origEntity) {
        final String key = origEntity.getId();
        // If the entity is listed in the transaction already, return it directly
        if (tasks.containsKey(key)) {
            MapTaskWithValue current = tasks.get(key);
            return current.getValue();
        }
        // Else enlist its copy in the transaction. Never return direct reference to the underlying map
        final V res = cloner.from(origEntity);
        return updateIfChanged(res, e -> e.isUpdated());
    }

    @Override
    public V read(String sKey) {
        try { 
            // TODO: Consider using Optional rather than handling NPE
            final V entity = read(sKey, map::read);
            if (entity == null) {
                log.debugf("Could not read object for key %s", sKey);
                return null;
            }
            return postProcess(registerEntityForChanges(entity));
        } catch (NullPointerException ex) {
            return null;
        }
    }

    private V read(String key, Function<String, V> defaultValueFunc) {
        MapTaskWithValue current = tasks.get(key);
        // If the key exists, then it has entered the "tasks" after bulk delete that could have 
        // removed it, so looking through bulk deletes is irrelevant
        if (tasks.containsKey(key)) {
            return current.getValue();
        }

        // If the key does not exist, then it would be read fresh from the storage, but then it
        // could have been removed by some bulk delete in the existing tasks. Check it.
        final V value = defaultValueFunc.apply(key);
        for (MapTaskWithValue val : tasks.values()) {
            if (val instanceof ConcurrentHashMapStorage.BulkDeleteOperation) {
                final BulkDeleteOperation delOp = (BulkDeleteOperation) val;
                if (! delOp.getFilterForNonDeletedObjects().test(value)) {
                    return null;
                }
            }
        }

        return value;
    }

    /**
     * Returns the stream of records that match given criteria and includes changes made in this transaction, i.e.
     * the result contains updates and excludes records that have been deleted in this transaction.
     *
     * @param queryParameters
     * @return
     */
    @Override
    public Stream<V> read(QueryParameters<M> queryParameters) {
        DefaultModelCriteria<M> mcb = queryParameters.getModelCriteriaBuilder();
        MapModelCriteriaBuilder<K,V,M> mapMcb = mcb.flashToModelCriteriaBuilder(createCriteriaBuilder());

        Predicate<? super V> filterOutAllBulkDeletedObjects = tasks.values().stream()
          .filter(BulkDeleteOperation.class::isInstance)
          .map(BulkDeleteOperation.class::cast)
          .map(BulkDeleteOperation::getFilterForNonDeletedObjects)
          .reduce(Predicate::and)
          .orElse(v -> true);

        Stream<V> updatedAndNotRemovedObjectsStream = this.map.read(queryParameters)
          .filter(filterOutAllBulkDeletedObjects)
          .map(this::getUpdated)      // If the object has been removed, store.get will return null, otherwise it will return me.getValue()
          .filter(Objects::nonNull)
          .map(this::registerEntityForChanges);

        updatedAndNotRemovedObjectsStream = postProcess(updatedAndNotRemovedObjectsStream);

        if (mapMcb != null) {
            // Add explicit filtering for the case when the map returns raw stream of untested values (ie. realize sequential scan)
            updatedAndNotRemovedObjectsStream = updatedAndNotRemovedObjectsStream
              .filter(e -> mapMcb.getKeyFilter().test(keyConverter.fromStringSafe(e.getId())))
              .filter(mapMcb.getEntityFilter());
        }

        // In case of created values stored in MapKeycloakTransaction, we need filter those according to the filter
        Stream<V> res = mapMcb == null
          ? updatedAndNotRemovedObjectsStream
          : Stream.concat(
              createdValuesStream(mapMcb.getKeyFilter(), mapMcb.getEntityFilter()),
              updatedAndNotRemovedObjectsStream
            );

        if (!queryParameters.getOrderBy().isEmpty()) {
            res = res.sorted(MapFieldPredicates.getComparator(queryParameters.getOrderBy().stream()));
        }


        return res;
    }

    @Override
    public long getCount(QueryParameters<M> queryParameters) {
        return read(queryParameters).count();
    }

    private V getUpdated(V orig) {
        MapTaskWithValue current = orig == null ? null : tasks.get(orig.getId());
        return current == null ? orig : current.getValue();
    }

    @Override
    public V create(V value) {
        String key = map.determineKeyFromValue(value);
        if (key == null) {
            K newKey = keyConverter.yieldNewUniqueKey();
            key = keyConverter.keyToString(newKey);
            value = cloner.from(key, value);
        } else if (! key.equals(value.getId())) {
            value = cloner.from(key, value);
        } else {
            value = cloner.from(value);
        }
        addTask(key, new CreateOperation(value));
        return postProcess(value);
    }

    public V updateIfChanged(V value, Predicate<V> shouldPut) {
        String key = value.getId();
        log.tracef("Adding operation UPDATE_IF_CHANGED for %s @ %08x", key, System.identityHashCode(value));

        String taskKey = key;
        MapTaskWithValue op = new MapTaskWithValue(value) {
            @Override
            public void execute() {
                if (shouldPut.test(getValue())) {
                    map.update(getValue());
                }
            }
            @Override public MapOperation getOperation() { return MapOperation.UPDATE; }
        };
        return tasks.merge(taskKey, op, this::merge).getValue();
    }

    @Override
    public boolean delete(String key) {
        tasks.merge(key, new DeleteOperation(key), this::merge);
        return true;
    }

    @Override
    public long delete(QueryParameters<M> queryParameters) {
        log.tracef("Adding operation DELETE_BULK");

        K artificialKey = keyConverter.yieldNewUniqueKey();

        // Remove all tasks that create / update / delete objects deleted by the bulk removal.
        final BulkDeleteOperation bdo = new BulkDeleteOperation(queryParameters);
        Predicate<V> filterForNonDeletedObjects = bdo.getFilterForNonDeletedObjects();
        long res = 0;
        for (Iterator<Entry<TaskKey, MapTaskWithValue>> it = tasks.entrySet().iterator(); it.hasNext();) {
            Entry<TaskKey, MapTaskWithValue> me = it.next();
            if (! filterForNonDeletedObjects.test(me.getValue().getValue())) {
                log.tracef(" [DELETE_BULK] removing %s", me.getKey());
                it.remove();
                res++;
            }
        }

        tasks.put(keyConverter.keyToString(artificialKey), bdo);

        return res + bdo.getCount();
    }

    @Override
    public boolean exists(String key) {
        if (tasks.containsKey(key)) {
            MapTaskWithValue o = tasks.get(key);
            return o.getValue() != null;
        }

        // Check if there is a bulk delete operation in which case read the full entity
        for (MapTaskWithValue val : tasks.values()) {
            if (val instanceof ConcurrentHashMapStorage.BulkDeleteOperation) {
                return read(key) != null;
            }
        }

        return map.exists(key);
    }

    private Stream<V> createdValuesStream(Predicate<? super K> keyFilter, Predicate<? super V> entityFilter) {
        return this.tasks.entrySet().stream()
          .filter(me -> Objects.equals(realmId, me.getKey().getRealmId()) && keyFilter.test(keyConverter.fromStringSafe(me.getKey().getKey())))
          .map(Map.Entry::getValue)
          .filter(v -> v.containsCreate() && ! v.isReplace())
          .map(MapTaskWithValue::getValue)
          .filter(Objects::nonNull)
          .filter(entityFilter)
          // make a snapshot
          .collect(Collectors.toList()).stream();
    }

    private MapTaskWithValue merge(MapTaskWithValue oldValue, MapTaskWithValue newValue) {
        switch (newValue.getOperation()) {
            case DELETE:
                return newValue;
            default:
                return new MapTaskCompose(oldValue, newValue);
        }
    }

    protected abstract class MapTaskWithValue {
        protected final V value;
        private final String realmId;

        public MapTaskWithValue(V value) {
            this.value = value;
            this.realmId = ConcurrentHashMapStorage.this.realmId;
        }

        public V getValue() {
            return value;
        }

        public boolean containsCreate() {
            return MapOperation.CREATE == getOperation();
        }

        public boolean containsRemove() {
            return MapOperation.DELETE == getOperation();
        }

        public boolean isReplace() {
            return false;
        }

        public String getRealmId() {
            return realmId;
        }

        public abstract MapOperation getOperation();
        public abstract void execute();
   }

    private class MapTaskCompose extends MapTaskWithValue {

        private final MapTaskWithValue oldValue;
        private final MapTaskWithValue newValue;

        public MapTaskCompose(MapTaskWithValue oldValue, MapTaskWithValue newValue) {
            super(null);
            this.oldValue = oldValue;
            this.newValue = newValue;
        }

        @Override
        public void execute() {
            oldValue.execute();
            newValue.execute();
        }

        @Override
        public V getValue() {
            return newValue.getValue();
        }

        @Override
        public MapOperation getOperation() {
            return null;
        }

        @Override
        public boolean containsCreate() {
            return oldValue.containsCreate() || newValue.containsCreate();
        }

        @Override
        public boolean containsRemove() {
            return oldValue.containsRemove() || newValue.containsRemove();
        }

        @Override
        public boolean isReplace() {
            return (newValue.getOperation() == MapOperation.CREATE && oldValue.containsRemove()) ||
              (oldValue instanceof ConcurrentHashMapStorage.MapTaskCompose && ((MapTaskCompose) oldValue).isReplace());
        }
    }

    private class CreateOperation extends MapTaskWithValue {
        public CreateOperation(V value) {
            super(value);
        }

        @Override public void execute() { map.create(getValue()); }
        @Override public MapOperation getOperation() { return MapOperation.CREATE; }
    }

    private class DeleteOperation extends MapTaskWithValue {
        private final String key;

        public DeleteOperation(String key) {
            super(null);
            this.key = key;
        }

        @Override public void execute() { map.delete(key); }
        @Override public MapOperation getOperation() { return MapOperation.DELETE; }
    }

    private class BulkDeleteOperation extends MapTaskWithValue {

        private final QueryParameters<M> queryParameters;

        public BulkDeleteOperation(QueryParameters<M> queryParameters) {
            super(null);
            this.queryParameters = queryParameters;
        }

        @Override
        @SuppressWarnings("unchecked")
        public void execute() {
            map.delete(queryParameters);
        }

        public Predicate<V> getFilterForNonDeletedObjects() {
            DefaultModelCriteria<M> mcb = queryParameters.getModelCriteriaBuilder();
            MapModelCriteriaBuilder<K,V,M> mmcb = mcb.flashToModelCriteriaBuilder(createCriteriaBuilder());
            
            Predicate<? super V> entityFilter = mmcb.getEntityFilter();
            Predicate<? super K> keyFilter = mmcb.getKeyFilter();
            return v -> v == null || ! (keyFilter.test(keyConverter.fromStringSafe(v.getId())) && entityFilter.test(v));
        }

        @Override
        public MapOperation getOperation() {
            return MapOperation.DELETE;
        }

        private long getCount() {
            return map.getCount(queryParameters);
        }
    }

    @Override
    public String getRealmId() {
        if (mapHasRealmId) {
            return ((HasRealmId) map).getRealmId();
        }
        return null;
    }

    @Override
    @SuppressWarnings("unchecked")
    public void setRealmId(String realmId) {
        if (mapHasRealmId) {
            ((HasRealmId) map).setRealmId(realmId);
            this.realmId = realmId;
        } else {
            this.realmId = null;
        }
    }

    private V postProcess(V value) {
        return (realmId == null || value == null)
          ? value
          : ModelEntityUtil.supplyReadOnlyFieldValueIfUnset(value, realmIdEntityField, realmId);
    }

    private Stream<V> postProcess(Stream<V> stream) {
        if (this.realmId == null) {
            return stream;
        }

        String localRealmId = this.realmId;
        return stream.map((V value) -> ModelEntityUtil.supplyReadOnlyFieldValueIfUnset(value, realmIdEntityField, localRealmId));
    }

}