SerdeResolverUtils.java
/*
* Copyright 2022-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.stream.binder.kafka.streams;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.common.serialization.Serdes;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.kafka.support.serializer.JacksonJsonSerde;
/**
* Utility class that contains various methods to help resolve {@link Serde Serdes}.
*
* @author Chris Bono
* @since 4.0
*/
abstract class SerdeResolverUtils {
private static final Log LOG = LogFactory.getLog(SerdeResolverUtils.class);
/** Classnames of the standard built-in Serdes supported in {@link Serdes#serdeFrom(Class)}. */
private static final Set<String> STANDARD_SERDE_CLASSNAMES = Set.of(
Serdes.String().getClass().getName(),
Serdes.Short().getClass().getName(),
Serdes.Integer().getClass().getName(),
Serdes.Long().getClass().getName(),
Serdes.Float().getClass().getName(),
Serdes.Double().getClass().getName(),
Serdes.ByteArray().getClass().getName(),
Serdes.ByteBuffer().getClass().getName(),
Serdes.Bytes().getClass().getName(),
Serdes.UUID().getClass().getName());
/**
* Return the {@code Serde<?>} to use for the specified type using the following steps until a match is found.
* <p><ul>
* <li>the closest matching configured {@code Serde<?>} bean if one exists</li>
* <li>the Kafka Streams built-in serde if the target type is one of the built-in types exposed by Kafka Streams
* (Integer, Long, Short, Double, Float, byte[], UUID and String)
* <li>the fallback serde if specified and not one of the Kafka Streams exposed type serdes</li>
* <li>the {@link JsonSerde} if the target type is not exactly {@code Object}</li>
* <li>the fallback as the last resort</li>
* </ul>
* @param context the application context
* @param targetType the target type to find the serde for
* @param fallbackSerde the serde to use when no type can be inferred
* @return serde to use for the target type or {@code fallbackSerde} as outlined in the method description
*/
static Serde<?> resolveForType(ConfigurableApplicationContext context, ResolvableType targetType, /*@Nullable*/ Serde<?> fallbackSerde) {
Class<?> genericRawClazz = targetType.getRawClass();
// We don't attempt to find a matching Serde for type '?' - just return fallback
if (genericRawClazz == null) {
return fallbackSerde;
}
List<String> matchingSerdes = beanNamesForMatchingSerdes(context, targetType);
if (!matchingSerdes.isEmpty()) {
return context.getBean(matchingSerdes.get(0), Serde.class);
}
// Use standard serde for built-in types
Serde<?> standardDefaultSerde = getStandardDefaultSerde(genericRawClazz);
if (standardDefaultSerde != null) {
return standardDefaultSerde;
}
// Use fallback if specified and not from std defaults (we know from above that type is not std default
// so using a fallback that is std default type would not work)
if (fallbackSerde != null && !isSerdeFromStandardDefaults(fallbackSerde)) {
return fallbackSerde;
}
// Use JsonSerde if type is not exactly Object
if (!genericRawClazz.isAssignableFrom((Object.class))) {
return new JacksonJsonSerde<>(genericRawClazz);
}
// Finally, just resort to using the fallback
return fallbackSerde;
}
private static Serde<?> getStandardDefaultSerde(Class<?> genericRawClazz) {
try {
return Serdes.serdeFrom(genericRawClazz);
}
catch (IllegalArgumentException ex) {
if (LOG.isTraceEnabled()) {
LOG.trace(ex);
}
}
return null;
}
private static boolean isSerdeFromStandardDefaults(Serde<?> serde) {
if (serde == null) {
return false;
}
return STANDARD_SERDE_CLASSNAMES.contains(serde.getClass().getName());
}
/**
* Find the names of all {@link Serde} beans that can be used for {@code targetType}.
*
* @param context the application context
* @param targetType the target type the serdes are being matched for
* @return list of bean names for matching serdes ordered by most specific match, or an empty list if no matches found
*/
static List<String> beanNamesForMatchingSerdes(ConfigurableApplicationContext context, ResolvableType targetType) {
// We don't attempt to find a matching Serde for type '?'
if (targetType.getRawClass() == null) {
return Collections.emptyList();
}
List<SerdeWithSpecificityScore> matchingSerdes = new ArrayList<>();
ResolvableType serdeType = ResolvableType.forClassWithGenerics(Serde.class, targetType);
String[] serdeBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(context.getBeanFactory(),
serdeType, false, false);
Arrays.stream(serdeBeanNames).forEach((beanName) -> {
try {
BeanDefinition beanDefinition = context.getBeanFactory().getMergedBeanDefinition(beanName);
ResolvableType serdeBeanGeneric = beanDefinition.getResolvableType().getGeneric(0);
if (LOG.isDebugEnabled()) {
LOG.debug("Found matching Serde<" + serdeBeanGeneric.getType() + "> under beanName=" + beanName);
}
matchingSerdes.add(new SerdeWithSpecificityScore(calculateScore(targetType, serdeBeanGeneric), beanName));
}
catch (Exception ex) {
LOG.warn("Failed introspecting Serde bean '" + beanName + "'", ex);
}
});
if (!matchingSerdes.isEmpty()) {
return matchingSerdes.stream().sorted(Collections.reverseOrder())
.map(SerdeWithSpecificityScore::getSerdeBeanName)
.collect(Collectors.toList());
}
return Collections.emptyList();
}
/**
* Calculate a score to indicate how specific of a match one resolvable type is to another.
* <p><br>Simple string comparison (the number of matching leading characters between two type strings) is used to
* calculate the score. This approach avoids the recursive nature of the possible generic types, and leverages
* the type strings returned from {@link ResolvableType#toString()} and {@link Type#getTypeName()} which already
* include the properly handled generic types. The score is a composite of both of these properties because the
* 'toString' value does not include bounds values and works as a tie-breaker to distinguish non-exact
* matches that are closer in nature.
* <p><br><b>Example:</b>
* <pre>{@code
* -------------------------------------------------------------------------------------------------------
* targetType: Foo<Date> toString='Foo<Date>' typeName='Foo<java.util.Date>'
* typeToCheck1: Foo<Date> toString='Foo<Date>' typeName='Foo<java.util.Date>'
* typeToCheck2: Foo<? extends Date> toString='Foo<Date>' typeName='Foo<? extends java.util.Date>'
* -------------------------------------------------------------------------------------------------------
* }</pre>
*
* If using only the 'toString' value then both types would have the same score. However, including the 'typeName'
* value in the score differentiates them - in this case it is clear that 'typeToCheck1' is a direct match.
*
* @param targetType the target type
* @param typeToScore the type to calculate a score for
* @return a score on how close of a match {@code typeToScore} is to {@code targetType} - the higher the score,
* the closer the match
*/
private static int calculateScore(ResolvableType targetType, ResolvableType typeToScore) {
int score = countLeadingMatchingChars(targetType.getType().getTypeName(), typeToScore.getType().getTypeName());
score += countLeadingMatchingChars(targetType.toString(), typeToScore.toString());
return score;
}
private static int countLeadingMatchingChars(String s1, String s2) {
if (s1 == null || s2 == null) {
return 0;
}
int matchCount = 0;
for (int i = 0; i < s1.length() && i < s2.length(); i++) {
if (s1.charAt(i) != s2.charAt(i)) {
break;
}
matchCount++;
}
return matchCount;
}
/**
* Private internal class used strictly to 'remember' a score for a serde and use it for sorting later.
*/
private static class SerdeWithSpecificityScore implements Comparable<SerdeWithSpecificityScore> {
private final Integer score;
private final String serdeBeanName;
SerdeWithSpecificityScore(Integer score, String serdeBeanName) {
this.score = Objects.requireNonNull(score);
this.serdeBeanName = Objects.requireNonNull(serdeBeanName);
}
String getSerdeBeanName() {
return serdeBeanName;
}
@Override
public int compareTo(SerdeWithSpecificityScore other) {
return this.score.compareTo(other.score);
}
}
}