AbstractConnectionProperty.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.jdbc;
import com.google.common.base.CharMatcher;
import com.google.common.base.Splitter;
import okhttp3.Protocol;
import java.io.File;
import java.sql.DriverPropertyInfo;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
abstract class AbstractConnectionProperty<T>
implements ConnectionProperty<T>
{
private final String key;
private final Optional<String> defaultValue;
private final Predicate<Properties> isRequired;
private final Predicate<Properties> isAllowed;
private final Converter<T> converter;
protected AbstractConnectionProperty(
String key,
Optional<String> defaultValue,
Predicate<Properties> isRequired,
Predicate<Properties> isAllowed,
Converter<T> converter)
{
this.key = requireNonNull(key, "key is null");
this.defaultValue = requireNonNull(defaultValue, "defaultValue is null");
this.isRequired = requireNonNull(isRequired, "isRequired is null");
this.isAllowed = requireNonNull(isAllowed, "isAllowed is null");
this.converter = requireNonNull(converter, "converter is null");
}
protected AbstractConnectionProperty(
String key,
Predicate<Properties> required,
Predicate<Properties> allowed,
Converter<T> converter)
{
this(key, Optional.empty(), required, allowed, converter);
}
@Override
public String getKey()
{
return key;
}
@Override
public Optional<String> getDefault()
{
return defaultValue;
}
@Override
public DriverPropertyInfo getDriverPropertyInfo(Properties mergedProperties)
{
String currentValue = mergedProperties.getProperty(key);
DriverPropertyInfo result = new DriverPropertyInfo(key, currentValue);
result.required = isRequired.test(mergedProperties);
return result;
}
@Override
public boolean isRequired(Properties properties)
{
return isRequired.test(properties);
}
@Override
public boolean isAllowed(Properties properties)
{
return !properties.containsKey(key) || isAllowed.test(properties);
}
@Override
public Optional<T> getValue(Properties properties)
throws SQLException
{
String value = properties.getProperty(key);
if (value == null) {
if (isRequired(properties)) {
throw new SQLException(format("Connection property '%s' is required", key));
}
return Optional.empty();
}
try {
return Optional.of(converter.convert(value));
}
catch (RuntimeException e) {
if (value.isEmpty()) {
throw new SQLException(format("Connection property '%s' value is empty", key), e);
}
throw new SQLException(format("Connection property '%s' value is invalid: %s", key, value), e);
}
}
@Override
public void validate(Properties properties)
throws SQLException
{
if (!isAllowed(properties)) {
throw new SQLException(format("Connection property '%s' is not allowed", key));
}
getValue(properties);
}
protected static final Predicate<Properties> REQUIRED = properties -> true;
protected static final Predicate<Properties> NOT_REQUIRED = properties -> false;
protected static final Predicate<Properties> ALLOWED = properties -> true;
interface Converter<T>
{
T convert(String value);
}
protected static final Converter<String> STRING_CONVERTER = value -> value;
protected static final Converter<String> NON_EMPTY_STRING_CONVERTER = value -> {
checkArgument(!value.isEmpty(), "value is empty");
return value;
};
protected static final Converter<File> FILE_CONVERTER = File::new;
protected static final Converter<Boolean> BOOLEAN_CONVERTER = value -> {
switch (value.toLowerCase(ENGLISH)) {
case "true":
return true;
case "false":
return false;
}
throw new IllegalArgumentException("value must be 'true' or 'false'");
};
protected static final class StringMapConverter
implements Converter<Map<String, String>>
{
private static final CharMatcher PRINTABLE_ASCII = CharMatcher.inRange((char) 0x21, (char) 0x7E);
public static final StringMapConverter STRING_MAP_CONVERTER = new StringMapConverter();
private StringMapConverter() {}
@Override
public Map<String, String> convert(String value)
{
return Splitter.on(';').splitToList(value).stream()
.map(this::parseKeyValuePair)
.collect(toImmutableMap(entry -> entry.get(0), entry -> entry.get(1)));
}
public List<String> parseKeyValuePair(String keyValue)
{
List<String> nameValue = Splitter.on(':').splitToList(keyValue);
checkArgument(nameValue.size() == 2, "Malformed key value pair: %s", keyValue);
String name = nameValue.get(0);
String value = nameValue.get(1);
checkArgument(!name.isEmpty(), "Key is empty");
checkArgument(!value.isEmpty(), "Value is empty");
checkArgument(PRINTABLE_ASCII.matchesAllOf(name), "Key contains spaces or is not printable ASCII: %s", name);
checkArgument(PRINTABLE_ASCII.matchesAllOf(value), "Value contains spaces or is not printable ASCII: %s", name);
return nameValue;
}
}
protected static final class ListValidateConvertor
implements Converter<String>
{
public static final Converter<String> LIST_VALIDATE_CONVERTOR = new ListValidateConvertor();
private ListValidateConvertor() {}
@Override
public String convert(String value)
{
return Splitter.on(',').trimResults().splitToList(value).stream().map(this::validatePattern).collect(Collectors.joining(","));
}
private String validatePattern(String value)
{
Pattern alphaNumericFilter = Pattern.compile("^[a-zA-Z0-9]+$");
boolean isAlphaNumeric = alphaNumericFilter.matcher(value).matches();
checkArgument(isAlphaNumeric, "Input client tag should contain only alphanumeric characters: %s", value);
return value;
}
}
protected static final class ClassListConverter
implements Converter<List<QueryInterceptor>>
{
public static final ClassListConverter CLASS_LIST_CONVERTER = new ClassListConverter();
private ClassListConverter() {}
@Override
public List<QueryInterceptor> convert(String value)
{
return Splitter.on(';').splitToList(value).stream()
.map(this::loadClass)
.collect(toImmutableList());
}
private QueryInterceptor loadClass(String interceptor)
{
try {
return (QueryInterceptor) Class.forName(interceptor).getDeclaredConstructor().newInstance();
}
catch (Throwable e) {
throw new IllegalArgumentException(format("Could not load QueryInterceptor classes from %s", interceptor), e);
}
}
}
protected static final class HttpProtocolConverter
implements Converter<List<Protocol>>
{
public static final HttpProtocolConverter HTTP_PROTOCOL_CONVERTER = new HttpProtocolConverter();
private HttpProtocolConverter() {}
@Override
public List<Protocol> convert(String value)
{
return Splitter.on(',').splitToList(value).stream()
.map(this::loadProtocol)
.distinct()
.collect(toImmutableList());
}
private Protocol loadProtocol(String protocolName)
{
try {
switch (protocolName.toLowerCase(ENGLISH)) {
case "http11":
return Protocol.HTTP_1_1;
case "http10":
return Protocol.HTTP_1_0;
case "http2":
return Protocol.HTTP_2;
default:
return Protocol.get(protocolName);
}
}
catch (Exception e) {
throw new IllegalArgumentException(format("Could not load OkhttpProtocol from %s", protocolName), e);
}
}
}
protected interface CheckedPredicate<T>
{
boolean test(T t)
throws SQLException;
}
protected static <T> Predicate<T> checkedPredicate(CheckedPredicate<T> predicate)
{
return t -> {
try {
return predicate.test(t);
}
catch (SQLException e) {
return false;
}
};
}
}