Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1947501: Create AST API in JDBC #2112

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/java/net/snowflake/client/core/QueryExecDTO.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public class QueryExecDTO {

public QueryExecDTO(
String sqlText,
String dataframeAst,
boolean describeOnly,
Integer sequenceId,
Map<String, ParameterBindingDTO> bindings,
Expand All @@ -50,6 +51,7 @@ public QueryExecDTO(
boolean internal,
boolean asyncExec) {
this.sqlText = sqlText;
this.dataframeAst = dataframeAst;
this.describeOnly = describeOnly;
this.sequenceId = sequenceId;
this.bindings = bindings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ public abstract SFBaseResultSet execute(
throws SQLException, SFException;

/**
* Executes the given SQL string.
* Executes the given SQL string, with dataframe AST parameter.
*
* @param sql The SQL string to execute, synchronously.
* @param dataframeAst ...
* @param dataframeAst encoded string representation of the dataframe AST
* @param parametersBinding parameters to bind
* @param caller the JDBC interface method that called this method, if any
* @param execTimeData OOB telemetry object to record timings
Expand Down
63 changes: 35 additions & 28 deletions src/main/java/net/snowflake/client/core/SFStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ private void sanityCheckQuery(String sql) throws SQLException {
* Execute SQL query with an option for describe only
*
* @param sql sql statement
* @param dataframeAst ...
* @param dataframeAst encoded string representation of the dataframe AST
* @param describeOnly true if describe only
* @return query result set
* @throws SQLException if connection is already closed
Expand All @@ -130,23 +130,26 @@ private SFBaseResultSet executeQuery(
CallingMethod caller,
ExecTimeTelemetryData execTimeData)
throws SQLException, SFException {
sanityCheckQuery(sql);
// if dataframeAst is passed, then can skip sql checks
if (dataframeAst == null) {
sanityCheckQuery(sql);

String trimmedSql = sql.trim();

// snowflake specific client side commands
if (isFileTransfer(trimmedSql)) {
// Server side value or Connection string value is false then disable the PUT/GET command
if ((session != null && !(session.getJdbcEnablePutGet() && session.getEnablePutGet()))) {
// PUT/GET command disabled either on server side or in the client connection string
logger.debug("Executing file transfer locally is disabled: {}", sql);
throw new SnowflakeSQLException("File transfers have been disabled.");
}

String trimmedSql = sql.trim();
// PUT/GET command
logger.debug("Executing file transfer locally: {}", sql);

// snowflake specific client side commands
if (isFileTransfer(trimmedSql)) {
// Server side value or Connection string value is false then disable the PUT/GET command
if ((session != null && !(session.getJdbcEnablePutGet() && session.getEnablePutGet()))) {
// PUT/GET command disabled either on server side or in the client connection string
logger.debug("Executing file transfer locally is disabled: {}", sql);
throw new SnowflakeSQLException("File transfers have been disabled.");
return executeFileTransfer(sql);
}

// PUT/GET command
logger.debug("Executing file transfer locally: {}", sql);

return executeFileTransfer(sql);
}

// NOTE: It is intentional two describeOnly parameters are specified.
Expand Down Expand Up @@ -191,6 +194,7 @@ public SFPreparedStatementMetaData describe(String sql) throws SFException, SQLE
* <p>
*
* @param sql sql statement
* @param dataframeAst encoded string representation of the dataframe AST
* @param parameterBindings binding information
* @param describeOnly true if just showing result set metadata
* @param internal true if internal command not showing up in the history
Expand Down Expand Up @@ -321,6 +325,7 @@ public Void call() throws SQLException {
* A helper method to build URL and submit the SQL to snowflake for exec
*
* @param sql sql statement
* @param dataframeAst encoded string representation of the dataframe AST
* @param mediaType media type
* @param bindValues map of binding values
* @param describeOnly whether only show the result set metadata
Expand Down Expand Up @@ -767,7 +772,7 @@ private void cancelHelper(String sql, String mediaType, CancellationReason cance
* Execute sql
*
* @param sql sql statement.
* @param dataframeAst ...
* @param dataframeAst encoded string representation of the dataframe AST
* @param asyncExec is async exec
* @param parametersBinding parameters to bind
* @param caller the JDBC interface method that called this method, if any
Expand All @@ -787,22 +792,24 @@ public SFBaseResultSet execute(
throws SQLException, SFException {
TelemetryService.getInstance().updateContext(session.getSnowflakeConnectionString());

// todo: if (dataframeAst == null)
sanityCheckQuery(sql);
// if dataframeAst is passed, then no need for sql checks and can skip
if (dataframeAst == null) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write a comment above this line, to mention why we skip those code.

sanityCheckQuery(sql);

session.injectedDelay();
session.injectedDelay();

if (session.getPreparedStatementLogging()) {
logger.info("Execute: {}", sql);
} else {
logger.debug("Execute: {}", sql);
}
if (session.getPreparedStatementLogging()) {
logger.info("Execute: {}", sql);
} else {
logger.debug("Execute: {}", sql);
}

String trimmedSql = sql.trim();
String trimmedSql = sql.trim();

if (trimmedSql.length() >= 20 && trimmedSql.toLowerCase().startsWith("set-sf-property")) {
executeSetProperty(sql);
return null;
if (trimmedSql.length() >= 20 && trimmedSql.toLowerCase().startsWith("set-sf-property")) {
executeSetProperty(sql);
return null;
}
}
return executeQuery(sql, dataframeAst, parametersBinding, false, asyncExec, caller, execTimeData);
}
Expand Down
1 change: 1 addition & 0 deletions src/main/java/net/snowflake/client/core/StmtUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ public static StmtOutput execute(StmtInput stmtInput, ExecTimeTelemetryData exec
QueryExecDTO sqlJsonBody =
new QueryExecDTO(
stmtInput.sql,
stmtInput.dataframeAst,
stmtInput.describeOnly,
stmtInput.sequenceId,
stmtInput.bindValues,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,13 @@ public ResultSet executeAsyncQuery(String sql) throws SQLException {
return rs;
}

// todo: add doc
/**
* Execute dataframeAst query
*
* @param dataframeAst encoded string representation of the dataframe AST
* @return ResultSet
* @throws SQLException if @link{#executeQueryInternal(String, Map)} throws an exception
*/
public ResultSet executeDataframeAst(String dataframeAst) throws SQLException {
ExecTimeTelemetryData execTimeData =
new ExecTimeTelemetryData("ResultSet Statement.executeQuery(String)", this.batchID);
Expand Down Expand Up @@ -284,6 +290,7 @@ private void setQueryIdWhenValidOrNull(String queryId) {
* Internal method for executing a query with bindings accepted.
*
* @param sql sql statement
* @param dataframeAst encoded string representation of the dataframe AST
* @param asyncExec execute query asynchronously
* @param parameterBindings parameters bindings
* @return query result set
Expand All @@ -309,7 +316,7 @@ ResultSet executeQueryInternal(
} else {
sfResultSet =
sfBaseStatement.execute(
sql, parameterBindings, SFBaseStatement.CallingMethod.EXECUTE_QUERY, execTimeData);
sql, dataframeAst, parameterBindings, SFBaseStatement.CallingMethod.EXECUTE_QUERY, execTimeData);
resultSetMetadataHandler(sfResultSet);
}
sfResultSet.setSession(this.connection.getSFBaseSession());
Expand Down
71 changes: 71 additions & 0 deletions src/test/java/net/snowflake/client/jdbc/DataframeAstTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package net.snowflake.client.jdbc;

import com.fasterxml.jackson.databind.ObjectMapper;
import net.snowflake.client.core.*;
import org.apache.http.client.methods.HttpPost;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.mock;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.sql.ResultSet;
import java.sql.SQLException;

public class DataframeAstTest {

@Test
public void testSendAst() throws SQLException, IOException {
String ast = "dummyAst";
SnowflakeConnectionV1 mockedConn = mock(SnowflakeConnectionV1.class);
SFConnectionHandler mockedHandler = mock(SFConnectionHandler.class);
when(mockedConn.getHandler()).thenReturn(mockedHandler);
HttpClientSettingsKey mockedHttpClientSettingsKey = mock(HttpClientSettingsKey.class);
SFSession mockedSession = mock(SFSession.class);
SFBaseStatement sfBaseStatement = new SFStatement(mockedSession);
when(mockedHandler.getSFStatement()).thenReturn(sfBaseStatement);
when(mockedSession.getServerUrl()).thenReturn("dummy");
when(mockedSession.getHttpClientKey()).thenReturn(mockedHttpClientSettingsKey);
when(mockedHttpClientSettingsKey.getGzipDisabled()).thenReturn(true);

ArgumentCaptor<HttpPost> captor = ArgumentCaptor.forClass(HttpPost.class);
try(MockedStatic<HttpUtil> mockedHttpUtil = Mockito.mockStatic(HttpUtil.class)) {
mockedHttpUtil.when(() -> HttpUtil.executeRequest(
captor.capture(),
anyInt(),
anyInt(),
anyInt(),
anyInt(),
anyInt(),
any(),
anyBoolean(),
anyBoolean(),
any(),
any()
)).thenReturn("dummy");
SnowflakeStatementV1 stmt = new SnowflakeStatementV1(
mockedConn,
ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY,
ResultSet.CLOSE_CURSORS_AT_COMMIT);
stmt.executeDataframeAst(ast);
} catch (Exception e) {
// do nothing, let it terminate early
}
HttpPost result = captor.getValue();
ObjectMapper mapper = new ObjectMapper();
ByteArrayInputStream bais = (ByteArrayInputStream) result.getEntity().getContent();
int size = bais.available();
byte[] buffer = new byte[size];
bais.read(buffer);
String resultStr = new String(buffer, StandardCharsets.UTF_8);
assert mapper.readTree(resultStr).get("dataframeAst").asText().equals(ast);
}
}
11 changes: 11 additions & 0 deletions src/test/java/net/snowflake/client/jdbc/MockConnectionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,17 @@ public SFBaseResultSet execute(
return new MockJsonResultSet(mockedResponse, sfSession);
}

@Override
public SFBaseResultSet execute(
String sql,
String dataframeAst,
Map<String, ParameterBindingDTO> parametersBinding,
CallingMethod caller,
ExecTimeTelemetryData execTimeData)
throws SQLException, SFException {
return new MockJsonResultSet(mockedResponse, sfSession);
}

@Override
public SFBaseResultSet asyncExecute(
String sql,
Expand Down