NodeCalcEvaluator.java
/**
* Copyright (c) 2018, RTE (http://www.rte-france.com)
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
* SPDX-License-Identifier: MPL-2.0
*/
package com.powsybl.timeseries.ast;
import com.powsybl.timeseries.DoubleMultiPoint;
import org.apache.commons.lang3.tuple.Pair;
import java.util.IdentityHashMap;
import java.util.Map;
/**
* @author Geoffroy Jamgotchian {@literal <geoffroy.jamgotchian at rte-france.com>}
*/
public class NodeCalcEvaluator implements NodeCalcVisitor<Double, NodeCalcEvaluator.EvalContext> {
static class EvalContext {
DoubleMultiPoint multiPoint;
Map<NodeCalc, Double> cache;
EvalContext(DoubleMultiPoint point, Map<NodeCalc, Double> cache) {
this.multiPoint = point;
this.cache = cache;
}
}
public static double eval(NodeCalc nodeCalc, DoubleMultiPoint multiPoint) {
return new NodeCalcEvaluator().evaluateWithCache(nodeCalc, multiPoint);
}
private double evaluateWithCache(NodeCalc nodeCalc, DoubleMultiPoint multiPoint) {
EvalContext evalContext = new EvalContext(multiPoint, new IdentityHashMap<>());
return nodeCalc.accept(this, evalContext, 0);
}
@Override
public Double visit(IntegerNodeCalc nodeCalc, EvalContext evalContext) {
return nodeCalc.toDouble();
}
@Override
public Double visit(FloatNodeCalc nodeCalc, EvalContext evalContext) {
return nodeCalc.toDouble();
}
@Override
public Double visit(DoubleNodeCalc nodeCalc, EvalContext evalContext) {
return nodeCalc.getValue();
}
@Override
public Double visit(BigDecimalNodeCalc nodeCalc, EvalContext evalContext) {
return nodeCalc.toDouble();
}
@Override
public Double visit(BinaryOperation nodeCalc, EvalContext evalContext, Double left, Double right) {
double leftValue = left;
double rightValue = right;
return switch (nodeCalc.getOperator()) {
case PLUS -> leftValue + rightValue;
case MINUS -> leftValue - rightValue;
case MULTIPLY -> leftValue * rightValue;
case DIVIDE -> leftValue / rightValue;
case LESS_THAN -> leftValue < rightValue ? 1d : 0d;
case LESS_THAN_OR_EQUALS_TO -> leftValue <= rightValue ? 1d : 0d;
case GREATER_THAN -> leftValue > rightValue ? 1d : 0d;
case GREATER_THAN_OR_EQUALS_TO -> leftValue >= rightValue ? 1d : 0d;
case EQUALS -> leftValue == rightValue ? 1d : 0d;
case NOT_EQUALS -> leftValue != rightValue ? 1d : 0d;
};
}
@Override
public Double visit(UnaryOperation nodeCalc, EvalContext evalContext, Double child) {
double childValue = child;
return switch (nodeCalc.getOperator()) {
case ABS -> Math.abs(childValue);
case NEGATIVE -> -childValue;
case POSITIVE -> childValue;
};
}
@Override
public NodeCalc iterate(UnaryOperation nodeCalc, EvalContext evalContext) {
return nodeCalc.getChild();
}
@Override
public Double visit(MinNodeCalc nodeCalc, EvalContext evalContext, Double child) {
double childValue = child;
return Math.min(childValue, nodeCalc.getMin());
}
@Override
public NodeCalc iterate(MinNodeCalc nodeCalc, EvalContext evalContext) {
return nodeCalc.getChild();
}
@Override
public Double visit(MaxNodeCalc nodeCalc, EvalContext evalContext, Double child) {
double childValue = child;
return Math.max(childValue, nodeCalc.getMax());
}
@Override
public NodeCalc iterate(MaxNodeCalc nodeCalc, EvalContext evalContext) {
return nodeCalc.getChild();
}
@Override
public Double visit(CachedNodeCalc nodeCalc, EvalContext evalContext, Double child) {
double childValue;
if (child == null) {
childValue = evalContext.cache.get(nodeCalc);
} else {
childValue = child;
evalContext.cache.put(nodeCalc, childValue);
}
return childValue;
}
@Override
public NodeCalc iterate(CachedNodeCalc nodeCalc, EvalContext evalContext) {
return evalContext.cache.containsKey(nodeCalc) ? null : nodeCalc.getChild();
}
@Override
public Double visit(TimeNodeCalc nodeCalc, EvalContext evalContext, Double child) {
return (double) (evalContext.multiPoint).getTime();
}
@Override
public NodeCalc iterate(TimeNodeCalc nodeCalc, EvalContext evalContext) {
return null;
}
@Override
public Double visit(TimeSeriesNumNodeCalc nodeCalc, EvalContext evalContext) {
if (evalContext.multiPoint == null) {
throw new IllegalStateException("Multi point is null");
}
return evalContext.multiPoint.getValue(nodeCalc.getTimeSeriesNum());
}
@Override
public Double visit(TimeSeriesNameNodeCalc nodeCalc, EvalContext evalContext) {
throw new IllegalStateException("NodeCalc should have been resolved before");
}
@Override
public Double visit(BinaryMinCalc nodeCalc, EvalContext evalContext, Double left, Double right) {
double leftValue = left;
double rightValue = right;
return Math.min(leftValue, rightValue);
}
@Override
public Double visit(BinaryMaxCalc nodeCalc, EvalContext evalContext, Double left, Double right) {
double leftValue = left;
double rightValue = right;
return Math.max(leftValue, rightValue);
}
@Override
public Pair<NodeCalc, NodeCalc> iterate(AbstractBinaryNodeCalc nodeCalc, EvalContext evalContext) {
return Pair.of(nodeCalc.getLeft(), nodeCalc.getRight());
}
}