Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: CONJ-1238 rewriteBatchStatements #202

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
100 changes: 98 additions & 2 deletions src/main/java/org/mariadb/jdbc/ClientPreparedStatement.java
Original file line number Diff line number Diff line change
@@ -5,13 +5,18 @@

import static org.mariadb.jdbc.util.constants.Capabilities.*;

import java.io.IOException;
import java.sql.*;
import java.util.*;
import org.mariadb.jdbc.client.ColumnDecoder;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.result.CompleteResult;
import org.mariadb.jdbc.client.result.Result;
import org.mariadb.jdbc.client.socket.impl.ByteCountingWriter;
import org.mariadb.jdbc.client.util.ClosableLock;
import org.mariadb.jdbc.client.util.Parameters;
import org.mariadb.jdbc.export.ExceptionFactory;
import org.mariadb.jdbc.export.MaxAllowedPacketException;
import org.mariadb.jdbc.message.ClientMessage;
import org.mariadb.jdbc.message.client.*;
import org.mariadb.jdbc.message.server.OkPacket;
@@ -48,7 +53,11 @@
super(sql, con, lock, autoGeneratedKeys, resultSetType, resultSetConcurrency, defaultFetchSize);
boolean noBackslashEscapes =
(con.getContext().getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) > 0;
parser = ClientParser.parameterParts(sql, noBackslashEscapes);
if (con.getContext().getConf().rewriteBatchedStatements()) {
parser = ClientParser.rewritableParts(sql, noBackslashEscapes);
} else {
parser = ClientParser.parameterParts(sql, noBackslashEscapes);
}
parameters = new ParameterList(parser.getParamCount());
}

@@ -133,8 +142,15 @@
if (canUseBulk && batchParameters.size() > 1 && !clientParser.isMultiQuery()) {
executeBatchBulk(escapeTimeout(sql));
return true;
} else if (conf.rewriteBatchedStatements()
&& parser.getValuesBracketPositions() != null
&& autoGeneratedKeys == Statement.NO_GENERATED_KEYS
&& batchParameters.size() > 1) {
// values rewritten in one query :
// INSERT INTO X(a,b) VALUES (1,2),(3,4), ...
executeBatchRewrite();
} else {
executeBatchPipeline();
executeBatchPipeline();
}
}

@@ -168,6 +184,86 @@
throw bue;
}
}

/**
* Send n * COM_QUERY + n * read answer
*
* @throws SQLException if IOException / Command error
*/
private void executeBatchRewrite() throws SQLException {
ClientMessage[] packets = getBatchPackets(preSqlCmd(), parser, batchParameters, con.getContext());

try {
results =
con.getClient()
.executePipeline(
packets,
this,
0,
maxRows,
ResultSet.CONCUR_READ_ONLY,
ResultSet.TYPE_FORWARD_ONLY,
closeOnCompletion,
false);
} catch (SQLException bue) {
results = null;
throw bue;
}
}

public static ClientMessage[] getBatchPackets(String preSqlCmd, ClientParser parser, List<Parameters> parametersList, Context context) throws SQLException {
Configuration conf = context.getConf();
List<ClientMessage> messages = new ArrayList<>();
int startValuePos = parser.getValuesBracketPositions().get(0);
int endValuePos = parser.getValuesBracketPositions().get(1);

int staticLength = startValuePos + (parser.getQuery().length - endValuePos); // non-repeating sections
int repeatLength = (endValuePos - startValuePos - parser.getParamPositions().size()); // repeating section, excluding parameters

int startIndex = 0;
int currentIndex = 0;
int totalParameterList = parametersList.size();
int maxPacketLength = conf.maxAllowedPacket() == null ? 0x00ffffff + 4 : conf.maxAllowedPacket(); // from PacketWriter.MAX_PACKET_LENGTH

int packetLength = 1 + staticLength;
do {
Parameters parameters = parametersList.get(currentIndex);
packetLength += repeatLength;
ByteCountingWriter byteCountingWriter = new ByteCountingWriter();
for (int i = 0; i < parser.getParamPositions().size() && packetLength <= maxPacketLength; i++) {
try {
parameters.get(i).encodeText(byteCountingWriter, context); // ensureReplayable ?
} catch (IOException e) {

Check warning on line 236 in src/main/java/org/mariadb/jdbc/ClientPreparedStatement.java

Codecov / codecov/patch

src/main/java/org/mariadb/jdbc/ClientPreparedStatement.java#L236

Added line #L236 was not covered by tests
// cannot get an IOException writing to a ByteCountingWriter
throw new IllegalStateException("IOException in ByteCountingWriter", e);

Check warning on line 238 in src/main/java/org/mariadb/jdbc/ClientPreparedStatement.java

Codecov / codecov/patch

src/main/java/org/mariadb/jdbc/ClientPreparedStatement.java#L238

Added line #L238 was not covered by tests
}
packetLength += byteCountingWriter.getByteCount();
byteCountingWriter.resetCount();
}
if (packetLength >= maxPacketLength) {
if (startIndex == currentIndex) {
// exceeded maxPacketLength for a single set of parameters, throw an exception
// this is how MaxAllowedPacketExceptions are converted to SqlExceptions in StandardClient
MaxAllowedPacketException mape = new MaxAllowedPacketException("query size is >= to max_allowed_packet", false);
throw context.getExceptionFactory()
.withSql(parser.getSql())
.create(

Check warning on line 250 in src/main/java/org/mariadb/jdbc/ClientPreparedStatement.java

Codecov / codecov/patch

src/main/java/org/mariadb/jdbc/ClientPreparedStatement.java#L247-L250

Added lines #L247 - L250 were not covered by tests
"Packet too big for current server max_allowed_packet value", "HZ000", mape);
}
messages.add(new QueryWithParametersPacket(preSqlCmd, parser, parametersList.subList(startIndex, currentIndex)));
packetLength = 1 + staticLength;
startIndex = currentIndex;
} else {
packetLength += 1; // the comma between VALUES() blocks
}
currentIndex++;
} while (currentIndex < totalParameterList);
if (startIndex != currentIndex) {
messages.add(new QueryWithParametersPacket(preSqlCmd, parser, parametersList.subList(startIndex, currentIndex)));
}

return messages.toArray(new ClientMessage[] {} );
}

