CGLIBEnhancedConverter.java
/*
* Copyright (C) 2006, 2007, 2008, 2010, 2011, 2013, 2014, 2015, 2016, 2018, 2020 XStream Committers.
* All rights reserved.
*
* The software in this package is published under the terms of the BSD
* style license a copy of which has been included with this distribution in
* the LICENSE.txt file.
*
* Created on 13. April 2006 by Joerg Schaible
*/
package com.thoughtworks.xstream.converters.reflection;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.thoughtworks.xstream.converters.ConversionException;
import com.thoughtworks.xstream.converters.MarshallingContext;
import com.thoughtworks.xstream.converters.UnmarshallingContext;
import com.thoughtworks.xstream.core.ClassLoaderReference;
import com.thoughtworks.xstream.io.HierarchicalStreamReader;
import com.thoughtworks.xstream.io.HierarchicalStreamWriter;
import com.thoughtworks.xstream.mapper.CGLIBMapper;
import com.thoughtworks.xstream.mapper.Mapper;
import net.sf.cglib.proxy.Callback;
import net.sf.cglib.proxy.CallbackFilter;
import net.sf.cglib.proxy.Enhancer;
import net.sf.cglib.proxy.Factory;
import net.sf.cglib.proxy.MethodInterceptor;
import net.sf.cglib.proxy.NoOp;
/**
* Converts a proxy created by the CGLIB {@link Enhancer}. Such a proxy is recreated while deserializing the proxy. The
* converter does only work, if<br>
* <ul>
* <li>the DefaultNamingPolicy is used for the proxy's name</li>
* <li>the proxy uses a factory or only one Callback is registered</li>
* <li>a possible super class has at least a protected default constructor</li>
* </ul>
* Note, that the this converter relies on the CGLIBMapper.
*
* @author Jörg Schaible
* @since 1.2
*/
public class CGLIBEnhancedConverter extends SerializableConverter {
private static String DEFAULT_NAMING_MARKER = "$$EnhancerByCGLIB$$";
private static String CALLBACK_MARKER = "CGLIB$CALLBACK_";
private transient Map<String, List<Field>> fieldCache;
/**
* Construct a CGLIBEnhancedConverter.
*
* @param mapper the mapper chain instance
* @param reflectionProvider the reflection provider
* @param classLoaderReference the reference to the {@link ClassLoader} of the XStream instance
* @since 1.4.5
*/
public CGLIBEnhancedConverter(
final Mapper mapper, final ReflectionProvider reflectionProvider,
final ClassLoaderReference classLoaderReference) {
super(mapper, new CGLIBFilteringReflectionProvider(reflectionProvider), classLoaderReference);
fieldCache = new HashMap<>();
}
/**
* @deprecated As of 1.4.5 use {@link #CGLIBEnhancedConverter(Mapper, ReflectionProvider, ClassLoaderReference)}
*/
@Deprecated
public CGLIBEnhancedConverter(
final Mapper mapper, final ReflectionProvider reflectionProvider, final ClassLoader classLoader) {
super(mapper, new CGLIBFilteringReflectionProvider(reflectionProvider), classLoader);
fieldCache = new HashMap<>();
}
/**
* @deprecated As of 1.4 use {@link #CGLIBEnhancedConverter(Mapper, ReflectionProvider, ClassLoaderReference)}
*/
@Deprecated
public CGLIBEnhancedConverter(final Mapper mapper, final ReflectionProvider reflectionProvider) {
this(mapper, new CGLIBFilteringReflectionProvider(reflectionProvider), CGLIBEnhancedConverter.class
.getClassLoader());
}
@Override
public boolean canConvert(final Class<?> type) {
return type != null && Enhancer.isEnhanced(type) && type.getName().indexOf(DEFAULT_NAMING_MARKER) > 0
|| type == CGLIBMapper.Marker.class;
}
@Override
public void marshal(final Object source, final HierarchicalStreamWriter writer, final MarshallingContext context) {
final Class<?> type = source.getClass();
final boolean hasFactory = Factory.class.isAssignableFrom(type);
writer.startNode("type", type);
context.convertAnother(type.getSuperclass());
writer.endNode();
writer.startNode("interfaces");
final Class<?>[] interfaces = type.getInterfaces();
for (final Class<?> interface1 : interfaces) {
if (interface1 == Factory.class) {
continue;
}
writer.startNode(mapper.serializedClass(interface1.getClass()), interface1.getClass());
context.convertAnother(interface1);
writer.endNode();
}
writer.endNode();
writer.startNode("hasFactory");
writer.setValue(String.valueOf(hasFactory));
writer.endNode();
Map<?, ?> callbackIndexMap = null;
final Callback[] callbacks = hasFactory ? ((Factory)source).getCallbacks() : getCallbacks(source);
if (callbacks.length > 1) {
if (hasFactory) {
callbackIndexMap = createCallbackIndexMap((Factory)source);
} else {
final ConversionException exception = new ConversionException(
"Cannot handle CGLIB enhanced proxies without factory that have multiple callbacks");
exception.add("proxy-superclass", type.getSuperclass().getName());
exception.add("number-of-callbacks", String.valueOf(callbacks.length));
throw exception;
}
writer.startNode("callbacks");
writer.startNode("mapping");
context.convertAnother(callbackIndexMap);
writer.endNode();
}
boolean hasInterceptor = false;
for (final Callback callback : callbacks) {
if (callback == null) {
final String name = mapper.serializedClass(null);
writer.startNode(name);
writer.endNode();
} else {
hasInterceptor = hasInterceptor || MethodInterceptor.class.isAssignableFrom(callback.getClass());
writer.startNode(mapper.serializedClass(callback.getClass()), callback.getClass());
context.convertAnother(callback);
writer.endNode();
}
}
if (callbacks.length > 1) {
writer.endNode();
}
try {
final Field field = type.getDeclaredField("serialVersionUID");
if (!field.isAccessible()) {
field.setAccessible(true);
}
final long serialVersionUID = field.getLong(null);
writer.startNode("serialVersionUID", String.class);
writer.setValue(String.valueOf(serialVersionUID));
writer.endNode();
} catch (final NoSuchFieldException e) {
// OK, ignore
} catch (final IllegalAccessException e) {
final ObjectAccessException exception = new ObjectAccessException("Cannot access field", e);
exception.add("field", type.getName() + ".serialVersionUID");
throw exception;
}
if (hasInterceptor) {
writer.startNode("instance");
super.doMarshalConditionally(source, writer, context);
writer.endNode();
}
}
private Callback[] getCallbacks(final Object source) {
final Class<?> type = source.getClass();
List<Field> fields = fieldCache.get(type.getName());
if (fields == null) {
fields = new ArrayList<>();
fieldCache.put(type.getName(), fields);
for (int i = 0; true; ++i) {
try {
final Field field = type.getDeclaredField(CALLBACK_MARKER + i);
if (!field.isAccessible()) {
field.setAccessible(true);
}
fields.add(field);
} catch (final NoSuchFieldException e) {
break;
}
}
}
final List<Callback> list = new ArrayList<>();
for (int i = 0; i < fields.size(); ++i) {
try {
final Field field = fields.get(i);
final Callback callback = (Callback)field.get(source);
list.add(callback);
} catch (final IllegalAccessException e) {
final ObjectAccessException exception = new ObjectAccessException("Cannot access field", e);
exception.add("field", type.getName() + "." + CALLBACK_MARKER + i);
throw exception;
}
}
return list.toArray(new Callback[list.size()]);
}
private Map<? super Object, ? super Object> createCallbackIndexMap(final Factory source) {
final Callback[] originalCallbacks = source.getCallbacks();
final Callback[] reverseEngineeringCallbacks = new Callback[originalCallbacks.length];
final Map<? super Object, ? super Object> callbackIndexMap = new HashMap<>();
int idxNoOp = -1;
for (int i = 0; i < originalCallbacks.length; i++) {
final Callback callback = originalCallbacks[i];
if (callback == null) {
reverseEngineeringCallbacks[i] = null;
} else if (NoOp.class.isAssignableFrom(callback.getClass())) {
reverseEngineeringCallbacks[i] = NoOp.INSTANCE;
idxNoOp = i;
} else {
reverseEngineeringCallbacks[i] = createReverseEngineeredCallbackOfProperType(callback, i,
callbackIndexMap);
}
}
try {
source.setCallbacks(reverseEngineeringCallbacks);
final Set<Class<?>> interfaces = new HashSet<>();
final Set<Method> methods = new HashSet<>();
Class<?> type = source.getClass();
do {
methods.addAll(Arrays.asList(type.getDeclaredMethods()));
methods.addAll(Arrays.asList(type.getMethods()));
final Class<?>[] implementedInterfaces = type.getInterfaces();
interfaces.addAll(Arrays.asList(implementedInterfaces));
type = type.getSuperclass();
} while (type != null);
for (final Iterator<Class<?>> iterator = interfaces.iterator(); iterator.hasNext();) {
type = iterator.next();
methods.addAll(Arrays.asList(type.getDeclaredMethods()));
}
for (final Iterator<Method> iter = methods.iterator(); iter.hasNext();) {
final Method method = iter.next();
if (!method.isAccessible()) {
method.setAccessible(true);
}
if (Factory.class.isAssignableFrom(method.getDeclaringClass())
|| (method.getModifiers() & (Modifier.FINAL | Modifier.STATIC)) > 0) {
iter.remove();
continue;
}
final Class<?>[] parameterTypes = method.getParameterTypes();
Method calledMethod = method;
try {
if ((method.getModifiers() & Modifier.ABSTRACT) > 0) {
calledMethod = source.getClass().getMethod(method.getName(), method.getParameterTypes());
}
callbackIndexMap.put(null, method);
calledMethod.invoke(source, parameterTypes == null
? (Object[])null
: createNullArguments(parameterTypes));
} catch (final IllegalAccessException e) {
final ObjectAccessException exception = new ObjectAccessException("Cannot access method", e);
exception.add("method", calledMethod.toString());
throw exception;
} catch (final InvocationTargetException e) {
// OK, ignore
} catch (final NoSuchMethodException e) {
final ConversionException exception = new ConversionException(
"CGLIB enhanced proxies wit abstract nethod that has not been implemented");
exception.add("proxy-superclass", type.getSuperclass().getName());
exception.add("method", method.toString());
throw exception;
}
if (callbackIndexMap.containsKey(method)) {
iter.remove();
}
}
if (idxNoOp >= 0) {
final Integer idx = Integer.valueOf(idxNoOp);
for (final Method method : methods) {
callbackIndexMap.put(method, idx);
}
}
} finally {
source.setCallbacks(originalCallbacks);
}
callbackIndexMap.remove(null);
return callbackIndexMap;
}
private Object[] createNullArguments(final Class<?>[] parameterTypes) {
final Object[] arguments = new Object[parameterTypes.length];
for (int i = 0; i < arguments.length; i++) {
final Class<?> type = parameterTypes[i];
if (type.isPrimitive()) {
if (type == byte.class) {
arguments[i] = Byte.valueOf((byte)0);
} else if (type == short.class) {
arguments[i] = Short.valueOf((short)0);
} else if (type == int.class) {
arguments[i] = Integer.valueOf(0);
} else if (type == long.class) {
arguments[i] = Long.valueOf(0);
} else if (type == float.class) {
arguments[i] = Float.valueOf(0);
} else if (type == double.class) {
arguments[i] = Double.valueOf(0);
} else if (type == char.class) {
arguments[i] = Character.valueOf('\0');
} else {
arguments[i] = Boolean.FALSE;
}
}
}
return arguments;
}
private Callback createReverseEngineeredCallbackOfProperType(final Callback callback, final int index,
final Map<? super Object, ? super Object> callbackIndexMap) {
Class<?> iface = null;
Class<?>[] interfaces = callback.getClass().getInterfaces();
for (int i = 0; i < interfaces.length; i++) {
if (Callback.class.isAssignableFrom(interfaces[i])) {
iface = interfaces[i];
if (iface == Callback.class) {
final ConversionException exception = new ConversionException("Cannot handle CGLIB callback");
exception.add("CGLIB-callback-type", callback.getClass().getName());
throw exception;
}
interfaces = iface.getInterfaces();
if (Arrays.asList(interfaces).contains(Callback.class)) {
break;
}
i = -1;
}
}
return (Callback)Proxy.newProxyInstance(iface.getClassLoader(), new Class[]{iface},
new ReverseEngineeringInvocationHandler(index, callbackIndexMap));
}
@Override
public Object unmarshal(final HierarchicalStreamReader reader, final UnmarshallingContext context) {
final Enhancer enhancer = new Enhancer();
reader.moveDown();
enhancer.setSuperclass((Class<?>)context.convertAnother(null, Class.class));
reader.moveUp();
reader.moveDown();
final List<Class<?>> interfaces = new ArrayList<>();
while (reader.hasMoreChildren()) {
reader.moveDown();
interfaces.add((Class<?>)context.convertAnother(null, mapper.realClass(reader.getNodeName())));
reader.moveUp();
}
enhancer.setInterfaces(interfaces.toArray(new Class[interfaces.size()]));
reader.moveUp();
reader.moveDown();
final boolean useFactory = Boolean.valueOf(reader.getValue()).booleanValue();
enhancer.setUseFactory(useFactory);
reader.moveUp();
final List<Callback> callbacksToEnhance = new ArrayList<>();
final List<Callback> callbacks = new ArrayList<>();
Map<Method, Integer> callbackIndexMap = null;
reader.moveDown();
if ("callbacks".equals(reader.getNodeName())) {
reader.moveDown();
@SuppressWarnings("unchecked")
final Map<Method, Integer> typedMap = (Map<Method, Integer>)context.convertAnother(null, HashMap.class);
callbackIndexMap = typedMap;
reader.moveUp();
while (reader.hasMoreChildren()) {
reader.moveDown();
readCallback(reader, context, callbacksToEnhance, callbacks);
reader.moveUp();
}
} else {
readCallback(reader, context, callbacksToEnhance, callbacks);
}
enhancer.setCallbacks(callbacksToEnhance.toArray(new Callback[callbacksToEnhance.size()]));
if (callbackIndexMap != null) {
enhancer.setCallbackFilter(new ReverseEngineeredCallbackFilter(callbackIndexMap));
}
reader.moveUp();
Object result = null;
while (reader.hasMoreChildren()) {
reader.moveDown();
if (reader.getNodeName().equals("serialVersionUID")) {
enhancer.setSerialVersionUID(Long.valueOf(reader.getValue()));
} else if (reader.getNodeName().equals("instance")) {
result = create(enhancer, callbacks, useFactory);
super.doUnmarshalConditionally(result, reader, context);
}
reader.moveUp();
}
if (result == null) {
result = create(enhancer, callbacks, useFactory);
}
return serializationMembers.callReadResolve(result);
}
private void readCallback(final HierarchicalStreamReader reader, final UnmarshallingContext context,
final List<Callback> callbacksToEnhance, final List<Callback> callbacks) {
final Callback callback = (Callback)context.convertAnother(null, mapper.realClass(reader.getNodeName()));
callbacks.add(callback);
if (callback == null) {
callbacksToEnhance.add(NoOp.INSTANCE);
} else {
callbacksToEnhance.add(callback);
}
}
private Object create(final Enhancer enhancer, final List<Callback> callbacks, final boolean useFactory) {
final Object result = enhancer.create();
if (useFactory) {
((Factory)result).setCallbacks(callbacks.toArray(new Callback[callbacks.size()]));
}
return result;
}
@Override
protected List<Class<?>> hierarchyFor(final Class<?> type) {
final List<Class<?>> typeHierarchy = super.hierarchyFor(type);
// drop the CGLIB proxy
typeHierarchy.remove(typeHierarchy.size() - 1);
return typeHierarchy;
}
@Override
protected Object readResolve() {
super.readResolve();
fieldCache = new HashMap<>();
return this;
}
private static class CGLIBFilteringReflectionProvider extends ReflectionProviderWrapper {
public CGLIBFilteringReflectionProvider(final ReflectionProvider reflectionProvider) {
super(reflectionProvider);
}
@Override
public void visitSerializableFields(final Object object, final Visitor visitor) {
wrapped.visitSerializableFields(object, new Visitor() {
@Override
public void visit(final String name, final Class<?> type, final Class<?> definedIn,
final Object value) {
if (!name.startsWith("CGLIB$")) {
visitor.visit(name, type, definedIn, value);
}
}
});
}
}
private static final class ReverseEngineeringInvocationHandler implements InvocationHandler {
private final Integer index;
private final Map<? super Object, ? super Object> indexMap;
public ReverseEngineeringInvocationHandler(
final int index, final Map<? super Object, ? super Object> indexMap) {
this.indexMap = indexMap;
this.index = Integer.valueOf(index);
}
@Override
public Object invoke(final Object proxy, final Method method, final Object[] args) throws Throwable {
indexMap.put(indexMap.get(null), index);
return null;
}
}
private static class ReverseEngineeredCallbackFilter implements CallbackFilter {
private final Map<Method, Integer> callbackIndexMap;
public ReverseEngineeredCallbackFilter(final Map<Method, Integer> callbackIndexMap) {
this.callbackIndexMap = callbackIndexMap;
}
@Override
public int accept(final Method method) {
if (!callbackIndexMap.containsKey(method)) {
final ConversionException exception = new ConversionException(
"CGLIB callback not detected in reverse engineering");
exception.add("CGLIB-callback", method.toString());
throw exception;
}
return callbackIndexMap.get(method).intValue();
}
}
}