StandardClient.java
// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2025 MariaDB Corporation Ab
package org.mariadb.jdbc.client.impl;
import static org.mariadb.jdbc.client.impl.ConnectionHelper.enabledSslCipherSuites;
import static org.mariadb.jdbc.client.impl.ConnectionHelper.enabledSslProtocolSuites;
import static org.mariadb.jdbc.util.constants.Capabilities.SSL;
import java.io.*;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.security.KeyManagementException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLNonTransientConnectionException;
import java.sql.SQLTimeoutException;
import java.time.DateTimeException;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.util.*;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.net.ssl.*;
import org.mariadb.jdbc.Configuration;
import org.mariadb.jdbc.HostAddress;
import org.mariadb.jdbc.ServerPreparedStatement;
import org.mariadb.jdbc.client.Client;
import org.mariadb.jdbc.client.Completion;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.ReadableByteBuf;
import org.mariadb.jdbc.client.context.BaseContext;
import org.mariadb.jdbc.client.context.RedoContext;
import org.mariadb.jdbc.client.result.Result;
import org.mariadb.jdbc.client.result.StreamingResult;
import org.mariadb.jdbc.client.socket.Reader;
import org.mariadb.jdbc.client.socket.Writer;
import org.mariadb.jdbc.client.socket.impl.*;
import org.mariadb.jdbc.client.tls.MariaDbX509EphemeralTrustingManager;
import org.mariadb.jdbc.client.util.ClosableLock;
import org.mariadb.jdbc.client.util.MutableByte;
import org.mariadb.jdbc.export.ExceptionFactory;
import org.mariadb.jdbc.export.MaxAllowedPacketException;
import org.mariadb.jdbc.export.Prepare;
import org.mariadb.jdbc.export.SslMode;
import org.mariadb.jdbc.message.ClientMessage;
import org.mariadb.jdbc.message.client.*;
import org.mariadb.jdbc.message.server.*;
import org.mariadb.jdbc.plugin.*;
import org.mariadb.jdbc.plugin.authentication.AuthenticationPluginLoader;
import org.mariadb.jdbc.plugin.authentication.addon.ClearPasswordPlugin;
import org.mariadb.jdbc.plugin.authentication.standard.NativePasswordPlugin;
import org.mariadb.jdbc.plugin.tls.TlsSocketPluginLoader;
import org.mariadb.jdbc.util.Security;
import org.mariadb.jdbc.util.StringUtils;
import org.mariadb.jdbc.util.constants.Capabilities;
import org.mariadb.jdbc.util.constants.ServerStatus;
import org.mariadb.jdbc.util.log.Logger;
import org.mariadb.jdbc.util.log.Loggers;
/** Connection client */
public class StandardClient implements Client, AutoCloseable {
private static final Logger logger = Loggers.getLogger(StandardClient.class);
/** connection exception factory */
protected final ExceptionFactory exceptionFactory;
private static final Pattern REDIRECT_PATTERN =
Pattern.compile(
"(mariadb|mysql):\\/\\/(([^/@:]+)?(:([^/]+))?@)?(([^/:]+)(:([0-9]+))?)(\\/([^?]+)(\\?(.*))?)?$",
Pattern.CASE_INSENSITIVE | Pattern.DOTALL);
private Socket socket;
private final MutableByte sequence = new MutableByte();
private final MutableByte compressionSequence = new MutableByte();
private final ClosableLock lock;
private Configuration conf;
private AuthenticationPlugin authPlugin;
private HostAddress hostAddress;
private final boolean disablePipeline;
/** connection context */
protected Context context;
/** packet writer */
protected Writer writer;
private boolean closed = false;
private Reader reader;
private byte[] certFingerprint = null;
private org.mariadb.jdbc.Statement streamStmt = null;
private ClientMessage streamMsg = null;
private int socketTimeout;
private final Consumer<String> redirectConsumer = this::redirect;
/**
* Constructor
*
* @param conf configuration
* @param hostAddress host
* @param lock thread locker
* @param skipPostCommands must connection post command be skipped
* @throws SQLException if connection fails
*/
@SuppressWarnings({"this-escape"})
public StandardClient(
Configuration conf, HostAddress hostAddress, ClosableLock lock, boolean skipPostCommands)
throws SQLException {
this.conf = conf;
this.lock = lock;
this.hostAddress = hostAddress;
this.exceptionFactory = new ExceptionFactory(conf, hostAddress);
this.disablePipeline = conf.disablePipeline();
this.socketTimeout = conf.socketTimeout();
this.socket = ConnectionHelper.connectSocket(conf, hostAddress);
try {
setupConnection(skipPostCommands);
} catch (SQLException e) {
handleConnectionError(e);
} catch (SocketTimeoutException e) {
handleTimeoutError(e);
} catch (IOException e) {
handleIOError(e);
}
}
private void setupConnection(boolean skipPostCommands) throws SQLException, IOException {
OutputStream out = socket.getOutputStream();
InputStream in =
conf.useReadAheadInput()
? new ReadAheadBufferedStream(socket.getInputStream())
: new BufferedInputStream(socket.getInputStream(), 16384);
assignStream(out, in, conf, null);
configureTimeout();
InitialHandshakePacket handshake = handleServerHandshake();
long clientCapabilities = setupClientCapabilities(handshake);
SSLSocket sslSocket = handleSSLConnection(handshake, clientCapabilities);
if (sslSocket != null) {
out = new BufferedOutputStream(sslSocket.getOutputStream(), 16384);
in =
conf.useReadAheadInput()
? new ReadAheadBufferedStream(sslSocket.getInputStream())
: new BufferedInputStream(sslSocket.getInputStream(), 16384);
assignStream(out, in, conf, handshake.getThreadId());
}
handleAuthentication(handshake, clientCapabilities);
setupCompression(in, out, clientCapabilities, handshake.getThreadId());
if (!skipPostCommands) {
postConnectionQueries();
}
setSocketTimeout(conf.socketTimeout());
}
private void setupCompression(
InputStream in, OutputStream out, long clientCapabilities, long threadId) {
if ((clientCapabilities & Capabilities.COMPRESS) != 0) {
assignStream(
new CompressOutputStream(out, compressionSequence),
new CompressInputStream(in, compressionSequence),
conf,
threadId);
}
}
private SSLSocket handleSSLConnection(InitialHandshakePacket handshake, long clientCapabilities)
throws SQLException, IOException {
updateThreadIds(handshake);
Configuration conf = context.getConf();
SslMode sslMode = determineSslMode(conf);
if (sslMode == SslMode.DISABLE) {
return null;
}
validateServerSslCapability();
sendSslRequest(handshake, clientCapabilities);
TlsSocketPlugin socketPlugin = TlsSocketPluginLoader.get(conf.tlsSocketType());
TrustManager[] trustManagers =
socketPlugin.getTrustManager(conf, context.getExceptionFactory(), hostAddress);
SSLSocket sslSocket = createSslSocket(conf, socketPlugin, trustManagers);
configureSslSocket(sslSocket, conf);
handleSslHandshake(sslSocket, trustManagers);
if (requiresHostnameVerification(sslMode)) {
verifyHostname(sslSocket, socketPlugin);
}
return sslSocket;
}
private void updateThreadIds(InitialHandshakePacket handshake) {
this.reader.setServerThreadId(handshake.getThreadId(), hostAddress);
this.writer.setServerThreadId(handshake.getThreadId(), hostAddress);
}
private SslMode determineSslMode(Configuration conf) {
return hostAddress.sslMode == null ? conf.sslMode() : hostAddress.sslMode;
}
private void validateServerSslCapability() throws SQLException {
if (!context.hasServerCapability(Capabilities.SSL)) {
throw context
.getExceptionFactory()
.create("Trying to connect with ssl, but ssl not enabled in the server", "08000");
}
}
private void sendSslRequest(InitialHandshakePacket handshake, long clientCapabilities)
throws IOException {
SslRequestPacket.create(
clientCapabilities | Capabilities.SSL, (byte) handshake.getDefaultCollation())
.encode(writer, context);
}
private SSLSocket createSslSocket(
Configuration conf, TlsSocketPlugin socketPlugin, TrustManager[] trustManagers)
throws SQLException, IOException {
try {
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(
socketPlugin.getKeyManager(conf, context.getExceptionFactory()), trustManagers, null);
return socketPlugin.createSocket(socket, sslContext.getSocketFactory());
} catch (KeyManagementException e) {
throw context.getExceptionFactory().create("Could not initialize SSL context", "08000", e);
} catch (NoSuchAlgorithmException e) {
throw context
.getExceptionFactory()
.create("SSLContext TLS Algorithm not unknown", "08000", e);
}
}
private void configureSslSocket(SSLSocket sslSocket, Configuration conf) throws SQLException {
enabledSslProtocolSuites(sslSocket, conf);
enabledSslCipherSuites(sslSocket, conf);
sslSocket.setUseClientMode(true);
}
private void handleSslHandshake(SSLSocket sslSocket, TrustManager[] trustManagers)
throws IOException {
sslSocket.startHandshake();
if (trustManagers.length > 0
&& trustManagers[0] instanceof MariaDbX509EphemeralTrustingManager) {
certFingerprint = ((MariaDbX509EphemeralTrustingManager) trustManagers[0]).getFingerprint();
}
}
private boolean requiresHostnameVerification(SslMode sslMode) {
return certFingerprint == null && sslMode == SslMode.VERIFY_FULL && hostAddress.host != null;
}
private void verifyHostname(SSLSocket sslSocket, TlsSocketPlugin socketPlugin)
throws SQLException {
try {
socketPlugin.verify(hostAddress.host, sslSocket.getSession(), context.getThreadId());
} catch (SSLException ex) {
throw context
.getExceptionFactory()
.create(
"SSL hostname verification failed : "
+ ex.getMessage()
+ "\nThis verification can be disabled using the sslMode to VERIFY_CA "
+ "but won't prevent man-in-the-middle attacks anymore",
"08006");
}
}
private void configureTimeout() throws SQLException {
if (conf.connectTimeout() > 0) {
setSocketTimeout(conf.connectTimeout());
} else if (conf.socketTimeout() > 0) {
setSocketTimeout(conf.socketTimeout());
}
}
private InitialHandshakePacket handleServerHandshake() throws SQLException, IOException {
ReadableByteBuf buf = reader.readReusablePacket(logger.isTraceEnabled());
if (buf.getByte() == -1) {
throwHandshakeError(buf);
}
return InitialHandshakePacket.decode(buf);
}
private void throwHandshakeError(ReadableByteBuf buf) throws SQLException {
ErrorPacket errorPacket = new ErrorPacket(buf, null);
throw this.exceptionFactory.create(
errorPacket.getMessage(), errorPacket.getSqlState(), errorPacket.getErrorCode());
}
private long setupClientCapabilities(InitialHandshakePacket handshake) {
this.exceptionFactory.setThreadId(handshake.getThreadId());
long capabilities =
ConnectionHelper.initializeClientCapabilities(
conf, handshake.getCapabilities(), hostAddress);
initializeContext(handshake, capabilities);
this.reader.setServerThreadId(handshake.getThreadId(), hostAddress);
this.writer.setServerThreadId(handshake.getThreadId(), hostAddress);
return capabilities;
}
private void initializeContext(InitialHandshakePacket handshake, long clientCapabilities) {
PrepareCache cache =
conf.cachePrepStmts() ? new PrepareCache(conf.prepStmtCacheSize(), this) : null;
Boolean isLoopback = null;
if (socket.getInetAddress() != null) isLoopback = socket.getInetAddress().isLoopbackAddress();
this.context =
conf.transactionReplay()
? new RedoContext(
hostAddress,
handshake,
clientCapabilities,
conf,
exceptionFactory,
cache,
isLoopback)
: new BaseContext(
hostAddress,
handshake,
clientCapabilities,
conf,
exceptionFactory,
cache,
isLoopback);
}
private void handleAuthentication(InitialHandshakePacket handshake, long clientCapabilities)
throws IOException, SQLException {
String authType = determineAuthType(handshake);
Credential credential =
ConnectionHelper.loadCredential(conf.credentialPlugin(), conf, hostAddress);
sendHandshakeResponse(handshake, clientCapabilities, credential, authType);
createAuthPlugin(handshake, credential, authType);
writer.flush();
authenticationHandler(credential, hostAddress);
}
private String determineAuthType(InitialHandshakePacket handshake) {
String authType = handshake.getAuthenticationPluginType();
CredentialPlugin credPlugin = conf.credentialPlugin();
if (credPlugin != null && credPlugin.defaultAuthenticationPluginType() != null) {
authType = credPlugin.defaultAuthenticationPluginType();
}
return authType;
}
private void handleConnectionError(SQLException e) throws SQLException {
destroySocket();
throw e;
}
private void handleTimeoutError(SocketTimeoutException e) throws SQLTimeoutException {
destroySocket();
throw new SQLTimeoutException(
String.format("Socket timeout when connecting to %s. %s", hostAddress, e.getMessage()),
"08000",
e);
}
private void handleIOError(IOException e) throws SQLException {
destroySocket();
throw exceptionFactory.create(
String.format("Could not connect to %s : %s", hostAddress, e.getMessage()), "08000", e);
}
private void sendHandshakeResponse(
InitialHandshakePacket handshake,
long clientCapabilities,
Credential credential,
String authType)
throws IOException {
new HandshakeResponse(
credential,
authType,
context.getSeed(),
conf,
hostAddress.host,
clientCapabilities,
(byte) handshake.getDefaultCollation())
.encode(writer, context);
}
private void createAuthPlugin(
InitialHandshakePacket handshake, Credential credential, String authType) {
authPlugin =
"mysql_clear_password".equals(authType)
? new ClearPasswordPlugin(credential.getPassword())
: new NativePasswordPlugin(credential.getPassword(), handshake.getSeed());
}
/**
* @param credential credential
* @param hostAddress host address
* @throws IOException if any socket error occurs
* @throws SQLException if any other kind of issue occurs
*/
public void authenticationHandler(Credential credential, HostAddress hostAddress)
throws IOException, SQLException {
writer.permitTrace(true);
Configuration conf = context.getConf();
ReadableByteBuf buf = reader.readReusablePacket();
authentication_loop:
while (true) {
switch (buf.getByte() & 0xFF) {
case 0xFE:
// *************************************************************************************
// Authentication Switch Request see
// https://mariadb.com/kb/en/library/connection/#authentication-switch-request
// *************************************************************************************
AuthSwitchPacket authSwitchPacket = AuthSwitchPacket.decode(buf);
AuthenticationPluginFactory authPluginFactory =
AuthenticationPluginLoader.get(authSwitchPacket.getPlugin(), conf);
if (authPluginFactory.requireSsl() && !context.hasClientCapability(SSL)) {
throw context
.getExceptionFactory()
.create(
"Cannot use authentication plugin "
+ authPluginFactory.type()
+ " if SSL is not enabled.",
"08000");
}
authPlugin =
authPluginFactory.initialize(
credential.getPassword(), authSwitchPacket.getSeed(), conf, hostAddress);
if (certFingerprint != null
&& (!authPlugin.isMitMProof()
|| credential.getPassword() == null
|| credential.getPassword().isEmpty())) {
throw context
.getExceptionFactory()
.create(
String.format(
"Cannot use authentication plugin %s with a Self signed certificates."
+ " Either set sslMode=trust, use password with a MitM-Proof"
+ " authentication plugin or provide server certificate to client",
authPluginFactory.type()));
}
buf = authPlugin.process(writer, reader, context);
break;
case 0xFF:
// *************************************************************************************
// ERR_Packet
// see https://mariadb.com/kb/en/library/err_packet/
// *************************************************************************************
ErrorPacket errorPacket = new ErrorPacket(buf, context);
throw context
.getExceptionFactory()
.create(
errorPacket.getMessage(), errorPacket.getSqlState(), errorPacket.getErrorCode());
case 0x00:
// *************************************************************************************
// OK_Packet -> Authenticated !
// see https://mariadb.com/kb/en/library/ok_packet/
// *************************************************************************************
OkPacket okPacket = OkPacket.parseWithInfo(buf, context);
// ssl certificates validation using client password
if (certFingerprint != null) {
// need to ensure server certificates
// pass only if :
// * connection method is MitM-proof (e.g. unix socket)
// * auth plugin is MitM-proof and check SHA2(user's hashed password, scramble,
// certificate fingerprint)
if (this.socket instanceof UnixDomainSocket) break authentication_loop;
if (!authPlugin.isMitMProof()
|| credential.getPassword() == null
|| credential.getPassword().isEmpty()
|| !validateFingerPrint(
authPlugin,
okPacket.getInfo(),
certFingerprint,
credential,
context.getSeed())) {
throw context
.getExceptionFactory()
.create(
"Self signed certificates. Either set sslMode=trust, use password with a"
+ " MitM-Proof authentication plugin or provide server certificate to"
+ " client",
"08000");
}
}
if (context.getRedirectUrl() != null
&& ((conf.permitRedirect() == null && conf.sslMode() == SslMode.VERIFY_FULL)
|| conf.permitRedirect())) redirect(context.getRedirectUrl());
break authentication_loop;
default:
throw context
.getExceptionFactory()
.create(
"unexpected data during authentication (header=" + (buf.getUnsignedByte()),
"08000");
}
}
writer.permitTrace(true);
}
private static boolean validateFingerPrint(
AuthenticationPlugin authPlugin,
byte[] validationHash,
byte[] fingerPrint,
Credential credential,
final byte[] seed) {
if (validationHash.length == 0) return false;
try {
assert (validationHash[0] == 0x01); // SHA256 encryption
byte[] hash = authPlugin.hash(credential);
final MessageDigest messageDigest = MessageDigest.getInstance("SHA-256");
messageDigest.update(hash);
messageDigest.update(seed);
messageDigest.update(fingerPrint);
final byte[] digest = messageDigest.digest();
final String hashHex = StringUtils.byteArrayToHexString(digest);
final String serverValidationHex =
new String(validationHash, 1, validationHash.length - 1, StandardCharsets.US_ASCII);
return hashHex.equals(serverValidationHex);
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("SHA-256 MessageDigest expected to be not available", e);
}
}
public void redirect(String redirectUrl) {
if (redirectUrl != null
&& ((conf.permitRedirect() == null && conf.sslMode() == SslMode.VERIFY_FULL)
|| conf.permitRedirect())) {
// redirect only if not in a transaction
if ((this.context.getServerStatus() & ServerStatus.IN_TRANSACTION) == 0) {
this.context.setRedirectUrl(null);
Matcher matcher = REDIRECT_PATTERN.matcher(redirectUrl);
if (!matcher.matches()) {
logger.error(
"error parsing redirection string '"
+ redirectUrl
+ "'. format must be"
+ " 'mariadb/mysql://[<user>[:<password>]@]<host>[:<port>]/[<db>[?<opt1>=<value1>[&<opt2>=<value2>]]]'");
return;
}
try {
String redirectHost =
matcher.group(7) != null
? URLDecoder.decode(matcher.group(7), "utf8")
: matcher.group(6);
int redirectPort = matcher.group(9) != null ? Integer.parseInt(matcher.group(9)) : 3306;
if (this.getHostAddress() != null
&& redirectHost.equals(this.getHostAddress().host)
&& redirectPort == this.getHostAddress().port) {
// redirection to the same host, skip loop redirection
return;
}
// actually only options accepted are user and password
// there might be additional possible options in the future
String redirectUser = matcher.group(3);
String redirectPwd = matcher.group(5);
Configuration.Builder redirectConfBuilder =
this.context.getConf().toBuilder()
.addresses(HostAddress.from(redirectHost, redirectPort, true));
if (redirectUser != null) redirectConfBuilder.user(redirectUser);
if (redirectPwd != null) redirectConfBuilder.password(redirectPwd);
try {
Configuration redirectConf = redirectConfBuilder.build();
HostAddress redirectHostAddress = redirectConf.addresses().get(0);
StandardClient redirectClient =
new StandardClient(redirectConf, redirectHostAddress, lock, false);
// properly close current connection
this.close();
logger.info("redirecting connection " + hostAddress + " to " + redirectUrl);
// affect redirection to current client
this.closed = false;
this.socket = redirectClient.socket;
this.conf = redirectConf;
this.hostAddress = redirectHostAddress;
this.context = redirectClient.context;
this.writer = redirectClient.writer;
this.reader = redirectClient.reader;
} catch (SQLException e) {
logger.error("fail to redirect to '" + redirectUrl + "'");
}
} catch (UnsupportedEncodingException ee) {
// eat, still supporting java 8
}
} else {
this.context.setRedirectUrl(redirectUrl);
}
} else {
this.context.setRedirectUrl(null);
}
}
private void assignStream(OutputStream out, InputStream in, Configuration conf, Long threadId) {
this.writer =
new PacketWriter(
out, conf.maxQuerySizeToLog(), conf.maxAllowedPacket(), sequence, compressionSequence);
this.writer.setServerThreadId(threadId, hostAddress);
this.reader = new PacketReader(in, conf, sequence);
this.reader.setServerThreadId(threadId, hostAddress);
}
/** Closing socket in case of Connection error after socket creation. */
protected void destroySocket() {
closed = true;
try {
this.reader.close();
} catch (IOException ee) {
// eat exception
}
try {
this.writer.close();
} catch (IOException ee) {
// eat exception
}
try {
this.socket.close();
} catch (IOException ee) {
// eat exception
}
}
/**
* load server timezone and ensure this corresponds to client timezone
*
* @throws SQLException if any socket error.
*/
private void handleTimezone() throws SQLException {
if (conf.connectionTimeZone() == null || "LOCAL".equalsIgnoreCase(conf.connectionTimeZone())) {
context.setConnectionTimeZone(TimeZone.getDefault());
} else {
String zoneId = conf.connectionTimeZone();
if ("SERVER".equalsIgnoreCase(zoneId)) {
try {
Result res =
(Result)
execute(new QueryPacket("SELECT @@time_zone, @@system_time_zone"), true).get(0);
res.next();
zoneId = res.getString(1);
if ("SYSTEM".equals(zoneId)) {
zoneId = res.getString(2);
}
} catch (SQLException sqle) {
Result res =
(Result)
execute(
new QueryPacket(
"SHOW VARIABLES WHERE Variable_name in ("
+ "'system_time_zone',"
+ "'time_zone')"),
true)
.get(0);
String systemTimeZone = null;
while (res.next()) {
if ("system_time_zone".equals(res.getString(1))) {
systemTimeZone = res.getString(2);
} else {
zoneId = res.getString(2);
}
}
if ("SYSTEM".equals(zoneId)) {
zoneId = systemTimeZone;
}
}
}
try {
context.setConnectionTimeZone(TimeZone.getTimeZone(ZoneId.of(zoneId).normalized()));
} catch (DateTimeException e) {
try {
context.setConnectionTimeZone(
TimeZone.getTimeZone(ZoneId.of(zoneId, ZoneId.SHORT_IDS).normalized()));
} catch (DateTimeException e2) {
// unknown zone id
throw new SQLException(String.format("Unknown zoneId %s", zoneId), e);
}
}
}
}
private void postConnectionQueries() throws SQLException {
List<String> commands = new ArrayList<>();
List<String> galeraAllowedStates =
conf.galeraAllowedState() == null
? Collections.emptyList()
: Arrays.asList(conf.galeraAllowedState().split(","));
if (hostAddress != null
&& Boolean.TRUE.equals(hostAddress.primary)
&& !galeraAllowedStates.isEmpty()) {
commands.add("show status like 'wsrep_local_state'");
}
handleTimezone();
String sessionVariableQuery = createSessionVariableQuery(context);
if (sessionVariableQuery != null) commands.add(sessionVariableQuery);
if (conf.database() != null
&& conf.createDatabaseIfNotExist()
&& (hostAddress == null || hostAddress.primary)) {
String escapedDb = conf.database().replace("`", "``");
commands.add(String.format("CREATE DATABASE IF NOT EXISTS `%s`", escapedDb));
commands.add(String.format("USE `%s`", escapedDb));
}
if (conf.initSql() != null) {
commands.add(conf.initSql());
}
if (conf.nonMappedOptions().containsKey("initSql")) {
String[] initialCommands = conf.nonMappedOptions().get("initSql").toString().split(";");
Collections.addAll(commands, initialCommands);
}
if (!commands.isEmpty()) {
try {
List<Completion> res;
ClientMessage[] msgs = new ClientMessage[commands.size()];
for (int i = 0; i < commands.size(); i++) {
msgs[i] = new QueryPacket(commands.get(i));
}
res =
executePipeline(
msgs,
null,
0,
0L,
ResultSet.CONCUR_READ_ONLY,
ResultSet.TYPE_FORWARD_ONLY,
false,
true);
if (hostAddress != null
&& Boolean.TRUE.equals(hostAddress.primary)
&& !galeraAllowedStates.isEmpty()) {
ResultSet rs = (ResultSet) res.get(0);
if (rs.next()) {
if (!galeraAllowedStates.contains(rs.getString(2))) {
throw exceptionFactory.create(
String.format("fail to validate Galera state (State is %s)", rs.getString(2)));
}
} else {
throw exceptionFactory.create(
"fail to validate Galera state (unknown 'wsrep_local_state' state)");
}
res.remove(0);
}
} catch (SQLException sqlException) {
if (!conf.disconnectOnExpiredPasswords()
&& (sqlException.getErrorCode() == 1862 || sqlException.getErrorCode() == 1820)) {
// password has expired, but configuration expressly permit sandbox mode.
logger.info("connected in sandbox mode. only password change is permitted");
return;
}
if (conf.timezone() != null && !"disable".equalsIgnoreCase(conf.timezone())) {
// timezone is not valid
throw exceptionFactory.create(
String.format(
"Setting configured timezone '%s' fail on server.\n"
+ "Look at https://mariadb.com/kb/en/mysql_tzinfo_to_sql/ to load tz data on"
+ " server, or set timezone=disable to disable setting client timezone.",
conf.timezone()),
"HY000",
sqlException);
}
throw exceptionFactory.create("Initialization command fail", "08000", sqlException);
}
if (conf.returnMultiValuesGeneratedIds()) {
ClientMessage query = new QueryPacket("SELECT @@auto_increment_increment");
List<Completion> res = execute(query, true);
ResultSet rs = (ResultSet) res.get(0);
if (rs.next()) {
context.setAutoIncrement(rs.getLong(1));
}
}
}
}
/**
* Creates a query string for setting session variables based on context and configuration.
*
* @param context the connection context
* @return query string for setting session variables, or null if no variables need to be set
*/
public String createSessionVariableQuery(Context context) {
List<String> sessionCommands = new ArrayList<>();
addAutoCommitCommand(context, sessionCommands);
addTruncationCommand(sessionCommands);
addSessionTrackingCommand(context, sessionCommands);
addTimeZoneCommand(context, sessionCommands);
addTransactionIsolationCommand(context, sessionCommands);
addReadOnlyCommand(context, sessionCommands);
addCharsetCommand(context, sessionCommands);
addCustomSessionVariables(sessionCommands);
return buildFinalQuery(sessionCommands);
}
private void addAutoCommitCommand(Context context, List<String> commands) {
boolean canRelyOnConnectionFlag = isReliableConnectionFlag(context);
if (isAutoCommitUpdateRequired(context, canRelyOnConnectionFlag)) {
boolean autoCommitValue = conf.autocommit() == null || conf.autocommit();
commands.add("autocommit=" + (autoCommitValue ? "1" : "0"));
}
}
private boolean isReliableConnectionFlag(Context context) {
return context.getVersion().isMariaDBServer()
&& (context.getVersion().versionFixedMajorMinorGreaterOrEqual(10, 4, 33)
|| context.getVersion().versionFixedMajorMinorGreaterOrEqual(10, 5, 24)
|| context.getVersion().versionFixedMajorMinorGreaterOrEqual(10, 6, 17)
|| context.getVersion().versionFixedMajorMinorGreaterOrEqual(10, 11, 7)
|| context.getVersion().versionFixedMajorMinorGreaterOrEqual(11, 0, 5)
|| context.getVersion().versionFixedMajorMinorGreaterOrEqual(11, 1, 4)
|| context.getVersion().versionFixedMajorMinorGreaterOrEqual(11, 2, 3));
}
private boolean isAutoCommitUpdateRequired(Context context, boolean canRelyOnConnectionFlag) {
return (conf.autocommit() == null && (context.getServerStatus() & ServerStatus.AUTOCOMMIT) == 0)
|| (conf.autocommit() != null && !canRelyOnConnectionFlag)
|| (conf.autocommit() != null
&& canRelyOnConnectionFlag
&& ((context.getServerStatus() & ServerStatus.AUTOCOMMIT) > 0) != conf.autocommit());
}
private void addTruncationCommand(List<String> commands) {
if (conf.jdbcCompliantTruncation()) {
commands.add("sql_mode=CONCAT(@@sql_mode,',STRICT_TRANS_TABLES')");
}
}
private void addSessionTrackingCommand(Context context, List<String> commands) {
if (!isSessionTrackingSupported(context)) {
return;
}
StringBuilder concatValues =
new StringBuilder(",")
.append(
context.canUseTransactionIsolation() ? "transaction_isolation" : "tx_isolation");
if (conf.returnMultiValuesGeneratedIds()) {
concatValues.append(",auto_increment_increment");
}
commands.add(
"session_track_system_variables = CONCAT(@@global.session_track_system_variables,'"
+ concatValues
+ "')");
}
private boolean isSessionTrackingSupported(Context context) {
return context.hasClientCapability(Capabilities.CLIENT_SESSION_TRACK)
&& ((context.getVersion().isMariaDBServer()
&& (context.getVersion().versionGreaterOrEqual(10, 2, 2)))
|| context.getVersion().versionGreaterOrEqual(5, 7, 0));
}
private void addTimeZoneCommand(Context context, List<String> commands) {
if (!Boolean.TRUE.equals(conf.forceConnectionTimeZoneToSession())) {
return;
}
TimeZone connectionTz = context.getConnectionTimeZone();
ZoneId connectionZoneId = connectionTz.toZoneId();
if (connectionZoneId.normalized().equals(TimeZone.getDefault().toZoneId())) {
return;
}
if (connectionZoneId.getRules().isFixedOffset()) {
addFixedOffsetTimeZone(connectionZoneId, commands);
} else {
commands.add("time_zone='" + connectionZoneId.normalized() + "'");
}
}
private void addFixedOffsetTimeZone(ZoneId connectionZoneId, List<String> commands) {
ZoneOffset zoneOffset = connectionZoneId.getRules().getOffset(Instant.now());
if (zoneOffset.getTotalSeconds() == 0) {
commands.add("time_zone='+00:00'");
} else {
commands.add("time_zone='" + zoneOffset.getId() + "'");
}
}
private void addTransactionIsolationCommand(Context context, List<String> commands) {
if (conf.transactionIsolation() != null) {
String isolationVariable =
context.canUseTransactionIsolation() ? "transaction_isolation" : "tx_isolation";
commands.add(
String.format(
"@@session.%s='%s'", isolationVariable, conf.transactionIsolation().getValue()));
}
}
private void addReadOnlyCommand(Context context, List<String> commands) {
if (hostAddress != null
&& !hostAddress.primary
&& context.getVersion().versionGreaterOrEqual(5, 6, 5)) {
String readOnlyVariable =
context.canUseTransactionIsolation() ? "transaction_read_only" : "tx_read_only";
commands.add(String.format("@@session.%s=1", readOnlyVariable));
}
}
private void addCharsetCommand(Context context, List<String> commands) {
if (!isDefaultCharsetSufficient(context)) {
StringBuilder charsetCommand = new StringBuilder("NAMES utf8mb4");
if (conf.connectionCollation() != null) {
charsetCommand.append(" COLLATE ").append(conf.connectionCollation());
}
commands.add(charsetCommand.toString());
}
}
private boolean isDefaultCharsetSufficient(Context context) {
return context.getCharset() != null
&& "utf8mb4".equals(context.getCharset())
&& conf.connectionCollation() == null;
}
private void addCustomSessionVariables(List<String> commands) {
if (conf.sessionVariables() != null) {
commands.add(Security.parseSessionVariables(conf.sessionVariables()));
}
}
private String buildFinalQuery(List<String> commands) {
return commands.isEmpty() ? null : "set " + String.join(",", commands);
}
public void setReadOnly(boolean readOnly) throws SQLException {
if (closed) {
throw new SQLNonTransientConnectionException("Connection is closed", "08000", 1220);
}
}
/**
* Send client message to server
*
* @param message client message
* @return number of command send
* @throws SQLException if socket error occurs
*/
public int sendQuery(ClientMessage message) throws SQLException {
checkNotClosed();
try {
if (logger.isDebugEnabled() && message.description() != null) {
logger.debug("execute query: {}", message.description());
}
return message.encode(writer, context);
} catch (MaxAllowedPacketException maxException) {
if (maxException.isMustReconnect()) {
destroySocket();
throw exceptionFactory
.withSql(message.description())
.create(
"Packet too big for current server max_allowed_packet value",
"08000",
maxException);
}
throw exceptionFactory
.withSql(message.description())
.create(
"Packet too big for current server max_allowed_packet value", "HZ000", maxException);
} catch (IOException ioException) {
destroySocket();
throw exceptionFactory
.withSql(message.description())
.create("Socket error", "08000", ioException);
}
}
public List<Completion> execute(ClientMessage message, boolean canRedo) throws SQLException {
return execute(
message,
null,
0,
0L,
ResultSet.CONCUR_READ_ONLY,
ResultSet.TYPE_FORWARD_ONLY,
false,
canRedo);
}
public List<Completion> execute(
ClientMessage message, org.mariadb.jdbc.Statement stmt, boolean canRedo) throws SQLException {
return execute(
message,
stmt,
0,
0L,
ResultSet.CONCUR_READ_ONLY,
ResultSet.TYPE_FORWARD_ONLY,
false,
canRedo);
}
public List<Completion> executePipeline(
ClientMessage[] messages,
org.mariadb.jdbc.Statement stmt,
int fetchSize,
long maxRows,
int resultSetConcurrency,
int resultSetType,
boolean closeOnCompletion,
boolean canRedo)
throws SQLException {
List<Completion> results = new ArrayList<>();
int perMsgCounter = 0;
int readCounter = 0;
int[] responseMsg = new int[messages.length];
try {
if (disablePipeline) {
for (readCounter = 0; readCounter < messages.length; readCounter++) {
results.addAll(
execute(
messages[readCounter],
stmt,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion,
canRedo));
}
} else {
for (int i = 0; i < messages.length; i++) {
responseMsg[i] = sendQuery(messages[i]);
}
while (readCounter < messages.length) {
readCounter++;
for (perMsgCounter = 0; perMsgCounter < responseMsg[readCounter - 1]; perMsgCounter++) {
results.addAll(
readResponse(
stmt,
messages[readCounter - 1],
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion));
}
}
}
return results;
} catch (SQLException sqlException) {
if (!closed) {
results.add(null);
// read remaining results
perMsgCounter++;
for (; perMsgCounter < responseMsg[readCounter - 1]; perMsgCounter++) {
try {
results.addAll(
readResponse(
stmt,
messages[readCounter - 1],
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion));
} catch (SQLException e) {
// eat
}
}
for (int i = readCounter; i < messages.length; i++) {
for (int j = 0; j < responseMsg[i]; j++) {
try {
results.addAll(
readResponse(
stmt,
messages[i],
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion));
} catch (SQLException e) {
results.add(null);
}
}
}
// prepare associated to PrepareStatement need to be uncached
for (Completion result : results) {
if (result instanceof PrepareResultPacket && stmt instanceof ServerPreparedStatement) {
try {
((PrepareResultPacket) result).decrementUse(this, (ServerPreparedStatement) stmt);
} catch (SQLException e) {
// eat
}
}
}
}
int batchUpdateLength = 0;
for (ClientMessage message : messages) {
batchUpdateLength += message.batchUpdateLength();
}
throw exceptionFactory.createBatchUpdate(
results, batchUpdateLength, responseMsg, sqlException);
}
}
public List<Completion> execute(
ClientMessage message,
org.mariadb.jdbc.Statement stmt,
int fetchSize,
long maxRows,
int resultSetConcurrency,
int resultSetType,
boolean closeOnCompletion,
boolean canRedo)
throws SQLException {
int nbResp = sendQuery(message);
if (nbResp == 1) {
return readResponse(
stmt,
message,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion);
} else {
if (streamStmt != null) {
streamStmt.fetchRemaining();
streamStmt = null;
}
List<Completion> completions = new ArrayList<>();
try {
while (nbResp-- > 0) {
readResults(
stmt,
message,
completions,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion);
}
return completions;
} catch (SQLException e) {
while (nbResp-- > 0) {
try {
readResults(
stmt,
message,
completions,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion);
} catch (SQLException ee) {
// eat
}
}
throw e;
}
}
}
/**
* Read server responses for a client message
*
* @param stmt statement that issue the message
* @param message client message sent
* @param fetchSize fetch size
* @param maxRows maximum number of rows
* @param resultSetConcurrency concurrency
* @param resultSetType result-set type
* @param closeOnCompletion close statement on resultset completion
* @return list of result
* @throws SQLException if any error occurs
*/
public List<Completion> readResponse(
org.mariadb.jdbc.Statement stmt,
ClientMessage message,
int fetchSize,
long maxRows,
int resultSetConcurrency,
int resultSetType,
boolean closeOnCompletion)
throws SQLException {
checkNotClosed();
if (streamStmt != null) {
streamStmt.fetchRemaining();
streamStmt = null;
}
List<Completion> completions = new ArrayList<>();
readResults(
stmt,
message,
completions,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion);
return completions;
}
/**
* Read server response
*
* @param message client message that was sent
* @throws SQLException if any error occurs
*/
public void readResponse(ClientMessage message) throws SQLException {
checkNotClosed();
if (streamStmt != null) {
streamStmt.fetchRemaining();
streamStmt = null;
}
List<Completion> completions = new ArrayList<>();
readResults(
null,
message,
completions,
0,
0L,
ResultSet.CONCUR_READ_ONLY,
ResultSet.TYPE_FORWARD_ONLY,
false);
}
public void closePrepare(Prepare prepare) throws SQLException {
checkNotClosed();
try {
new ClosePreparePacket(prepare.getStatementId()).encode(writer, context);
} catch (IOException ioException) {
destroySocket();
throw exceptionFactory.create(
"Socket error during post connection queries: " + ioException.getMessage(),
"08000",
ioException);
}
}
public void readStreamingResults(
List<Completion> completions,
int fetchSize,
long maxRows,
int resultSetConcurrency,
int resultSetType,
boolean closeOnCompletion)
throws SQLException {
if (streamStmt != null) {
readResults(
streamStmt,
streamMsg,
completions,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion);
}
}
private void readResults(
org.mariadb.jdbc.Statement stmt,
ClientMessage message,
List<Completion> completions,
int fetchSize,
long maxRows,
int resultSetConcurrency,
int resultSetType,
boolean closeOnCompletion)
throws SQLException {
completions.add(
readPacket(
stmt,
message,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion));
while ((context.getServerStatus() & ServerStatus.MORE_RESULTS_EXISTS) > 0) {
completions.add(
readPacket(
stmt,
message,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion));
}
}
/**
* Read a MySQL packet from socket
*
* @param message client message issuing the result
* @return a mysql result
* @throws SQLException if any error occurs
*/
public Completion readPacket(ClientMessage message) throws SQLException {
return readPacket(
null, message, 0, 0L, ResultSet.CONCUR_READ_ONLY, ResultSet.TYPE_FORWARD_ONLY, false);
}
/**
* Read server response packet.
*
* @see <a href="https://mariadb.com/kb/en/mariadb/4-server-response-packets/">server response
* packets</a>
* @param stmt current statement (null if internal)
* @param message current message
* @param fetchSize default fetch size
* @param maxRows maximum row number
* @param resultSetConcurrency concurrency
* @param resultSetType type
* @param closeOnCompletion must resultset close statement on completion
* @return Completion
* @throws SQLException if any exception
*/
public Completion readPacket(
org.mariadb.jdbc.Statement stmt,
ClientMessage message,
int fetchSize,
long maxRows,
int resultSetConcurrency,
int resultSetType,
boolean closeOnCompletion)
throws SQLException {
try {
boolean traceEnable = logger.isTraceEnabled();
Completion completion =
message.readPacket(
stmt,
fetchSize,
maxRows,
resultSetConcurrency,
resultSetType,
closeOnCompletion,
reader,
writer,
context,
exceptionFactory,
lock,
traceEnable,
message,
redirectConsumer);
if (completion instanceof StreamingResult && !((StreamingResult) completion).loaded()) {
streamStmt = stmt;
streamMsg = message;
}
return completion;
} catch (SocketTimeoutException ste) {
destroySocket();
throw exceptionFactory
.withSql(message.description())
.create("Socket timout error", "08000", ste);
} catch (IOException ioException) {
destroySocket();
throw exceptionFactory
.withSql(message.description())
.create("Socket error", "08000", ioException);
}
}
/**
* Throw an exception if client is closed
*
* @throws SQLException if closed
*/
protected void checkNotClosed() throws SQLException {
if (closed) {
throw exceptionFactory.create("Connection is closed", "08000", 1220);
}
}
private void closeSocket() {
try {
try {
long maxCurrentMillis = System.currentTimeMillis() + 10;
socket.shutdownOutput();
socket.setSoTimeout(3);
InputStream is = socket.getInputStream();
//noinspection StatementWithEmptyBody
while (is.read() != -1 && System.currentTimeMillis() < maxCurrentMillis) {
// read byte
}
} catch (Throwable t) {
// eat exception
}
writer.close();
reader.close();
} catch (IOException e) {
// eat
} finally {
try {
socket.close();
} catch (IOException e) {
// socket closed, if any error, so not throwing error
}
}
}
public boolean isClosed() {
return closed;
}
public Context getContext() {
return context;
}
public void abort(Executor executor) throws SQLException {
if (executor == null) {
throw exceptionFactory.create("Cannot abort the connection: null executor passed");
}
// fireConnectionClosed(new ConnectionEvent(this));
boolean lockStatus = lock.tryLock();
if (!this.closed) {
this.closed = true;
logger.debug("aborting connection {}", context.getThreadId());
if (!lockStatus) {
// lock not available : query is running
// force end by executing an KILL connection
try (StandardClient cli = new StandardClient(conf, hostAddress, new ClosableLock(), true)) {
cli.execute(new QueryPacket("KILL " + context.getThreadId()), false);
} catch (SQLException e) {
// eat
}
} else {
try {
QuitPacket.INSTANCE.encode(writer, context);
} catch (IOException e) {
// eat
}
}
if (streamStmt != null) {
streamStmt.abort();
}
closeSocket();
}
if (lockStatus) {
lock.unlock();
}
}
public int getSocketTimeout() {
return this.socketTimeout;
}
public void setSocketTimeout(int milliseconds) throws SQLException {
try {
socketTimeout = milliseconds;
socket.setSoTimeout(milliseconds);
} catch (SocketException se) {
throw exceptionFactory.create("Cannot set the network timeout", "42000", se);
}
}
public void close() {
boolean locked = lock.tryLock();
if (!this.closed) {
this.closed = true;
try {
QuitPacket.INSTANCE.encode(writer, context);
} catch (IOException e) {
// eat
}
closeSocket();
}
if (locked) {
lock.unlock();
}
}
public String getSocketIp() {
return this.socket.getInetAddress() == null
? null
: this.socket.getInetAddress().getHostAddress();
}
public boolean isPrimary() {
return hostAddress.primary;
}
public ExceptionFactory getExceptionFactory() {
return exceptionFactory;
}
public HostAddress getHostAddress() {
return hostAddress;
}
public void reset() {
context.resetStateFlag();
context.resetPrepareCache();
}
}