ConcurrentHashMapStorage.java
/*
* Copyright 2020 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.map.storage.chm;
import org.keycloak.models.KeycloakTransaction;
import org.keycloak.models.map.common.StringKeyConverter;
import org.keycloak.models.map.common.AbstractEntity;
import org.keycloak.models.map.common.DeepCloner;
import org.keycloak.models.map.common.EntityField;
import org.keycloak.models.map.common.HasRealmId;
import org.keycloak.models.map.common.UpdatableEntity;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.jboss.logging.Logger;
import org.keycloak.models.map.storage.CrudOperations;
import org.keycloak.models.map.storage.MapStorage;
import org.keycloak.models.map.storage.ModelEntityUtil;
import org.keycloak.models.map.storage.QueryParameters;
import org.keycloak.models.map.storage.chm.MapModelCriteriaBuilder.UpdatePredicatesFunc;
import org.keycloak.models.map.storage.criteria.DefaultModelCriteria;
import org.keycloak.storage.SearchableModelField;
import java.util.function.Consumer;
import java.util.Collection;
import java.util.Set;
import java.util.function.BiFunction;
public class ConcurrentHashMapStorage<K, V extends AbstractEntity & UpdatableEntity, M, CRUD extends CrudOperations<V, M>> implements MapStorage<V, M>, KeycloakTransaction, HasRealmId {
private final static Logger log = Logger.getLogger(ConcurrentHashMapStorage.class);
protected boolean active;
protected boolean rollback;
protected final TaskMap tasks = new TaskMap();
protected final CRUD map;
protected final StringKeyConverter<K> keyConverter;
protected final DeepCloner cloner;
protected final Map<SearchableModelField<? super M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates;
protected final EntityField<V> realmIdEntityField;
private String realmId;
private final boolean mapHasRealmId;
protected static final class TaskKey {
private final String realmId;
private final String key;
public TaskKey(String realmId, String key) {
this.realmId = realmId;
this.key = key;
}
private Object getRealmId() {
return this.realmId;
}
public String getKey() {
return key;
}
@Override
public int hashCode() {
return Objects.hash(this.key, this.realmId);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
final TaskKey other = (TaskKey) obj;
return Objects.equals(this.key, other.key) && Objects.equals(this.realmId, other.realmId);
}
@Override
public String toString() {
return key + " / " + realmId;
}
static TaskKey keyFor(String realmId, String id) {
return new TaskKey(realmId, id);
}
}
protected class TaskMap {
private final Map<TaskKey, MapTaskWithValue> map = new LinkedHashMap<>();
public boolean isEmpty() {
return map.isEmpty();
}
public boolean containsKey(String key) {
return map.containsKey(TaskKey.keyFor(realmId, key));
}
public MapTaskWithValue get(String key) {
return map.get(TaskKey.keyFor(realmId, key));
}
public MapTaskWithValue put(String key, MapTaskWithValue value) {
return map.put(TaskKey.keyFor(realmId, key), value);
}
public void clear() {
map.clear();
}
public Collection<MapTaskWithValue> values() {
return map.values();
}
public Set<Entry<TaskKey, MapTaskWithValue>> entrySet() {
return map.entrySet();
}
public MapTaskWithValue merge(String key, MapTaskWithValue value, BiFunction<? super MapTaskWithValue, ? super MapTaskWithValue, ? extends MapTaskWithValue> remappingFunction) {
return map.merge(TaskKey.keyFor(realmId, key), value, remappingFunction);
}
}
protected enum MapOperation {
CREATE, UPDATE, DELETE,
}
public ConcurrentHashMapStorage(CRUD map, StringKeyConverter<K> keyConverter, DeepCloner cloner, Map<SearchableModelField<? super M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates) {
this(map, keyConverter, cloner, fieldPredicates, null);
}
public ConcurrentHashMapStorage(CRUD map, StringKeyConverter<K> keyConverter, DeepCloner cloner, Map<SearchableModelField<? super M>, UpdatePredicatesFunc<K, V, M>> fieldPredicates, EntityField<V> realmIdEntityField) {
this.map = map;
this.keyConverter = keyConverter;
this.cloner = cloner;
this.fieldPredicates = fieldPredicates;
this.realmIdEntityField = realmIdEntityField;
this.mapHasRealmId = map instanceof HasRealmId;
}
@Override
public void begin() {
active = true;
}
@Override
public void commit() {
if (rollback) {
throw new RuntimeException("Rollback only!");
}
final Consumer<String> setRealmId = mapHasRealmId ? ((HasRealmId) map)::setRealmId : a -> {};
if (! tasks.isEmpty()) {
log.tracef("Commit - %s", map);
for (MapTaskWithValue value : tasks.values()) {
setRealmId.accept(value.getRealmId());
value.execute();
}
}
}
@Override
public void rollback() {
tasks.clear();
}
@Override
public void setRollbackOnly() {
rollback = true;
}
@Override
public boolean getRollbackOnly() {
return rollback;
}
@Override
public boolean isActive() {
return active;
}
private MapModelCriteriaBuilder<K, V, M> createCriteriaBuilder() {
return new MapModelCriteriaBuilder<>(keyConverter, fieldPredicates);
}
/**
* Adds a given task if not exists for the given key
*/
protected void addTask(String key, MapTaskWithValue task) {
log.tracef("Adding operation %s for %s @ %08x", task.getOperation(), key, System.identityHashCode(task.getValue()));
tasks.merge(key, task, MapTaskCompose::new);
}
/**
* Returns a deep clone of an entity. If the clone is already in the transaction, returns this one.
* <p>
* Usually used before giving an entity from a source back to the caller,
* to prevent changing it directly in the data store, but to keep transactional properties.
* @param origEntity Original entity
* @return
*/
public V registerEntityForChanges(V origEntity) {
final String key = origEntity.getId();
// If the entity is listed in the transaction already, return it directly
if (tasks.containsKey(key)) {
MapTaskWithValue current = tasks.get(key);
return current.getValue();
}
// Else enlist its copy in the transaction. Never return direct reference to the underlying map
final V res = cloner.from(origEntity);
return updateIfChanged(res, e -> e.isUpdated());
}
@Override
public V read(String sKey) {
try {
// TODO: Consider using Optional rather than handling NPE
final V entity = read(sKey, map::read);
if (entity == null) {
log.debugf("Could not read object for key %s", sKey);
return null;
}
return postProcess(registerEntityForChanges(entity));
} catch (NullPointerException ex) {
return null;
}
}
private V read(String key, Function<String, V> defaultValueFunc) {
MapTaskWithValue current = tasks.get(key);
// If the key exists, then it has entered the "tasks" after bulk delete that could have
// removed it, so looking through bulk deletes is irrelevant
if (tasks.containsKey(key)) {
return current.getValue();
}
// If the key does not exist, then it would be read fresh from the storage, but then it
// could have been removed by some bulk delete in the existing tasks. Check it.
final V value = defaultValueFunc.apply(key);
for (MapTaskWithValue val : tasks.values()) {
if (val instanceof ConcurrentHashMapStorage.BulkDeleteOperation) {
final BulkDeleteOperation delOp = (BulkDeleteOperation) val;
if (! delOp.getFilterForNonDeletedObjects().test(value)) {
return null;
}
}
}
return value;
}
/**
* Returns the stream of records that match given criteria and includes changes made in this transaction, i.e.
* the result contains updates and excludes records that have been deleted in this transaction.
*
* @param queryParameters
* @return
*/
@Override
public Stream<V> read(QueryParameters<M> queryParameters) {
DefaultModelCriteria<M> mcb = queryParameters.getModelCriteriaBuilder();
MapModelCriteriaBuilder<K,V,M> mapMcb = mcb.flashToModelCriteriaBuilder(createCriteriaBuilder());
Predicate<? super V> filterOutAllBulkDeletedObjects = tasks.values().stream()
.filter(BulkDeleteOperation.class::isInstance)
.map(BulkDeleteOperation.class::cast)
.map(BulkDeleteOperation::getFilterForNonDeletedObjects)
.reduce(Predicate::and)
.orElse(v -> true);
Stream<V> updatedAndNotRemovedObjectsStream = this.map.read(queryParameters)
.filter(filterOutAllBulkDeletedObjects)
.map(this::getUpdated) // If the object has been removed, store.get will return null, otherwise it will return me.getValue()
.filter(Objects::nonNull)
.map(this::registerEntityForChanges);
updatedAndNotRemovedObjectsStream = postProcess(updatedAndNotRemovedObjectsStream);
if (mapMcb != null) {
// Add explicit filtering for the case when the map returns raw stream of untested values (ie. realize sequential scan)
updatedAndNotRemovedObjectsStream = updatedAndNotRemovedObjectsStream
.filter(e -> mapMcb.getKeyFilter().test(keyConverter.fromStringSafe(e.getId())))
.filter(mapMcb.getEntityFilter());
}
// In case of created values stored in MapKeycloakTransaction, we need filter those according to the filter
Stream<V> res = mapMcb == null
? updatedAndNotRemovedObjectsStream
: Stream.concat(
createdValuesStream(mapMcb.getKeyFilter(), mapMcb.getEntityFilter()),
updatedAndNotRemovedObjectsStream
);
if (!queryParameters.getOrderBy().isEmpty()) {
res = res.sorted(MapFieldPredicates.getComparator(queryParameters.getOrderBy().stream()));
}
return res;
}
@Override
public long getCount(QueryParameters<M> queryParameters) {
return read(queryParameters).count();
}
private V getUpdated(V orig) {
MapTaskWithValue current = orig == null ? null : tasks.get(orig.getId());
return current == null ? orig : current.getValue();
}
@Override
public V create(V value) {
String key = map.determineKeyFromValue(value);
if (key == null) {
K newKey = keyConverter.yieldNewUniqueKey();
key = keyConverter.keyToString(newKey);
value = cloner.from(key, value);
} else if (! key.equals(value.getId())) {
value = cloner.from(key, value);
} else {
value = cloner.from(value);
}
addTask(key, new CreateOperation(value));
return postProcess(value);
}
public V updateIfChanged(V value, Predicate<V> shouldPut) {
String key = value.getId();
log.tracef("Adding operation UPDATE_IF_CHANGED for %s @ %08x", key, System.identityHashCode(value));
String taskKey = key;
MapTaskWithValue op = new MapTaskWithValue(value) {
@Override
public void execute() {
if (shouldPut.test(getValue())) {
map.update(getValue());
}
}
@Override public MapOperation getOperation() { return MapOperation.UPDATE; }
};
return tasks.merge(taskKey, op, this::merge).getValue();
}
@Override
public boolean delete(String key) {
tasks.merge(key, new DeleteOperation(key), this::merge);
return true;
}
@Override
public long delete(QueryParameters<M> queryParameters) {
log.tracef("Adding operation DELETE_BULK");
K artificialKey = keyConverter.yieldNewUniqueKey();
// Remove all tasks that create / update / delete objects deleted by the bulk removal.
final BulkDeleteOperation bdo = new BulkDeleteOperation(queryParameters);
Predicate<V> filterForNonDeletedObjects = bdo.getFilterForNonDeletedObjects();
long res = 0;
for (Iterator<Entry<TaskKey, MapTaskWithValue>> it = tasks.entrySet().iterator(); it.hasNext();) {
Entry<TaskKey, MapTaskWithValue> me = it.next();
if (! filterForNonDeletedObjects.test(me.getValue().getValue())) {
log.tracef(" [DELETE_BULK] removing %s", me.getKey());
it.remove();
res++;
}
}
tasks.put(keyConverter.keyToString(artificialKey), bdo);
return res + bdo.getCount();
}
@Override
public boolean exists(String key) {
if (tasks.containsKey(key)) {
MapTaskWithValue o = tasks.get(key);
return o.getValue() != null;
}
// Check if there is a bulk delete operation in which case read the full entity
for (MapTaskWithValue val : tasks.values()) {
if (val instanceof ConcurrentHashMapStorage.BulkDeleteOperation) {
return read(key) != null;
}
}
return map.exists(key);
}
private Stream<V> createdValuesStream(Predicate<? super K> keyFilter, Predicate<? super V> entityFilter) {
return this.tasks.entrySet().stream()
.filter(me -> Objects.equals(realmId, me.getKey().getRealmId()) && keyFilter.test(keyConverter.fromStringSafe(me.getKey().getKey())))
.map(Map.Entry::getValue)
.filter(v -> v.containsCreate() && ! v.isReplace())
.map(MapTaskWithValue::getValue)
.filter(Objects::nonNull)
.filter(entityFilter)
// make a snapshot
.collect(Collectors.toList()).stream();
}
private MapTaskWithValue merge(MapTaskWithValue oldValue, MapTaskWithValue newValue) {
switch (newValue.getOperation()) {
case DELETE:
return newValue;
default:
return new MapTaskCompose(oldValue, newValue);
}
}
protected abstract class MapTaskWithValue {
protected final V value;
private final String realmId;
public MapTaskWithValue(V value) {
this.value = value;
this.realmId = ConcurrentHashMapStorage.this.realmId;
}
public V getValue() {
return value;
}
public boolean containsCreate() {
return MapOperation.CREATE == getOperation();
}
public boolean containsRemove() {
return MapOperation.DELETE == getOperation();
}
public boolean isReplace() {
return false;
}
public String getRealmId() {
return realmId;
}
public abstract MapOperation getOperation();
public abstract void execute();
}
private class MapTaskCompose extends MapTaskWithValue {
private final MapTaskWithValue oldValue;
private final MapTaskWithValue newValue;
public MapTaskCompose(MapTaskWithValue oldValue, MapTaskWithValue newValue) {
super(null);
this.oldValue = oldValue;
this.newValue = newValue;
}
@Override
public void execute() {
oldValue.execute();
newValue.execute();
}
@Override
public V getValue() {
return newValue.getValue();
}
@Override
public MapOperation getOperation() {
return null;
}
@Override
public boolean containsCreate() {
return oldValue.containsCreate() || newValue.containsCreate();
}
@Override
public boolean containsRemove() {
return oldValue.containsRemove() || newValue.containsRemove();
}
@Override
public boolean isReplace() {
return (newValue.getOperation() == MapOperation.CREATE && oldValue.containsRemove()) ||
(oldValue instanceof ConcurrentHashMapStorage.MapTaskCompose && ((MapTaskCompose) oldValue).isReplace());
}
}
private class CreateOperation extends MapTaskWithValue {
public CreateOperation(V value) {
super(value);
}
@Override public void execute() { map.create(getValue()); }
@Override public MapOperation getOperation() { return MapOperation.CREATE; }
}
private class DeleteOperation extends MapTaskWithValue {
private final String key;
public DeleteOperation(String key) {
super(null);
this.key = key;
}
@Override public void execute() { map.delete(key); }
@Override public MapOperation getOperation() { return MapOperation.DELETE; }
}
private class BulkDeleteOperation extends MapTaskWithValue {
private final QueryParameters<M> queryParameters;
public BulkDeleteOperation(QueryParameters<M> queryParameters) {
super(null);
this.queryParameters = queryParameters;
}
@Override
@SuppressWarnings("unchecked")
public void execute() {
map.delete(queryParameters);
}
public Predicate<V> getFilterForNonDeletedObjects() {
DefaultModelCriteria<M> mcb = queryParameters.getModelCriteriaBuilder();
MapModelCriteriaBuilder<K,V,M> mmcb = mcb.flashToModelCriteriaBuilder(createCriteriaBuilder());
Predicate<? super V> entityFilter = mmcb.getEntityFilter();
Predicate<? super K> keyFilter = mmcb.getKeyFilter();
return v -> v == null || ! (keyFilter.test(keyConverter.fromStringSafe(v.getId())) && entityFilter.test(v));
}
@Override
public MapOperation getOperation() {
return MapOperation.DELETE;
}
private long getCount() {
return map.getCount(queryParameters);
}
}
@Override
public String getRealmId() {
if (mapHasRealmId) {
return ((HasRealmId) map).getRealmId();
}
return null;
}
@Override
@SuppressWarnings("unchecked")
public void setRealmId(String realmId) {
if (mapHasRealmId) {
((HasRealmId) map).setRealmId(realmId);
this.realmId = realmId;
} else {
this.realmId = null;
}
}
private V postProcess(V value) {
return (realmId == null || value == null)
? value
: ModelEntityUtil.supplyReadOnlyFieldValueIfUnset(value, realmIdEntityField, realmId);
}
private Stream<V> postProcess(Stream<V> stream) {
if (this.realmId == null) {
return stream;
}
String localRealmId = this.realmId;
return stream.map((V value) -> ModelEntityUtil.supplyReadOnlyFieldValueIfUnset(value, realmIdEntityField, localRealmId));
}
}