/**
* Send n * (COM_QUERY + read answer)
39 changes: 37 additions & 2 deletions src/main/java/org/mariadb/jdbc/Configuration.java
Original file line number Diff line number Diff line change
@@ -153,6 +153,7 @@ public class Configuration {
// protocol
private boolean allowMultiQueries;
private boolean allowLocalInfile;
private boolean rewriteBatchedStatements;
private boolean useCompression;
private boolean useAffectedRows;
private boolean useBulkStmts;
@@ -334,14 +335,21 @@ private void initializeTimezoneConfig(Builder builder) {
}

private void initializeQueryConfig(Builder builder) {
this.dumpQueriesOnException =
this.dumpQueriesOnException =
builder.dumpQueriesOnException != null && builder.dumpQueriesOnException;
this.prepStmtCacheSize = builder.prepStmtCacheSize != null ? builder.prepStmtCacheSize : 250;
this.useAffectedRows = builder.useAffectedRows != null && builder.useAffectedRows;
this.useServerPrepStmts = builder.useServerPrepStmts != null && builder.useServerPrepStmts;
this.rewriteBatchedStatements = builder.rewriteBatchedStatements != null && builder.rewriteBatchedStatements;
// disable use server prepare if using client rewrite
if (this.rewriteBatchedStatements) {
this.useServerPrepStmts = false;
} else {
this.useServerPrepStmts = builder.useServerPrepStmts != null && builder.useServerPrepStmts;
}
this.connectionAttributes = builder.connectionAttributes;
this.allowLocalInfile = builder.allowLocalInfile == null || builder.allowLocalInfile;
this.allowMultiQueries = builder.allowMultiQueries != null && builder.allowMultiQueries;

}

private void initializeBulkConfig(Builder builder) {
@@ -1832,6 +1840,10 @@ public int prepStmtCacheSize() {
public boolean useAffectedRows() {
return useAffectedRows;
}

public boolean rewriteBatchedStatements() {
return rewriteBatchedStatements;
}

/**
* Use server prepared statement. IF false, using client prepared statement.
@@ -2345,6 +2357,7 @@ public static final class Builder implements Cloneable {
// protocol
private Boolean allowMultiQueries;
private Boolean allowLocalInfile;
private Boolean rewriteBatchedStatements;
private Boolean useCompression;
private Boolean useAffectedRows;
private Boolean useBulkStmts;
@@ -2863,6 +2876,28 @@ public Builder allowLocalInfile(Boolean allowLocalInfile) {
return this;
}

/**
* For insert queries, rewrite batchedStatement to execute in a single executeQuery.
*
* <p>example: "insert into ab (i) values (?)" with first batch values = 1, second = 2 will be
* rewritten "insert into ab (i) values (1), (2)".
*
* <p>If query cannot be rewriten in "multi-values", rewrite will use multi-queries : "INSERT
* INTO TABLE(col1) VALUES (?) ON DUPLICATE KEY UPDATE col2=?" with values [1,2] and [3,4] will
* be rewritten "INSERT INTO TABLE(col1) VALUES (1) ON DUPLICATE KEY UPDATE col2=2;INSERT INTO
* TABLE(col1) VALUES (3) ON DUPLICATE KEY UPDATE col2=4"
*
* <p>when active, the useServerPrepStmts option is set to false
*
* @param rewriteBatchedStatements to enable/disable rewrite
* @return this {@link Builder}
*/
public Builder rewriteBatchedStatements(Boolean rewriteBatchedStatements) {
this.rewriteBatchedStatements = rewriteBatchedStatements;
return this;
}


/**
* Indicate to compress exchanges with the database through gzip. This permits better
* performance when the database is not in the same location.
Original file line number Diff line number Diff line change
@@ -193,7 +193,7 @@ private static long applyOptionalCapabilities(long capabilities, Configuration c
capabilities |= Capabilities.FOUND_ROWS;
}

if (configuration.allowMultiQueries()) {
if (configuration.allowMultiQueries() || (configuration.rewriteBatchedStatements())) {
capabilities |= Capabilities.MULTI_STATEMENTS;
}

Original file line number Diff line number Diff line change
@@ -1059,7 +1059,7 @@ public List<Completion> executePipeline(
results.add(null);
// read remaining results
perMsgCounter++;
for (; perMsgCounter < responseMsg[readCounter - 1]; perMsgCounter++) {
for (; readCounter > 0 && perMsgCounter < responseMsg[readCounter - 1]; perMsgCounter++) {
try {
results.addAll(
readResponse(
Loading