TraceSessionRepositoryAspect.java
/*
* Copyright 2018-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.sleuth.instrument.session;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.cloud.sleuth.CurrentTraceContext;
import org.springframework.cloud.sleuth.Tracer;
import org.springframework.session.FindByIndexNameSessionRepository;
import org.springframework.session.ReactiveSessionRepository;
import org.springframework.session.SessionRepository;
import org.springframework.util.ReflectionUtils;
/**
* Aspect around {@link SessionRepository} and {@link ReactiveSessionRepository} method
* execution.
*
* @author Marcin Grzejszczak
* @since 3.1.0
*/
@Aspect
public class TraceSessionRepositoryAspect {
private static final Log log = LogFactory.getLog(TraceSessionRepositoryAspect.class);
private final Tracer tracer;
private final CurrentTraceContext currentTraceContext;
public TraceSessionRepositoryAspect(Tracer tracer, CurrentTraceContext currentTraceContext) {
this.tracer = tracer;
this.currentTraceContext = currentTraceContext;
}
// RedisIndexedSessionRepository
@Around("execution(public * org.springframework.session.SessionRepository.*(..))")
public Object wrapSessionRepository(ProceedingJoinPoint pjp) throws Throwable {
SessionRepository target = (SessionRepository) pjp.getTarget();
if (target instanceof TraceSessionRepository) {
return pjp.proceed();
}
target = wrapSessionRepository(target);
return callMethodOnWrappedObject(pjp, target);
}
private SessionRepository wrapSessionRepository(SessionRepository target) {
if (target instanceof FindByIndexNameSessionRepository) {
return new TraceFindByIndexNameSessionRepository(this.tracer, (FindByIndexNameSessionRepository) target);
}
return new TraceSessionRepository(this.tracer, target);
}
private <T> Object callMethodOnWrappedObject(ProceedingJoinPoint pjp, T target) throws Throwable {
Method method = getMethod(pjp, target);
if (method != null) {
if (log.isDebugEnabled()) {
log.debug("Found a corresponding method on the trace representation [" + method + "]");
}
try {
return method.invoke(target, pjp.getArgs());
}
catch (Exception ex) {
ReflectionUtils.handleReflectionException(ex);
}
}
if (log.isDebugEnabled()) {
log.debug("Method [" + pjp.getSignature().getName()
+ "] not found on the trace representation. Will run the original one.");
}
return pjp.proceed();
}
@Around("execution(public * org.springframework.session.ReactiveSessionRepository.*(..))")
public Object wrapReactiveSessionRepository(ProceedingJoinPoint pjp) throws Throwable {
ReactiveSessionRepository target = (ReactiveSessionRepository) pjp.getTarget();
if (target instanceof TraceReactiveSessionRepository) {
return pjp.proceed();
}
target = new TraceReactiveSessionRepository(this.tracer, this.currentTraceContext, target);
return callMethodOnWrappedObject(pjp, target);
}
Method getMethod(ProceedingJoinPoint pjp, Object tracingWrapper) {
MethodSignature signature = (MethodSignature) pjp.getSignature();
Method method = signature.getMethod();
Method foundMethodOnTracingWrapper = ReflectionUtils.findMethod(tracingWrapper.getClass(), method.getName(),
method.getParameterTypes());
if (foundMethodOnTracingWrapper != null) {
if (log.isDebugEnabled()) {
log.debug("Found an exact match for method execution [" + foundMethodOnTracingWrapper + "]");
}
return foundMethodOnTracingWrapper;
}
Method[] uniquePublicDeclaredMethodsOnTracingWrapper = ReflectionUtils
.getUniqueDeclaredMethods(tracingWrapper.getClass(), m -> Modifier.isPublic(m.getModifiers()));
if (uniquePublicDeclaredMethodsOnTracingWrapper.length == 0) {
return null;
}
if (log.isTraceEnabled()) {
log.trace("Will pick one of the unique declared methods ["
+ Arrays.toString(uniquePublicDeclaredMethodsOnTracingWrapper) + "] that has a name ["
+ method.getName() + "]");
}
Object[] argsOnOriginalObject = pjp.getArgs();
return Arrays.stream(uniquePublicDeclaredMethodsOnTracingWrapper)
.filter(m -> m.getName().equals(method.getName())
&& paramsAreOfSameTyperInherited(argsOnOriginalObject, m.getParameterTypes()))
.findFirst().orElse(null);
}
private boolean paramsAreOfSameTyperInherited(Object[] argsOnOriginalObject, Class<?>[] typeOnTracingWrapper) {
if (argsOnOriginalObject.length != typeOnTracingWrapper.length) {
return false;
}
for (int i = 0; i < argsOnOriginalObject.length; i++) {
Class<?> argType = argsOnOriginalObject[i].getClass();
Class<?> typeOnWrapper = typeOnTracingWrapper[i];
if (!typeOnWrapper.isAssignableFrom(argType)) {
return false;
}
}
return true;
}
}