diff --git a/e2e-jar-test/pom.xml b/e2e-jar-test/pom.xml index 7acf7a1e9..6e3652a4e 100644 --- a/e2e-jar-test/pom.xml +++ b/e2e-jar-test/pom.xml @@ -29,7 +29,7 @@ net.snowflake snowflake-ingest-sdk - 2.1.2-SNAPSHOT + 2.2.0 diff --git a/pom.xml b/pom.xml index e14af3d8e..b771ac80b 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ net.snowflake snowflake-ingest-sdk - 2.1.2-SNAPSHOT + 2.2.0 jar Snowflake Ingest SDK Snowflake Ingest SDK @@ -45,7 +45,7 @@ 3.14.0 1.3.1 1.11.0 - 2.16.1 + 2.17.0 32.0.1-jre 3.3.6 true @@ -60,13 +60,13 @@ 4.1.94.Final 9.37.3 3.1 - 1.13.1 + 1.14.1 2.0.9 UTF-8 - 3.19.6 + 4.27.2 net.snowflake.ingest.internal 1.7.36 - 1.1.10.4 + 1.1.10.5 3.16.1 0.13.0 @@ -343,13 +343,13 @@ net.bytebuddy byte-buddy - 1.10.19 + 1.14.9 test net.bytebuddy byte-buddy-agent - 1.10.19 + 1.14.9 test @@ -358,6 +358,18 @@ 3.7.7 test + + org.openjdk.jmh + jmh-core + 1.34 + test + + + org.openjdk.jmh + jmh-generator-annprocess + 1.34 + test + @@ -470,6 +482,13 @@ org.apache.parquet parquet-common + + + + javax.annotation + javax.annotation-api + + org.apache.parquet @@ -491,7 +510,7 @@ com.github.luben zstd-jni - 1.5.0-1 + 1.5.6-2 runtime @@ -527,6 +546,16 @@ mockito-core test + + org.openjdk.jmh + jmh-core + test + + + org.openjdk.jmh + jmh-generator-annprocess + test + org.powermock powermock-api-mockito2 @@ -723,8 +752,8 @@ true + to workaround https://issues.apache.org/jira/browse/MNG-7982. Now the dependency analyzer complains that + the dependency is unused, so we ignore it here--> org.apache.commons:commons-compress org.apache.commons:commons-configuration2 @@ -819,9 +848,9 @@ failFast + The list of allowed licenses. If you see the build failing due to "There are some forbidden licenses used, please + check your dependencies", verify the conditions of the license and add the reference to it here. + --> Apache License 2.0 BSD 2-Clause License @@ -844,6 +873,7 @@ BSD 2-Clause License |The BSD License The MIT License|MIT License + 3-Clause BSD License|BSD-3-Clause @@ -1133,9 +1163,9 @@ + Plugin executes license processing Python script, which copies third party license files into the directory + target/generated-licenses-info/META-INF/third-party-licenses, which is then included in the shaded JAR. + --> org.codehaus.mojo exec-maven-plugin diff --git a/scripts/process_licenses.py b/scripts/process_licenses.py index bb43fbbf0..4a0377a8e 100644 --- a/scripts/process_licenses.py +++ b/scripts/process_licenses.py @@ -132,18 +132,22 @@ def main(): dependency_without_license_count += 1 missing_licenses_str += f"{dependency_lookup_key}: {license_name}\n" else: - raise Exception(f"The dependency {dependency_lookup_key} does not ship a license file, but neither is it not defined in ADDITIONAL_LICENSES_MAP") + raise Exception( + f"The dependency {dependency_lookup_key} does not ship a license file, but neither is it not " + f"defined in ADDITIONAL_LICENSES_MAP") with open(Path(target_dir, "ADDITIONAL_LICENCES"), "w") as additional_licenses_handle: additional_licenses_handle.write(missing_licenses_str) if dependency_count < 30: - raise Exception(f"Suspiciously low number of dependency JARs detected in {dependency_jars_path}: {dependency_count}") + raise Exception( + f"Suspiciously low number of dependency JARs detected in {dependency_jars_path}: {dependency_count}") print("License generation finished") print(f"\tTotal dependencies: {dependency_count}") print(f"\tTotal dependencies (with license): {dependency_with_license_count}") print(f"\tTotal dependencies (without license): {dependency_without_license_count}") print(f"\tIgnored dependencies: {dependency_ignored_count}") + if __name__ == "__main__": main() diff --git a/src/main/java/net/snowflake/ingest/connection/OAuthClient.java b/src/main/java/net/snowflake/ingest/connection/OAuthClient.java index 61c736a42..a592899ea 100644 --- a/src/main/java/net/snowflake/ingest/connection/OAuthClient.java +++ b/src/main/java/net/snowflake/ingest/connection/OAuthClient.java @@ -93,6 +93,7 @@ public void refreshToken() throws IOException { /** Helper method for making refresh request */ private HttpUriRequest makeRefreshTokenRequest() { + // TODO SNOW-1538108 Use SnowflakeServiceClient to make the request HttpPost post = new HttpPost(oAuthCredential.get().getOAuthTokenEndpoint()); post.addHeader(HttpHeaders.CONTENT_TYPE, OAUTH_CONTENT_TYPE_HEADER); post.addHeader(HttpHeaders.AUTHORIZATION, oAuthCredential.get().getAuthHeader()); diff --git a/src/main/java/net/snowflake/ingest/connection/RequestBuilder.java b/src/main/java/net/snowflake/ingest/connection/RequestBuilder.java index 55929400d..d2bf317d4 100644 --- a/src/main/java/net/snowflake/ingest/connection/RequestBuilder.java +++ b/src/main/java/net/snowflake/ingest/connection/RequestBuilder.java @@ -110,7 +110,7 @@ public class RequestBuilder { // Don't change! public static final String CLIENT_NAME = "SnowpipeJavaSDK"; - public static final String DEFAULT_VERSION = "2.1.2-SNAPSHOT"; + public static final String DEFAULT_VERSION = "2.2.0"; public static final String JAVA_USER_AGENT = "JAVA"; @@ -678,12 +678,23 @@ public HttpGet generateHistoryRangeRequest( */ public HttpPost generateStreamingIngestPostRequest( String payload, String endPoint, String message) { - LOGGER.debug("Generate Snowpipe streaming request: endpoint={}, payload={}", endPoint, payload); + final String requestId = UUID.randomUUID().toString(); + LOGGER.debug( + "Generate Snowpipe streaming request: endpoint={}, payload={}, requestId={}", + endPoint, + payload, + requestId); // Make the corresponding URI URI uri = null; try { uri = - new URIBuilder().setScheme(scheme).setHost(host).setPort(port).setPath(endPoint).build(); + new URIBuilder() + .setScheme(scheme) + .setHost(host) + .setPort(port) + .setPath(endPoint) + .setParameter(REQUEST_ID, requestId) + .build(); } catch (URISyntaxException e) { throw new SFException(e, ErrorCode.BUILD_REQUEST_FAILURE, message); } diff --git a/src/main/java/net/snowflake/ingest/streaming/OpenChannelRequest.java b/src/main/java/net/snowflake/ingest/streaming/OpenChannelRequest.java index cc8782dbd..4d3ea19aa 100644 --- a/src/main/java/net/snowflake/ingest/streaming/OpenChannelRequest.java +++ b/src/main/java/net/snowflake/ingest/streaming/OpenChannelRequest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming; @@ -150,7 +150,7 @@ public ZoneId getDefaultTimezone() { } public String getFullyQualifiedTableName() { - return String.format("%s.%s.%s", this.dbName, this.schemaName, this.tableName); + return Utils.getFullyQualifiedTableName(this.dbName, this.schemaName, this.tableName); } public OnErrorOption getOnErrorOption() { diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java index 6d5dce17f..71a9d501e 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java @@ -16,7 +16,6 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; -import java.util.stream.Collectors; import net.snowflake.ingest.connection.TelemetryService; import net.snowflake.ingest.streaming.InsertValidationResponse; import net.snowflake.ingest.streaming.OffsetTokenVerificationFunction; @@ -400,10 +399,10 @@ public float getSize() { Set verifyInputColumns( Map row, InsertValidationResponse.InsertError error, int rowIndex) { // Map of unquoted column name -> original column name - Map inputColNamesMap = - row.keySet().stream() - .collect(Collectors.toMap(LiteralQuoteUtils::unquoteColumnName, value -> value)); - + Set originalKeys = row.keySet(); + Map inputColNamesMap = new HashMap<>(); + originalKeys.forEach( + key -> inputColNamesMap.put(LiteralQuoteUtils.unquoteColumnName(key), key)); // Check for extra columns in the row List extraCols = new ArrayList<>(); for (String columnName : inputColNamesMap.keySet()) { diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelCache.java b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelCache.java index 989be0fa1..90c0f2ac9 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelCache.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelCache.java @@ -1,12 +1,15 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; -import java.util.Iterator; +import java.util.Collections; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.SFException; /** * In-memory cache that stores the active channels for a given Streaming Ingest client, and the @@ -23,6 +26,20 @@ class ChannelCache { String, ConcurrentHashMap>> cache = new ConcurrentHashMap<>(); + /** Flush information for each table including last flush time and if flush is needed */ + static class FlushInfo { + final long lastFlushTime; + final boolean needFlush; + + FlushInfo(long lastFlushTime, boolean needFlush) { + this.lastFlushTime = lastFlushTime; + this.needFlush = needFlush; + } + } + + /** Flush information for each table, only used when max chunks in blob is 1 */ + private final ConcurrentHashMap tableFlushInfo = new ConcurrentHashMap<>(); + /** * Add a channel to the channel cache * @@ -33,6 +50,11 @@ void addChannel(SnowflakeStreamingIngestChannelInternal channel) { this.cache.computeIfAbsent( channel.getFullyQualifiedTableName(), v -> new ConcurrentHashMap<>()); + // Update the last flush time for the table, add jitter to avoid all channels flush at the same + // time when the blobs are not interleaved + this.tableFlushInfo.putIfAbsent( + channel.getFullyQualifiedTableName(), new FlushInfo(System.currentTimeMillis(), false)); + SnowflakeStreamingIngestChannelInternal oldChannel = channels.put(channel.getName(), channel); // Invalidate old channel if it exits to block new inserts and return error to users earlier @@ -44,13 +66,84 @@ void addChannel(SnowflakeStreamingIngestChannelInternal channel) { } /** - * Returns an iterator over the (table, channels) in this map. + * Get the last flush time for a table + * + * @param fullyQualifiedTableName fully qualified table name + * @return last flush time in milliseconds + */ + Long getLastFlushTime(String fullyQualifiedTableName) { + FlushInfo tableFlushInfo = this.tableFlushInfo.get(fullyQualifiedTableName); + if (tableFlushInfo == null) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format("Last flush time for table %s not found", fullyQualifiedTableName)); + } + return tableFlushInfo.lastFlushTime; + } + + /** + * Set the last flush time for a table as the current time * - * @return + * @param fullyQualifiedTableName fully qualified table name + * @param lastFlushTime last flush time in milliseconds */ - Iterator>>> - iterator() { - return this.cache.entrySet().iterator(); + void setLastFlushTime(String fullyQualifiedTableName, Long lastFlushTime) { + this.tableFlushInfo.compute( + fullyQualifiedTableName, + (k, v) -> { + if (v == null) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format("Last flush time for table %s not found", fullyQualifiedTableName)); + } + return new FlushInfo(lastFlushTime, v.needFlush); + }); + } + + /** + * Get need flush flag for a table + * + * @param fullyQualifiedTableName fully qualified table name + * @return need flush flag + */ + boolean getNeedFlush(String fullyQualifiedTableName) { + FlushInfo tableFlushInfo = this.tableFlushInfo.get(fullyQualifiedTableName); + if (tableFlushInfo == null) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format("Need flush flag for table %s not found", fullyQualifiedTableName)); + } + return tableFlushInfo.needFlush; + } + + /** + * Set need flush flag for a table + * + * @param fullyQualifiedTableName fully qualified table name + * @param needFlush need flush flag + */ + void setNeedFlush(String fullyQualifiedTableName, boolean needFlush) { + this.tableFlushInfo.compute( + fullyQualifiedTableName, + (k, v) -> { + if (v == null) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format("Need flush flag for table %s not found", fullyQualifiedTableName)); + } + return new FlushInfo(v.lastFlushTime, needFlush); + }); + } + + /** Returns an immutable set view of the mappings contained in the channel cache. */ + Set>>> + entrySet() { + return Collections.unmodifiableSet(cache.entrySet()); + } + + /** Returns an immutable set view of the keys contained in the channel cache. */ + Set keySet() { + return Collections.unmodifiableSet(cache.keySet()); } /** Close all channels in the channel cache */ diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelFlushContext.java b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelFlushContext.java index 3e5265719..fe9542267 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelFlushContext.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelFlushContext.java @@ -1,9 +1,11 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; +import net.snowflake.ingest.utils.Utils; + /** * Channel immutable identification and encryption attributes. * @@ -36,12 +38,12 @@ class ChannelFlushContext { String encryptionKey, Long encryptionKeyId) { this.name = name; - this.fullyQualifiedName = String.format("%s.%s.%s.%s", dbName, schemaName, tableName, name); + this.fullyQualifiedName = + Utils.getFullyQualifiedChannelName(dbName, schemaName, tableName, name); this.dbName = dbName; this.schemaName = schemaName; this.tableName = tableName; - this.fullyQualifiedTableName = - String.format("%s.%s.%s", this.getDbName(), this.getSchemaName(), this.getTableName()); + this.fullyQualifiedTableName = Utils.getFullyQualifiedTableName(dbName, schemaName, tableName); this.channelSequencer = channelSequencer; this.encryptionKey = encryptionKey; this.encryptionKeyId = encryptionKeyId; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelsStatusRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelsStatusRequest.java index b98782ab9..025647f14 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelsStatusRequest.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelsStatusRequest.java @@ -1,14 +1,16 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; +import java.util.stream.Collectors; +import net.snowflake.ingest.utils.Utils; /** Class to deserialize a request from a channel status request */ -class ChannelsStatusRequest { +class ChannelsStatusRequest implements IStreamingIngestRequest { // Used to deserialize a channel request static class ChannelStatusRequestDTO { @@ -61,20 +63,12 @@ Long getClientSequencer() { } } - // Optional Request ID. Used for diagnostic purposes. - private String requestId; - // Channels in request private List channels; // Snowflake role used by client private String role; - @JsonProperty("request_id") - String getRequestId() { - return requestId; - } - @JsonProperty("role") public String getRole() { return role; @@ -85,11 +79,6 @@ public void setRole(String role) { this.role = role; } - @JsonProperty("request_id") - void setRequestId(String requestId) { - this.requestId = requestId; - } - @JsonProperty("channels") void setChannels(List channels) { this.channels = channels; @@ -99,4 +88,20 @@ void setChannels(List channels) { List getChannels() { return channels; } + + @Override + public String getStringForLogging() { + return String.format( + "ChannelsStatusRequest(role=%s, channels=[%s])", + role, + channels.stream() + .map( + r -> + Utils.getFullyQualifiedChannelName( + r.getDatabaseName(), + r.getSchemaName(), + r.getTableName(), + r.getChannelName())) + .collect(Collectors.joining(", "))); + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java b/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java index 278d4abea..72d89409f 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java @@ -18,6 +18,8 @@ public class ClientBufferParameters { private Constants.BdecParquetCompression bdecParquetCompression; + private Constants.BinaryStringEncoding binaryStringEncoding; + /** * Private constructor used for test methods * @@ -30,11 +32,13 @@ private ClientBufferParameters( boolean enableParquetInternalBuffering, long maxChunkSizeInBytes, long maxAllowedRowSizeInBytes, - Constants.BdecParquetCompression bdecParquetCompression) { + Constants.BdecParquetCompression bdecParquetCompression, + Constants.BinaryStringEncoding binaryStringEncoding) { this.enableParquetInternalBuffering = enableParquetInternalBuffering; this.maxChunkSizeInBytes = maxChunkSizeInBytes; this.maxAllowedRowSizeInBytes = maxAllowedRowSizeInBytes; this.bdecParquetCompression = bdecParquetCompression; + this.binaryStringEncoding = binaryStringEncoding; } /** @param clientInternal reference to the client object where the relevant parameters are set */ @@ -55,6 +59,10 @@ public ClientBufferParameters(SnowflakeStreamingIngestClientInternal clientInter clientInternal != null ? clientInternal.getParameterProvider().getBdecParquetCompressionAlgorithm() : ParameterProvider.BDEC_PARQUET_COMPRESSION_ALGORITHM_DEFAULT; + this.binaryStringEncoding = + clientInternal != null + ? clientInternal.getParameterProvider().getBinaryStringEncoding() + : ParameterProvider.BINARY_STRING_ENCODING_DEFAULT; } /** @@ -68,12 +76,14 @@ public static ClientBufferParameters test_createClientBufferParameters( boolean enableParquetInternalBuffering, long maxChunkSizeInBytes, long maxAllowedRowSizeInBytes, - Constants.BdecParquetCompression bdecParquetCompression) { + Constants.BdecParquetCompression bdecParquetCompression, + Constants.BinaryStringEncoding binaryStringEncoding) { return new ClientBufferParameters( enableParquetInternalBuffering, maxChunkSizeInBytes, maxAllowedRowSizeInBytes, - bdecParquetCompression); + bdecParquetCompression, + binaryStringEncoding); } public boolean getEnableParquetInternalBuffering() { @@ -91,4 +101,8 @@ public long getMaxAllowedRowSizeInBytes() { public Constants.BdecParquetCompression getBdecParquetCompression() { return bdecParquetCompression; } + + public Constants.BinaryStringEncoding getBinaryStringEncoding() { + return binaryStringEncoding; + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureRequest.java new file mode 100644 index 000000000..79b282079 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureRequest.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** Class used to serialize client configure request */ +class ClientConfigureRequest implements IStreamingIngestRequest { + /** + * Constructor for client configure request + * + * @param role Role to be used for the request. + */ + ClientConfigureRequest(String role) { + this.role = role; + } + + @JsonProperty("role") + private String role; + + // File name for the GCS signed url request + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonProperty("file_name") + private String fileName; + + String getRole() { + return role; + } + + void setRole(String role) { + this.role = role; + } + + String getFileName() { + return fileName; + } + + void setFileName(String fileName) { + this.fileName = fileName; + } + + @Override + public String getStringForLogging() { + return String.format("ClientConfigureRequest(role=%s, file_name=%s)", getRole(), getFileName()); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureResponse.java b/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureResponse.java new file mode 100644 index 000000000..03a1d3576 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ClientConfigureResponse.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** Class used to deserialize responses from configure endpoint */ +@JsonIgnoreProperties(ignoreUnknown = true) +class ClientConfigureResponse extends StreamingIngestResponse { + @JsonProperty("prefix") + private String prefix; + + @JsonProperty("status_code") + private Long statusCode; + + @JsonProperty("message") + private String message; + + @JsonProperty("stage_location") + private FileLocationInfo stageLocation; + + @JsonProperty("deployment_id") + private Long deploymentId; + + String getPrefix() { + return prefix; + } + + void setPrefix(String prefix) { + this.prefix = prefix; + } + + @Override + Long getStatusCode() { + return statusCode; + } + + void setStatusCode(Long statusCode) { + this.statusCode = statusCode; + } + + String getMessage() { + return message; + } + + void setMessage(String message) { + this.message = message; + } + + FileLocationInfo getStageLocation() { + return stageLocation; + } + + void setStageLocation(FileLocationInfo stageLocation) { + this.stageLocation = stageLocation; + } + + Long getDeploymentId() { + return deploymentId; + } + + void setDeploymentId(Long deploymentId) { + this.deploymentId = deploymentId; + } + + String getClientPrefix() { + if (this.deploymentId == null) { + return this.prefix; + } + return this.prefix + "_" + this.deploymentId; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java index 814423c28..3e13c4ffc 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java @@ -24,17 +24,14 @@ import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.time.format.DateTimeParseException; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.function.Supplier; import net.snowflake.client.jdbc.internal.google.common.collect.Sets; import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat; import net.snowflake.client.jdbc.internal.snowflake.common.util.Power10; import net.snowflake.ingest.streaming.internal.serialization.ByteArraySerializer; import net.snowflake.ingest.streaming.internal.serialization.ZonedDateTimeSerializer; +import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import org.apache.commons.codec.DecoderException; @@ -86,6 +83,18 @@ class DataValidationUtil { objectMapper.registerModule(module); } + // Caching the powers of 10 that are used for checking the range of numbers because computing them + // on-demand is expensive. + private static final BigDecimal[] POWER_10 = makePower10Table(); + + private static BigDecimal[] makePower10Table() { + BigDecimal[] power10 = new BigDecimal[Power10.sb16Size]; + for (int i = 0; i < Power10.sb16Size; i++) { + power10[i] = new BigDecimal(Power10.sb16Table[i]); + } + return power10; + } + /** * Validates and parses input as JSON. All types in the object tree must be valid variant types, * see {@link DataValidationUtil#isAllowedSemiStructuredType}. @@ -615,7 +624,7 @@ static int validateAndParseDate(String columnName, Object input, long insertRowI * @return Validated array */ static byte[] validateAndParseBinary( - String columnName, Object input, Optional maxLengthOptional, long insertRowIndex) { + String columnName, Object input, Optional maxLengthOptional, long insertRowIndex, Constants.BinaryStringEncoding binaryStringEncoding) { byte[] output; if (input instanceof byte[]) { // byte[] is a mutable object, we need to create a defensive copy to protect against @@ -625,12 +634,30 @@ static byte[] validateAndParseBinary( output = new byte[originalInputArray.length]; System.arraycopy(originalInputArray, 0, output, 0, originalInputArray.length); } else if (input instanceof String) { - try { - String stringInput = ((String) input).trim(); - output = Hex.decodeHex(stringInput); - } catch (DecoderException e) { + if(binaryStringEncoding == Constants.BinaryStringEncoding.BASE64) { + try { + String stringInput = ((String) input).trim(); + // Remove double quotes if present + if (stringInput.length() >= 2 && stringInput.startsWith("\"") && stringInput.endsWith("\"")) { + stringInput = stringInput.substring(1, stringInput.length() - 1); + } + Base64.Decoder decoder = Base64.getDecoder(); + output = decoder.decode(stringInput); + } catch (IllegalArgumentException e) { + throw valueFormatNotAllowedException( + columnName, "BINARY", "Not a valid base64 string", insertRowIndex); + } + } else if (binaryStringEncoding == Constants.BinaryStringEncoding.HEX) { + try { + String stringInput = ((String) input).trim(); + output = Hex.decodeHex(stringInput); + } catch (DecoderException e) { + throw valueFormatNotAllowedException( + columnName, "BINARY", "Not a valid hex string", insertRowIndex); + } + } else { throw valueFormatNotAllowedException( - columnName, "BINARY", "Not a valid hex string", insertRowIndex); + columnName, "BINARY", "Unsupported binary string format " + binaryStringEncoding.name(), insertRowIndex); } } else { throw typeNotAllowedException( @@ -823,7 +850,11 @@ static int validateAndParseBoolean(String columnName, Object input, long insertR static void checkValueInRange( BigDecimal bigDecimalValue, int scale, int precision, final long insertRowIndex) { - if (bigDecimalValue.abs().compareTo(BigDecimal.TEN.pow(precision - scale)) >= 0) { + BigDecimal comparand = + (precision >= scale) && (precision - scale) < POWER_10.length + ? POWER_10[precision - scale] + : BigDecimal.TEN.pow(precision - scale); + if (bigDecimalValue.abs().compareTo(comparand) >= 0) { throw new SFException( ErrorCode.INVALID_FORMAT_ROW, String.format( diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java new file mode 100644 index 000000000..322b53acf --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import net.snowflake.ingest.streaming.DropChannelRequest; +import net.snowflake.ingest.utils.Utils; + +/** Class used to serialize the {@link DropChannelRequest} */ +class DropChannelRequestInternal implements IStreamingIngestRequest { + @JsonProperty("request_id") + private String requestId; + + @JsonProperty("role") + private String role; + + @JsonProperty("channel") + private String channel; + + @JsonProperty("table") + private String table; + + @JsonProperty("database") + private String database; + + @JsonProperty("schema") + private String schema; + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonProperty("client_sequencer") + Long clientSequencer; + + DropChannelRequestInternal( + String requestId, + String role, + String database, + String schema, + String table, + String channel, + Long clientSequencer) { + this.requestId = requestId; + this.role = role; + this.database = database; + this.schema = schema; + this.table = table; + this.channel = channel; + this.clientSequencer = clientSequencer; + } + + String getRequestId() { + return requestId; + } + + String getRole() { + return role; + } + + String getChannel() { + return channel; + } + + String getTable() { + return table; + } + + String getDatabase() { + return database; + } + + String getSchema() { + return schema; + } + + Long getClientSequencer() { + return clientSequencer; + } + + String getFullyQualifiedTableName() { + return Utils.getFullyQualifiedTableName(database, schema, table); + } + + @Override + public String getStringForLogging() { + return String.format( + "DropChannelRequest(requestId=%s, role=%s, db=%s, schema=%s, table=%s, channel=%s," + + " clientSequencer=%s)", + requestId, role, database, schema, table, channel, clientSequencer); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FileLocationInfo.java b/src/main/java/net/snowflake/ingest/streaming/internal/FileLocationInfo.java new file mode 100644 index 000000000..add98a6fb --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FileLocationInfo.java @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; + +/** Class used to deserialized volume information response by server */ +class FileLocationInfo { + /** The stage type */ + @JsonProperty("locationType") + private String locationType; + + /** The container or bucket */ + @JsonProperty("location") + private String location; + + /** The path of the target file */ + @JsonProperty("path") + private String path; + + /** The credentials required for the stage */ + @JsonProperty("creds") + private Map credentials; + + /** AWS/S3/GCS region (S3/GCS only) */ + @JsonProperty("region") + private String region; + + /** The Azure Storage endpoint (Azure only) */ + @JsonProperty("endPoint") + private String endPoint; + + /** The Azure Storage account (Azure only) */ + @JsonProperty("storageAccount") + private String storageAccount; + + /** GCS gives us back a presigned URL instead of a cred */ + @JsonProperty("presignedUrl") + private String presignedUrl; + + /** Whether to encrypt/decrypt files on the stage */ + @JsonProperty("isClientSideEncrypted") + private boolean isClientSideEncrypted; + + /** Whether to use s3 regional URL (AWS Only) */ + @JsonProperty("useS3RegionalUrl") + private boolean useS3RegionalUrl; + + /** A unique id for volume assigned by server */ + @JsonProperty("volumeHash") + private String volumeHash; + + String getLocationType() { + return locationType; + } + + void setLocationType(String locationType) { + this.locationType = locationType; + } + + String getLocation() { + return location; + } + + void setLocation(String location) { + this.location = location; + } + + String getPath() { + return path; + } + + void setPath(String path) { + this.path = path; + } + + Map getCredentials() { + return credentials; + } + + void setCredentials(Map credentials) { + this.credentials = credentials; + } + + String getRegion() { + return region; + } + + void setRegion(String region) { + this.region = region; + } + + String getEndPoint() { + return endPoint; + } + + void setEndPoint(String endPoint) { + this.endPoint = endPoint; + } + + String getStorageAccount() { + return storageAccount; + } + + void setStorageAccount(String storageAccount) { + this.storageAccount = storageAccount; + } + + String getPresignedUrl() { + return presignedUrl; + } + + void setPresignedUrl(String presignedUrl) { + this.presignedUrl = presignedUrl; + } + + boolean getIsClientSideEncrypted() { + return this.isClientSideEncrypted; + } + + void setIsClientSideEncrypted(boolean isClientSideEncrypted) { + this.isClientSideEncrypted = isClientSideEncrypted; + } + + boolean getUseS3RegionalUrl() { + return this.useS3RegionalUrl; + } + + void setUseS3RegionalUrl(boolean useS3RegionalUrl) { + this.useS3RegionalUrl = useS3RegionalUrl; + } + + String getVolumeHash() { + return this.volumeHash; + } + + void setVolumeHash(String volumeHash) { + this.volumeHash = volumeHash; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java index f08196477..84e1a2561 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java @@ -1,10 +1,9 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; -import static net.snowflake.ingest.utils.Constants.BLOB_EXTENSION_TYPE; import static net.snowflake.ingest.utils.Constants.DISABLE_BACKGROUND_FLUSH; import static net.snowflake.ingest.utils.Constants.MAX_BLOB_SIZE_IN_BYTES; import static net.snowflake.ingest.utils.Constants.MAX_THREAD_COUNT; @@ -19,13 +18,12 @@ import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.util.ArrayList; -import java.util.Calendar; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.TimeZone; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; @@ -34,11 +32,10 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; -import net.snowflake.client.jdbc.SnowflakeSQLException; import net.snowflake.client.jdbc.internal.google.common.util.concurrent.ThreadFactoryBuilder; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; @@ -84,9 +81,6 @@ List>> getData() { private static final Logging logger = new Logging(FlushService.class); - // Increasing counter to generate a unique blob name per client - private final AtomicLong counter; - // The client that owns this flush service private final SnowflakeStreamingIngestClientInternal owningClient; @@ -102,18 +96,21 @@ List>> getData() { // Reference to the channel cache private final ChannelCache channelCache; - // Reference to the Streaming Ingest stage - private final StreamingIngestStage targetStage; + // Reference to the Streaming Ingest storage manager + private final IStorageManager storageManager; // Reference to register service private final RegisterService registerService; - // Indicates whether we need to schedule a flush - @VisibleForTesting volatile boolean isNeedFlush; - - // Latest flush time + /** + * Client level last flush time and need flush flag. This two variables are used when max chunk in + * blob is not 1. When max chunk in blob is 1, flush service ignores these variables and uses + * table level last flush time and need flush flag. See {@link ChannelCache.FlushInfo}. + */ @VisibleForTesting volatile long lastFlushTime; + @VisibleForTesting volatile boolean isNeedFlush; + // Indicates whether it's running as part of the test private final boolean isTestMode; @@ -122,23 +119,24 @@ List>> getData() { // blob encoding version private final Constants.BdecVersion bdecVersion; + private volatile int numProcessors = Runtime.getRuntime().availableProcessors(); /** - * Constructor for TESTING that takes (usually mocked) StreamingIngestStage + * Default constructor * - * @param client - * @param cache - * @param isTestMode + * @param client the owning client + * @param cache the channel cache + * @param storageManager the storage manager + * @param isTestMode whether the service is running in test mode */ FlushService( SnowflakeStreamingIngestClientInternal client, ChannelCache cache, - StreamingIngestStage targetStage, // For testing + IStorageManager storageManager, boolean isTestMode) { this.owningClient = client; this.channelCache = cache; - this.targetStage = targetStage; - this.counter = new AtomicLong(0); + this.storageManager = storageManager; this.registerService = new RegisterService<>(client, isTestMode); this.isNeedFlush = false; this.lastFlushTime = System.currentTimeMillis(); @@ -148,40 +146,6 @@ List>> getData() { createWorkers(); } - /** - * Default constructor - * - * @param client - * @param cache - * @param isTestMode - */ - FlushService( - SnowflakeStreamingIngestClientInternal client, ChannelCache cache, boolean isTestMode) { - this.owningClient = client; - this.channelCache = cache; - try { - this.targetStage = - new StreamingIngestStage( - isTestMode, - client.getRole(), - client.getHttpClient(), - client.getRequestBuilder(), - client.getName(), - DEFAULT_MAX_UPLOAD_RETRIES); - } catch (SnowflakeSQLException | IOException err) { - throw new SFException(err, ErrorCode.UNABLE_TO_CONNECT_TO_STAGE); - } - - this.registerService = new RegisterService<>(client, isTestMode); - this.counter = new AtomicLong(0); - this.isNeedFlush = false; - this.lastFlushTime = System.currentTimeMillis(); - this.isTestMode = isTestMode; - this.latencyTimerContextMap = new ConcurrentHashMap<>(); - this.bdecVersion = this.owningClient.getParameterProvider().getBlobFormatVersion(); - createWorkers(); - } - /** * Updates performance stats enabled * @@ -203,36 +167,65 @@ private CompletableFuture statsFuture() { /** * @param isForce if true will flush regardless of other conditions - * @param timeDiffMillis Time in milliseconds since the last flush + * @param tablesToFlush list of tables to flush + * @param flushStartTime the time when the flush started * @return */ - private CompletableFuture distributeFlush(boolean isForce, long timeDiffMillis) { + private CompletableFuture distributeFlush( + boolean isForce, Set tablesToFlush, Long flushStartTime) { return CompletableFuture.runAsync( () -> { - logFlushTask(isForce, timeDiffMillis); - distributeFlushTasks(); + logFlushTask(isForce, tablesToFlush, flushStartTime); + distributeFlushTasks(tablesToFlush); + long prevFlushEndTime = System.currentTimeMillis(); + this.lastFlushTime = prevFlushEndTime; this.isNeedFlush = false; - this.lastFlushTime = System.currentTimeMillis(); - return; + tablesToFlush.forEach( + table -> { + this.channelCache.setLastFlushTime(table, prevFlushEndTime); + this.channelCache.setNeedFlush(table, false); + }); }, this.flushWorker); } /** If tracing is enabled, print always else, check if it needs flush or is forceful. */ - private void logFlushTask(boolean isForce, long timeDiffMillis) { + private void logFlushTask(boolean isForce, Set tablesToFlush, long flushStartTime) { + boolean isNeedFlush = + this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest() == 1 + ? tablesToFlush.stream().anyMatch(channelCache::getNeedFlush) + : this.isNeedFlush; + long currentTime = System.currentTimeMillis(); + final String logInfo; + if (this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest() == 1) { + logInfo = + String.format( + "Tables=[%s]", + tablesToFlush.stream() + .map( + table -> + String.format( + "(name=%s, isNeedFlush=%s, timeDiffMillis=%s, currentDiffMillis=%s)", + table, + channelCache.getNeedFlush(table), + flushStartTime - channelCache.getLastFlushTime(table), + currentTime - channelCache.getLastFlushTime(table))) + .collect(Collectors.joining(", "))); + } else { + logInfo = + String.format( + "isNeedFlush=%s, timeDiffMillis=%s, currentDiffMillis=%s", + isNeedFlush, flushStartTime - this.lastFlushTime, currentTime - this.lastFlushTime); + } + final String flushTaskLogFormat = String.format( - "Submit forced or ad-hoc flush task on client=%s, isForce=%s," - + " isNeedFlush=%s, timeDiffMillis=%s, currentDiffMillis=%s", - this.owningClient.getName(), - isForce, - this.isNeedFlush, - timeDiffMillis, - System.currentTimeMillis() - this.lastFlushTime); + "Submit forced or ad-hoc flush task on client=%s, isForce=%s, %s", + this.owningClient.getName(), isForce, logInfo); if (logger.isTraceEnabled()) { logger.logTrace(flushTaskLogFormat); } - if (!logger.isTraceEnabled() && (this.isNeedFlush || isForce)) { + if (!logger.isTraceEnabled() && (isNeedFlush || isForce)) { logger.logDebug(flushTaskLogFormat); } } @@ -248,27 +241,65 @@ private CompletableFuture registerFuture() { } /** - * Kick off a flush job and distribute the tasks if one of the following conditions is met: - *
  • Flush is forced by the users - *
  • One or more buffers have reached the flush size - *
  • Periodical background flush when a time interval has reached + * Kick off a flush job and distribute the tasks. The flush service behaves differently based on + * the max chunks in blob: + * + *
      + *
    • The max chunks in blob is not 1 (interleaving is allowed), every channel will be flushed + * together if one of the following conditions is met: + *
        + *
      • Flush is forced by the users + *
      • One or more buffers have reached the flush size + *
      • Periodical background flush when a time interval has reached + *
      + *
    • The max chunks in blob is 1 (interleaving is not allowed), a channel will be flushed if + * one of the following conditions is met: + *
        + *
      • Flush is forced by the users + *
      • One or more buffers with the same target table as the channel have reached the + * flush size + *
      • Periodical background flush of the target table when a time interval has reached + *
      + *
    * * @param isForce * @return Completable future that will return when the blobs are registered successfully, or null * if none of the conditions is met above */ CompletableFuture flush(boolean isForce) { - long timeDiffMillis = System.currentTimeMillis() - this.lastFlushTime; + final long flushStartTime = System.currentTimeMillis(); + final long flushingInterval = + this.owningClient.getParameterProvider().getCachedMaxClientLagInMs(); + + final Set tablesToFlush; + if (this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest() == 1) { + tablesToFlush = + this.channelCache.keySet().stream() + .filter( + key -> + isForce + || flushStartTime - this.channelCache.getLastFlushTime(key) + >= flushingInterval + || this.channelCache.getNeedFlush(key)) + .collect(Collectors.toSet()); + } else { + if (isForce + || (!DISABLE_BACKGROUND_FLUSH + && !isTestMode() + && (this.isNeedFlush || flushStartTime - this.lastFlushTime >= flushingInterval))) { + tablesToFlush = this.channelCache.keySet(); + } else { + tablesToFlush = null; + } + } if (isForce || (!DISABLE_BACKGROUND_FLUSH && !isTestMode() - && (this.isNeedFlush - || timeDiffMillis - >= this.owningClient.getParameterProvider().getCachedMaxClientLagInMs()))) { - + && tablesToFlush != null + && !tablesToFlush.isEmpty())) { return this.statsFuture() - .thenCompose((v) -> this.distributeFlush(isForce, timeDiffMillis)) + .thenCompose((v) -> this.distributeFlush(isForce, tablesToFlush, flushStartTime)) .thenCompose((v) -> this.registerFuture()); } return this.statsFuture(); @@ -351,19 +382,27 @@ private void createWorkers() { /** * Distribute the flush tasks by iterating through all the channels in the channel cache and kick * off a build blob work when certain size has reached or we have reached the end + * + * @param tablesToFlush list of tables to flush */ - void distributeFlushTasks() { + void distributeFlushTasks(Set tablesToFlush) { Iterator< Map.Entry< String, ConcurrentHashMap>>> - itr = this.channelCache.iterator(); + itr = + this.channelCache.entrySet().stream() + .filter(e -> tablesToFlush.contains(e.getKey())) + .iterator(); List, CompletableFuture>> blobs = new ArrayList<>(); List> leftoverChannelsDataPerTable = new ArrayList<>(); + // The API states that the number of available processors reported can change and therefore, we + // should poll it occasionally. + numProcessors = Runtime.getRuntime().availableProcessors(); while (itr.hasNext() || !leftoverChannelsDataPerTable.isEmpty()) { List>> blobData = new ArrayList<>(); float totalBufferSizeInBytes = 0F; - final String blobPath = getBlobPath(this.targetStage.getClientPrefix()); + final String blobPath = this.storageManager.generateBlobPath(); // Distribute work at table level, split the blob if reaching the blob size limit or the // channel has different encryption key ids @@ -445,9 +484,9 @@ && shouldStopProcessing( // Kick off a build job if (blobData.isEmpty()) { - // we decrement the counter so that we do not have gaps in the blob names created by this - // client. See method getBlobPath() below. - this.counter.decrementAndGet(); + // we decrement the blob sequencer so that we do not have gaps in the blob names created by + // this client. + this.storageManager.decrementBlobSequencer(); } else { long flushStartMs = System.currentTimeMillis(); if (this.owningClient.flushLatency != null) { @@ -459,7 +498,13 @@ && shouldStopProcessing( CompletableFuture.supplyAsync( () -> { try { - BlobMetadata blobMetadata = buildAndUpload(blobPath, blobData); + // Get the fully qualified table name from the first channel in the blob. + // This only matters when the client is in Iceberg mode. In Iceberg mode, + // all channels in the blob belong to the same table. + String fullyQualifiedTableName = + blobData.get(0).get(0).getChannelContext().getFullyQualifiedTableName(); + BlobMetadata blobMetadata = + buildAndUpload(blobPath, blobData, fullyQualifiedTableName); blobMetadata.getBlobStats().setFlushStartMs(flushStartMs); return blobMetadata; } catch (Throwable e) { @@ -542,9 +587,12 @@ private boolean shouldStopProcessing( * @param blobPath Path of the destination blob in cloud storage * @param blobData All the data for one blob. Assumes that all ChannelData in the inner List * belongs to the same table. Will error if this is not the case + * @param fullyQualifiedTableName the table name of the first channel in the blob, only matters in + * Iceberg mode * @return BlobMetadata for FlushService.upload */ - BlobMetadata buildAndUpload(String blobPath, List>> blobData) + BlobMetadata buildAndUpload( + String blobPath, List>> blobData, String fullyQualifiedTableName) throws IOException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException, InvalidKeyException { @@ -555,12 +603,18 @@ BlobMetadata buildAndUpload(String blobPath, List>> blobData blob.blobStats.setBuildDurationMs(buildContext); - return upload(blobPath, blob.blobBytes, blob.chunksMetadataList, blob.blobStats); + return upload( + this.storageManager.getStorage(fullyQualifiedTableName), + blobPath, + blob.blobBytes, + blob.chunksMetadataList, + blob.blobStats); } /** * Upload a blob to Streaming Ingest dedicated stage * + * @param storage the storage to upload the blob * @param blobPath full path of the blob * @param blob blob data * @param metadata a list of chunk metadata @@ -568,13 +622,17 @@ BlobMetadata buildAndUpload(String blobPath, List>> blobData * @return BlobMetadata object used to create the register blob request */ BlobMetadata upload( - String blobPath, byte[] blob, List metadata, BlobStats blobStats) + StreamingIngestStorage storage, + String blobPath, + byte[] blob, + List metadata, + BlobStats blobStats) throws NoSuchAlgorithmException { logger.logInfo("Start uploading blob={}, size={}", blobPath, blob.length); long startTime = System.currentTimeMillis(); Timer.Context uploadContext = Utils.createTimerContext(this.owningClient.uploadLatency); - this.targetStage.put(blobPath, blob); + storage.put(blobPath, blob); if (uploadContext != null) { blobStats.setUploadDurationMs(uploadContext); @@ -626,48 +684,16 @@ void shutdown() throws InterruptedException { } } - /** Set the flag to indicate that a flush is needed */ - void setNeedFlush() { - this.isNeedFlush = true; - } - /** - * Generate a blob path, which is: "YEAR/MONTH/DAY_OF_MONTH/HOUR_OF_DAY/MINUTE/.BDEC" + * Set the flag to indicate that a flush is needed * - * @return the generated blob file path + * @param fullyQualifiedTableName the fully qualified table name */ - private String getBlobPath(String clientPrefix) { - Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC")); - return getBlobPath(calendar, clientPrefix); - } - - /** For TESTING */ - String getBlobPath(Calendar calendar, String clientPrefix) { - if (isTestMode && clientPrefix == null) { - clientPrefix = "testPrefix"; + void setNeedFlush(String fullyQualifiedTableName) { + this.isNeedFlush = true; + if (this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest() == 1) { + this.channelCache.setNeedFlush(fullyQualifiedTableName, true); } - - Utils.assertStringNotNullOrEmpty("client prefix", clientPrefix); - int year = calendar.get(Calendar.YEAR); - int month = calendar.get(Calendar.MONTH) + 1; // Gregorian calendar starts from 0 - int day = calendar.get(Calendar.DAY_OF_MONTH); - int hour = calendar.get(Calendar.HOUR_OF_DAY); - int minute = calendar.get(Calendar.MINUTE); - long time = TimeUnit.MILLISECONDS.toSeconds(calendar.getTimeInMillis()); - long threadId = Thread.currentThread().getId(); - // Create the blob short name, the clientPrefix contains the deployment id - String blobShortName = - Long.toString(time, 36) - + "_" - + clientPrefix - + "_" - + threadId - + "_" - + this.counter.getAndIncrement() - + "." - + BLOB_EXTENSION_TYPE; - return year + "/" + month + "/" + day + "/" + hour + "/" + minute + "/" + blobShortName; } /** @@ -693,19 +719,13 @@ void invalidateAllChannelsInBlob( })); } - /** Get the server generated unique prefix for this client */ - String getClientPrefix() { - return this.targetStage.getClientPrefix(); - } - /** * Throttle if the number of queued buildAndUpload tasks is bigger than the total number of * available processors */ boolean throttleDueToQueuedFlushTasks() { ThreadPoolExecutor buildAndUpload = (ThreadPoolExecutor) this.buildUploadWorkers; - boolean throttleOnQueuedTasks = - buildAndUpload.getQueue().size() > Runtime.getRuntime().availableProcessors(); + boolean throttleOnQueuedTasks = buildAndUpload.getQueue().size() > numProcessors; if (throttleOnQueuedTasks) { logger.logWarn( "Throttled due too many queue flush tasks (probably because of slow uploading speed)," diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/IStorageManager.java b/src/main/java/net/snowflake/ingest/streaming/internal/IStorageManager.java new file mode 100644 index 000000000..51f4a82de --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/IStorageManager.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import java.util.Optional; + +/** + * Interface to manage {@link StreamingIngestStorage} for {@link FlushService} + * + * @param The type of chunk data + * @param the type of location that's being managed (internal stage / external volume) + */ +interface IStorageManager { + /** Default max upload retries for streaming ingest storage */ + int DEFAULT_MAX_UPLOAD_RETRIES = 5; + + /** + * Given a fully qualified table name, return the target storage + * + * @param fullyQualifiedTableName the target fully qualified table name + * @return target stage + */ + StreamingIngestStorage getStorage(String fullyQualifiedTableName); + + /** + * Add a storage to the manager + * + * @param dbName the database name + * @param schemaName the schema name + * @param tableName the table name + * @param fileLocationInfo file location info from configure response + */ + void addStorage( + String dbName, String schemaName, String tableName, FileLocationInfo fileLocationInfo); + + /** + * Gets the latest file location info (with a renewed short-lived access token) for the specified + * location + * + * @param location A reference to the target location + * @param fileName optional filename for single-file signed URL fetch from server + * @return the new location information + */ + FileLocationInfo getRefreshedLocation(TLocation location, Optional fileName); + + /** + * Generate a unique blob path and increment the blob sequencer + * + * @return the blob path + */ + String generateBlobPath(); + + /** + * Decrement the blob sequencer, this method is needed to prevent gap between file name sequencer. + * See {@link IStorageManager#generateBlobPath()} for more details. + */ + void decrementBlobSequencer(); + + /** + * Get the unique client prefix generated by the Snowflake server + * + * @return the client prefix + */ + String getClientPrefix(); +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/IStreamingIngestRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/IStreamingIngestRequest.java new file mode 100644 index 000000000..a4b5e29d1 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/IStreamingIngestRequest.java @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +/** + * The StreamingIngestRequest interface is a marker interface used for type safety in the {@link + * SnowflakeServiceClient} for streaming ingest API request. + */ +interface IStreamingIngestRequest { + String getStringForLogging(); +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/InternalStageManager.java b/src/main/java/net/snowflake/ingest/streaming/internal/InternalStageManager.java new file mode 100644 index 000000000..d33a80738 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/InternalStageManager.java @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import static net.snowflake.ingest.utils.Constants.BLOB_EXTENSION_TYPE; + +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.util.Calendar; +import java.util.Optional; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.ingest.connection.IngestResponseException; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.SFException; +import net.snowflake.ingest.utils.Utils; + +class InternalStageLocation { + public InternalStageLocation() {} +} + +/** Class to manage single Snowflake internal stage */ +class InternalStageManager implements IStorageManager { + /** Target stage for the client */ + private final StreamingIngestStorage targetStage; + + /** Increasing counter to generate a unique blob name per client */ + private final AtomicLong counter; + + /** Whether the manager in test mode */ + private final boolean isTestMode; + + /** Snowflake service client used for configure calls */ + private final SnowflakeServiceClient snowflakeServiceClient; + + /** The name of the client */ + private final String clientName; + + /** The role of the client */ + private final String role; + + /** Client prefix generated by the Snowflake server */ + private String clientPrefix; + + /** Deployment ID generated by the Snowflake server */ + private Long deploymentId; + + /** + * Constructor for InternalStageManager + * + * @param isTestMode whether the manager in test mode + * @param role the role of the client + * @param clientName the name of the client + * @param snowflakeServiceClient the Snowflake service client to use for configure calls + */ + InternalStageManager( + boolean isTestMode, + String role, + String clientName, + SnowflakeServiceClient snowflakeServiceClient) { + this.snowflakeServiceClient = snowflakeServiceClient; + this.isTestMode = isTestMode; + this.clientName = clientName; + this.role = role; + this.counter = new AtomicLong(0); + try { + if (!isTestMode) { + ClientConfigureResponse response = + this.snowflakeServiceClient.clientConfigure(new ClientConfigureRequest(role)); + this.clientPrefix = response.getClientPrefix(); + this.deploymentId = response.getDeploymentId(); + this.targetStage = + new StreamingIngestStorage( + this, + clientName, + response.getStageLocation(), + new InternalStageLocation(), + DEFAULT_MAX_UPLOAD_RETRIES); + } else { + this.clientPrefix = null; + this.deploymentId = null; + this.targetStage = + new StreamingIngestStorage( + this, + "testClient", + (StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null, + new InternalStageLocation(), + DEFAULT_MAX_UPLOAD_RETRIES); + } + } catch (IngestResponseException | IOException e) { + throw new SFException(e, ErrorCode.CLIENT_CONFIGURE_FAILURE, e.getMessage()); + } catch (SnowflakeSQLException e) { + throw new SFException(e, ErrorCode.UNABLE_TO_CONNECT_TO_STAGE, e.getMessage()); + } + } + + /** + * Get the storage. In this case, the storage is always the target stage as there's only one stage + * in non-iceberg mode. + * + * @param fullyQualifiedTableName the target fully qualified table name + * @return the target storage + */ + @Override + @SuppressWarnings("unused") + public StreamingIngestStorage getStorage( + String fullyQualifiedTableName) { + // There's always only one stage for the client in non-iceberg mode + return targetStage; + } + + /** Add storage to the manager. Do nothing as there's only one stage in non-Iceberg mode. */ + @Override + public void addStorage( + String dbName, String schemaName, String tableName, FileLocationInfo fileLocationInfo) {} + + /** + * Gets the latest file location info (with a renewed short-lived access token) for the specified + * location + * + * @param location A reference to the target location + * @param fileName optional filename for single-file signed URL fetch from server + * @return the new location information + */ + @Override + public FileLocationInfo getRefreshedLocation( + InternalStageLocation location, Optional fileName) { + try { + ClientConfigureRequest request = new ClientConfigureRequest(this.role); + fileName.ifPresent(request::setFileName); + ClientConfigureResponse response = snowflakeServiceClient.clientConfigure(request); + if (this.clientPrefix == null) { + this.clientPrefix = response.getClientPrefix(); + this.deploymentId = response.getDeploymentId(); + } + if (this.deploymentId != null && !this.deploymentId.equals(response.getDeploymentId())) { + throw new SFException( + ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH, + this.deploymentId, + response.getDeploymentId(), + this.clientName); + } + return response.getStageLocation(); + } catch (IngestResponseException | IOException e) { + throw new SFException(e, ErrorCode.CLIENT_CONFIGURE_FAILURE, e.getMessage()); + } + } + + /** + * Generate a blob path, which is: "YEAR/MONTH/DAY_OF_MONTH/HOUR_OF_DAY/MINUTE/.BDEC" + * + * @return the generated blob file path + */ + @Override + public String generateBlobPath() { + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + return getBlobPath(calendar, this.clientPrefix); + } + + @Override + public void decrementBlobSequencer() { + this.counter.decrementAndGet(); + } + + /** For TESTING */ + @VisibleForTesting + public String getBlobPath(Calendar calendar, String clientPrefix) { + if (this.isTestMode && clientPrefix == null) { + clientPrefix = "testPrefix"; + } + + Utils.assertStringNotNullOrEmpty("client prefix", clientPrefix); + int year = calendar.get(Calendar.YEAR); + int month = calendar.get(Calendar.MONTH) + 1; // Gregorian calendar starts from 0 + int day = calendar.get(Calendar.DAY_OF_MONTH); + int hour = calendar.get(Calendar.HOUR_OF_DAY); + int minute = calendar.get(Calendar.MINUTE); + long time = TimeUnit.MILLISECONDS.toSeconds(calendar.getTimeInMillis()); + long threadId = Thread.currentThread().getId(); + // Create the blob short name, the clientPrefix contains the deployment id + String blobShortName = + Long.toString(time, 36) + + "_" + + clientPrefix + + "_" + + threadId + + "_" + + this.counter.getAndIncrement() + + "." + + BLOB_EXTENSION_TYPE; + return year + "/" + month + "/" + day + "/" + hour + "/" + minute + "/" + blobShortName; + } + + /** + * Get the unique client prefix generated by the Snowflake server + * + * @return the client prefix + */ + @Override + public String getClientPrefix() { + return this.clientPrefix; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java index f426e898d..777ae4fdc 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java @@ -9,9 +9,6 @@ public interface MemoryInfoProvider { /** @return Max memory the JVM can allocate */ long getMaxMemory(); - /** @return Total allocated JVM memory so far */ - long getTotalMemory(); - /** @return Free JVM memory */ long getFreeMemory(); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java index 3a957f225..d248ddfd9 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java @@ -4,20 +4,51 @@ package net.snowflake.ingest.streaming.internal; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + /** Reads memory information from JVM runtime */ public class MemoryInfoProviderFromRuntime implements MemoryInfoProvider { - @Override - public long getMaxMemory() { - return Runtime.getRuntime().maxMemory(); + private final long maxMemory; + private volatile long totalFreeMemory; + private final ScheduledExecutorService executorService; + private static final long FREE_MEMORY_UPDATE_INTERVAL_MS = 100; + private static final MemoryInfoProviderFromRuntime INSTANCE = + new MemoryInfoProviderFromRuntime(FREE_MEMORY_UPDATE_INTERVAL_MS); + + private MemoryInfoProviderFromRuntime(long freeMemoryUpdateIntervalMs) { + maxMemory = Runtime.getRuntime().maxMemory(); + totalFreeMemory = + Runtime.getRuntime().freeMemory() + (maxMemory - Runtime.getRuntime().totalMemory()); + executorService = + new ScheduledThreadPoolExecutor( + 1, + r -> { + Thread th = new Thread(r, "MemoryInfoProviderFromRuntime"); + th.setDaemon(true); + return th; + }); + executorService.scheduleAtFixedRate( + this::updateFreeMemory, 0, freeMemoryUpdateIntervalMs, TimeUnit.MILLISECONDS); + } + + private void updateFreeMemory() { + totalFreeMemory = + Runtime.getRuntime().freeMemory() + (maxMemory - Runtime.getRuntime().totalMemory()); + } + + public static MemoryInfoProviderFromRuntime getInstance() { + return INSTANCE; } @Override - public long getTotalMemory() { - return Runtime.getRuntime().totalMemory(); + public long getMaxMemory() { + return maxMemory; } @Override public long getFreeMemory() { - return Runtime.getRuntime().freeMemory(); + return totalFreeMemory; } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelRequestInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelRequestInternal.java new file mode 100644 index 000000000..ff53f6729 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelRequestInternal.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import net.snowflake.ingest.streaming.OpenChannelRequest; +import net.snowflake.ingest.utils.Constants; + +/** Class used to serialize the {@link OpenChannelRequest} */ +class OpenChannelRequestInternal implements IStreamingIngestRequest { + @JsonProperty("request_id") + private String requestId; + + @JsonProperty("role") + private String role; + + @JsonProperty("channel") + private String channel; + + @JsonProperty("table") + private String table; + + @JsonProperty("database") + private String database; + + @JsonProperty("schema") + private String schema; + + @JsonProperty("write_mode") + private String writeMode; + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonProperty("offset_token") + private String offsetToken; + + OpenChannelRequestInternal( + String requestId, + String role, + String database, + String schema, + String table, + String channel, + Constants.WriteMode writeMode, + String offsetToken) { + this.requestId = requestId; + this.role = role; + this.database = database; + this.schema = schema; + this.table = table; + this.channel = channel; + this.writeMode = writeMode.name(); + this.offsetToken = offsetToken; + } + + String getRequestId() { + return requestId; + } + + String getRole() { + return role; + } + + String getChannel() { + return channel; + } + + String getTable() { + return table; + } + + String getDatabase() { + return database; + } + + String getSchema() { + return schema; + } + + String getWriteMode() { + return writeMode; + } + + String getOffsetToken() { + return offsetToken; + } + + @Override + public String getStringForLogging() { + return String.format( + "OpenChannelRequestInternal(requestId=%s, role=%s, db=%s, schema=%s, table=%s, channel=%s," + + " writeMode=%s)", + requestId, role, database, schema, table, channel, writeMode); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java index 16b1ededa..9950c44aa 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java @@ -5,6 +5,7 @@ package net.snowflake.ingest.streaming.internal; import java.io.ByteArrayOutputStream; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.parquet.hadoop.BdecParquetWriter; @@ -34,6 +35,16 @@ public ParquetChunkData( this.rows = rows; this.parquetWriter = parquetWriter; this.output = output; - this.metadata = metadata; + // create a defensive copy of the parameter map because the argument map passed here + // may currently be shared across multiple threads. + this.metadata = createDefensiveCopy(metadata); + } + + private Map createDefensiveCopy(final Map metadata) { + final Map copy = new HashMap<>(metadata); + for (String k : metadata.keySet()) { + copy.put(k, metadata.get(k)); + } + return copy; } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java index 39ec66dbb..3ad3db5f4 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java @@ -9,6 +9,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.Logging; @@ -124,6 +125,12 @@ private SerializationResult serializeFromParquetWriteBuffers( if (mergedChannelWriter != null) { mergedChannelWriter.close(); + this.verifyRowCounts( + "serializeFromParquetWriteBuffers", + mergedChannelWriter, + rowCount, + channelsDataPerTable, + -1); } return new SerializationResult( channelsMetadataList, @@ -216,6 +223,9 @@ private SerializationResult serializeFromJavaObjects( rows.forEach(parquetWriter::writeRow); parquetWriter.close(); + this.verifyRowCounts( + "serializeFromJavaObjects", parquetWriter, rowCount, channelsDataPerTable, rows.size()); + return new SerializationResult( channelsMetadataList, columnEpStatsMapCombined, @@ -224,4 +234,71 @@ private SerializationResult serializeFromJavaObjects( mergedData, chunkMinMaxInsertTimeInMs); } + + /** + * Validates that rows count in metadata matches the row count in Parquet footer and the row count + * written by the parquet writer + * + * @param serializationType Serialization type, used for logging purposes only + * @param writer Parquet writer writing the data + * @param channelsDataPerTable Channel data + * @param totalMetadataRowCount Row count calculated during metadata collection + * @param javaSerializationTotalRowCount Total row count when java object serialization is used. + * Used only for logging purposes if there is a mismatch. + */ + private void verifyRowCounts( + String serializationType, + BdecParquetWriter writer, + long totalMetadataRowCount, + List> channelsDataPerTable, + long javaSerializationTotalRowCount) { + long parquetTotalRowsWritten = writer.getRowsWritten(); + + List parquetFooterRowsPerBlock = writer.getRowCountsFromFooter(); + long parquetTotalRowsInFooter = 0; + for (long perBlockCount : parquetFooterRowsPerBlock) { + parquetTotalRowsInFooter += perBlockCount; + } + + if (parquetTotalRowsInFooter != totalMetadataRowCount + || parquetTotalRowsWritten != totalMetadataRowCount) { + + final String perChannelRowCountsInMetadata = + channelsDataPerTable.stream() + .map(x -> String.valueOf(x.getRowCount())) + .collect(Collectors.joining(",")); + + final String channelNames = + channelsDataPerTable.stream() + .map(x -> String.valueOf(x.getChannelContext().getName())) + .collect(Collectors.joining(",")); + + final String perBlockRowCountsInFooter = + parquetFooterRowsPerBlock.stream().map(String::valueOf).collect(Collectors.joining(",")); + + final long channelsCountInMetadata = channelsDataPerTable.size(); + + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "[%s]The number of rows in Parquet does not match the number of rows in metadata. " + + "parquetTotalRowsInFooter=%d " + + "totalMetadataRowCount=%d " + + "parquetTotalRowsWritten=%d " + + "perChannelRowCountsInMetadata=%s " + + "perBlockRowCountsInFooter=%s " + + "channelsCountInMetadata=%d " + + "countOfSerializedJavaObjects=%d " + + "channelNames=%s", + serializationType, + parquetTotalRowsInFooter, + totalMetadataRowCount, + parquetTotalRowsWritten, + perChannelRowCountsInMetadata, + perBlockRowCountsInFooter, + channelsCountInMetadata, + javaSerializationTotalRowCount, + channelNames)); + } + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java index 17aaa9136..f5835d2a1 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java @@ -207,7 +207,7 @@ private float addRow( ColumnMetadata column = parquetColumn.columnMetadata; ParquetValueParser.ParquetBufferValue valueWithSize = ParquetValueParser.parseColumnValueToParquet( - value, column, parquetColumn.type, forkedStats, defaultTimezone, insertRowsCurrIndex); + value, column, parquetColumn.type, forkedStats, defaultTimezone, insertRowsCurrIndex, clientBufferParameters.getBinaryStringEncoding()); indexedRow[colIndex] = valueWithSize.getValue(); size += valueWithSize.getSize(); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetValueParser.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetValueParser.java index 282a007d4..15770d4f9 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetValueParser.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetValueParser.java @@ -10,6 +10,8 @@ import java.time.ZoneId; import java.util.Optional; import javax.annotation.Nullable; + +import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import net.snowflake.ingest.utils.Utils; @@ -80,12 +82,12 @@ float getSize() { * @return parsed value and byte size of Parquet internal representation */ static ParquetBufferValue parseColumnValueToParquet( - Object value, - ColumnMetadata columnMetadata, - PrimitiveType.PrimitiveTypeName typeName, - RowBufferStats stats, - ZoneId defaultTimezone, - long insertRowsCurrIndex) { + Object value, + ColumnMetadata columnMetadata, + PrimitiveType.PrimitiveTypeName typeName, + RowBufferStats stats, + ZoneId defaultTimezone, + long insertRowsCurrIndex, Constants.BinaryStringEncoding binaryStringEncoding) { Utils.assertNotNull("Parquet column stats", stats); float estimatedParquetSize = 0F; estimatedParquetSize += DEFINITION_LEVEL_ENCODING_BYTE_LEN; @@ -144,7 +146,7 @@ static ParquetBufferValue parseColumnValueToParquet( int length = 0; if (logicalType == AbstractRowBuffer.ColumnLogicalType.BINARY) { value = - getBinaryValueForLogicalBinary(value, stats, columnMetadata, insertRowsCurrIndex); + getBinaryValueForLogicalBinary(value, stats, columnMetadata, insertRowsCurrIndex, binaryStringEncoding); length = ((byte[]) value).length; } else { String str = getBinaryValue(value, stats, columnMetadata, insertRowsCurrIndex); @@ -414,14 +416,16 @@ private static byte[] getBinaryValueForLogicalBinary( Object value, RowBufferStats stats, ColumnMetadata columnMetadata, - final long insertRowsCurrIndex) { + final long insertRowsCurrIndex, + final Constants.BinaryStringEncoding binaryStringEncoding) { String maxLengthString = columnMetadata.getByteLength().toString(); byte[] bytes = DataValidationUtil.validateAndParseBinary( columnMetadata.getName(), value, Optional.of(maxLengthString).map(Integer::parseInt), - insertRowsCurrIndex); + insertRowsCurrIndex, binaryStringEncoding + ); stats.addBinaryValue(bytes); return bytes; } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobRequest.java new file mode 100644 index 000000000..fcb7edf4f --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobRequest.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import java.util.stream.Collectors; + +/** Class used to serialize the blob register request */ +class RegisterBlobRequest implements IStreamingIngestRequest { + @JsonProperty("request_id") + private String requestId; + + @JsonProperty("role") + private String role; + + @JsonProperty("blobs") + private List blobs; + + RegisterBlobRequest(String requestId, String role, List blobs) { + this.requestId = requestId; + this.role = role; + this.blobs = blobs; + } + + String getRequestId() { + return requestId; + } + + String getRole() { + return role; + } + + List getBlobs() { + return blobs; + } + + @Override + public String getStringForLogging() { + return String.format( + "RegisterBlobRequest(requestId=%s, role=%s, blobs=[%s])", + requestId, + role, + blobs.stream().map(BlobMetadata::getPath).collect(Collectors.joining(", "))); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java new file mode 100644 index 000000000..67958618b --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CHANNEL_STATUS; +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CLIENT_CONFIGURE; +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_DROP_CHANNEL; +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_OPEN_CHANNEL; +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_REGISTER_BLOB; +import static net.snowflake.ingest.streaming.internal.StreamingIngestUtils.executeWithRetries; +import static net.snowflake.ingest.utils.Constants.CHANNEL_STATUS_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.DROP_CHANNEL_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.OPEN_CHANNEL_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.REGISTER_BLOB_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; + +import java.io.IOException; +import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; +import net.snowflake.ingest.connection.IngestResponseException; +import net.snowflake.ingest.connection.RequestBuilder; +import net.snowflake.ingest.connection.ServiceResponseHandler; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.Logging; +import net.snowflake.ingest.utils.SFException; + +/** + * The SnowflakeServiceClient class is responsible for making API requests to the Snowflake service. + */ +class SnowflakeServiceClient { + private static final Logging logger = new Logging(SnowflakeServiceClient.class); + + /** HTTP client used for making requests */ + private final CloseableHttpClient httpClient; + + /** Request builder for building streaming API request */ + private final RequestBuilder requestBuilder; + + /** + * Default constructor + * + * @param httpClient the HTTP client used for making requests + * @param requestBuilder the request builder for building streaming API requests + */ + SnowflakeServiceClient(CloseableHttpClient httpClient, RequestBuilder requestBuilder) { + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + } + + /** + * Configures the client given a {@link ClientConfigureRequest}. + * + * @param request the client configuration request + * @return the response from the configuration request + */ + ClientConfigureResponse clientConfigure(ClientConfigureRequest request) + throws IngestResponseException, IOException { + ClientConfigureResponse response = + executeApiRequestWithRetries( + ClientConfigureResponse.class, + request, + CLIENT_CONFIGURE_ENDPOINT, + "client configure", + STREAMING_CLIENT_CONFIGURE); + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Client configure request failed, request={}, message={}", + request.getStringForLogging(), + response.getMessage()); + throw new SFException(ErrorCode.CLIENT_CONFIGURE_FAILURE, response.getMessage()); + } + return response; + } + + /** + * Opens a channel given a {@link OpenChannelRequestInternal}. + * + * @param request the open channel request + * @return the response from the open channel request + */ + OpenChannelResponse openChannel(OpenChannelRequestInternal request) + throws IngestResponseException, IOException { + OpenChannelResponse response = + executeApiRequestWithRetries( + OpenChannelResponse.class, + request, + OPEN_CHANNEL_ENDPOINT, + "open channel", + STREAMING_OPEN_CHANNEL); + + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Open channel request failed, request={}, response={}", + request.getStringForLogging(), + response.getMessage()); + throw new SFException(ErrorCode.OPEN_CHANNEL_FAILURE, response.getMessage()); + } + return response; + } + + /** + * Drops a channel given a {@link DropChannelRequestInternal}. + * + * @param request the drop channel request + * @return the response from the drop channel request + */ + DropChannelResponse dropChannel(DropChannelRequestInternal request) + throws IngestResponseException, IOException { + DropChannelResponse response = + executeApiRequestWithRetries( + DropChannelResponse.class, + request, + DROP_CHANNEL_ENDPOINT, + "drop channel", + STREAMING_DROP_CHANNEL); + + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Drop channel request failed, request={}, response={}", + request.getStringForLogging(), + response.getMessage()); + throw new SFException(ErrorCode.DROP_CHANNEL_FAILURE, response.getMessage()); + } + return response; + } + + /** + * Gets the status of a channel given a {@link ChannelsStatusRequest}. + * + * @param request the channel status request + * @return the response from the channel status request + */ + ChannelsStatusResponse getChannelStatus(ChannelsStatusRequest request) + throws IngestResponseException, IOException { + ChannelsStatusResponse response = + executeApiRequestWithRetries( + ChannelsStatusResponse.class, + request, + CHANNEL_STATUS_ENDPOINT, + "channel status", + STREAMING_CHANNEL_STATUS); + + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Channel status request failed, request={}, response={}", + request.getStringForLogging(), + response.getMessage()); + throw new SFException(ErrorCode.CHANNEL_STATUS_FAILURE, response.getMessage()); + } + return response; + } + + /** + * Registers a blob given a {@link RegisterBlobRequest}. + * + * @param request the register blob request + * @param executionCount the number of times the request has been executed, used for logging + * @return the response from the register blob request + */ + RegisterBlobResponse registerBlob(RegisterBlobRequest request, final int executionCount) + throws IngestResponseException, IOException { + RegisterBlobResponse response = + executeApiRequestWithRetries( + RegisterBlobResponse.class, + request, + REGISTER_BLOB_ENDPOINT, + "register blob", + STREAMING_REGISTER_BLOB); + + if (response.getStatusCode() != RESPONSE_SUCCESS) { + logger.logDebug( + "Register blob request failed, request={}, response={}, executionCount={}", + request.getStringForLogging(), + response.getMessage(), + executionCount); + throw new SFException(ErrorCode.REGISTER_BLOB_FAILURE, response.getMessage()); + } + return response; + } + + private T executeApiRequestWithRetries( + Class responseClass, + IStreamingIngestRequest request, + String endpoint, + String operation, + ServiceResponseHandler.ApiName apiName) + throws IngestResponseException, IOException { + return executeWithRetries( + responseClass, endpoint, request, operation, apiName, this.httpClient, this.requestBuilder); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java index 58e81d116..ca0bbe782 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -45,6 +45,10 @@ class SnowflakeStreamingIngestChannelInternal implements SnowflakeStreamingIn // Reference to the row buffer private final RowBuffer rowBuffer; + private final long insertThrottleIntervalInMs; + private final int insertThrottleThresholdInBytes; + private final int insertThrottleThresholdInPercentage; + private final long maxMemoryLimitInBytes; // Indicates whether the channel is closed private volatile boolean isClosed; @@ -61,6 +65,9 @@ class SnowflakeStreamingIngestChannelInternal implements SnowflakeStreamingIn // The latest cause of channel invalidation private String invalidationCause; + private final MemoryInfoProvider memoryInfoProvider; + private volatile long freeMemoryInBytes = 0; + /** * Constructor for TESTING ONLY which allows us to set the test mode * @@ -121,6 +128,17 @@ class SnowflakeStreamingIngestChannelInternal implements SnowflakeStreamingIn OffsetTokenVerificationFunction offsetTokenVerificationFunction) { this.isClosed = false; this.owningClient = client; + + this.insertThrottleIntervalInMs = + this.owningClient.getParameterProvider().getInsertThrottleIntervalInMs(); + this.insertThrottleThresholdInBytes = + this.owningClient.getParameterProvider().getInsertThrottleThresholdInBytes(); + this.insertThrottleThresholdInPercentage = + this.owningClient.getParameterProvider().getInsertThrottleThresholdInPercentage(); + this.maxMemoryLimitInBytes = + this.owningClient.getParameterProvider().getMaxMemoryLimitInBytes(); + + this.memoryInfoProvider = MemoryInfoProviderFromRuntime.getInstance(); this.channelFlushContext = new ChannelFlushContext( name, dbName, schemaName, tableName, channelSequencer, encryptionKey, encryptionKeyId); @@ -373,7 +391,7 @@ public InsertValidationResponse insertRows( Iterable> rows, @Nullable String startOffsetToken, @Nullable String endOffsetToken) { - throttleInsertIfNeeded(new MemoryInfoProviderFromRuntime()); + throttleInsertIfNeeded(memoryInfoProvider); checkValidation(); if (isClosed()) { @@ -395,7 +413,7 @@ public InsertValidationResponse insertRows( // if a large number of rows are inserted if (this.rowBuffer.getSize() >= this.owningClient.getParameterProvider().getMaxChannelSizeInBytes()) { - this.owningClient.setNeedFlush(); + this.owningClient.setNeedFlush(this.channelFlushContext.getFullyQualifiedTableName()); } return response; @@ -448,8 +466,6 @@ public Map getTableSchema() { /** Check whether we need to throttle the insertRows API */ void throttleInsertIfNeeded(MemoryInfoProvider memoryInfoProvider) { int retry = 0; - long insertThrottleIntervalInMs = - this.owningClient.getParameterProvider().getInsertThrottleIntervalInMs(); while ((hasLowRuntimeMemory(memoryInfoProvider) || (this.owningClient.getFlushService() != null && this.owningClient.getFlushService().throttleDueToQueuedFlushTasks())) @@ -473,22 +489,14 @@ void throttleInsertIfNeeded(MemoryInfoProvider memoryInfoProvider) { /** Check whether we have a low runtime memory condition */ private boolean hasLowRuntimeMemory(MemoryInfoProvider memoryInfoProvider) { - int insertThrottleThresholdInBytes = - this.owningClient.getParameterProvider().getInsertThrottleThresholdInBytes(); - int insertThrottleThresholdInPercentage = - this.owningClient.getParameterProvider().getInsertThrottleThresholdInPercentage(); - long maxMemoryLimitInBytes = - this.owningClient.getParameterProvider().getMaxMemoryLimitInBytes(); long maxMemory = maxMemoryLimitInBytes == MAX_MEMORY_LIMIT_IN_BYTES_DEFAULT ? memoryInfoProvider.getMaxMemory() : maxMemoryLimitInBytes; - long freeMemory = - memoryInfoProvider.getFreeMemory() - + (memoryInfoProvider.getMaxMemory() - memoryInfoProvider.getTotalMemory()); + freeMemoryInBytes = memoryInfoProvider.getFreeMemory(); boolean hasLowRuntimeMemory = - freeMemory < insertThrottleThresholdInBytes - && freeMemory * 100 / maxMemory < insertThrottleThresholdInPercentage; + freeMemoryInBytes < insertThrottleThresholdInBytes + && freeMemoryInBytes * 100 / maxMemory < insertThrottleThresholdInPercentage; if (hasLowRuntimeMemory) { logger.logWarn( "Throttled due to memory pressure, client={}, channel={}.", diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index 2990b49d8..080b4f87d 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -1,23 +1,14 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CHANNEL_STATUS; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_DROP_CHANNEL; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_OPEN_CHANNEL; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_REGISTER_BLOB; -import static net.snowflake.ingest.streaming.internal.StreamingIngestUtils.executeWithRetries; import static net.snowflake.ingest.streaming.internal.StreamingIngestUtils.sleepForRetry; -import static net.snowflake.ingest.utils.Constants.CHANNEL_STATUS_ENDPOINT; import static net.snowflake.ingest.utils.Constants.COMMIT_MAX_RETRY_COUNT; import static net.snowflake.ingest.utils.Constants.COMMIT_RETRY_INTERVAL_IN_MS; -import static net.snowflake.ingest.utils.Constants.DROP_CHANNEL_ENDPOINT; import static net.snowflake.ingest.utils.Constants.ENABLE_TELEMETRY_TO_SF; import static net.snowflake.ingest.utils.Constants.MAX_STREAMING_INGEST_API_CHANNEL_RETRY; -import static net.snowflake.ingest.utils.Constants.OPEN_CHANNEL_ENDPOINT; -import static net.snowflake.ingest.utils.Constants.REGISTER_BLOB_ENDPOINT; import static net.snowflake.ingest.utils.Constants.RESPONSE_ERR_ENQUEUE_TABLE_CHUNK_QUEUE_FULL; import static net.snowflake.ingest.utils.Constants.RESPONSE_ERR_GENERAL_EXCEPTION_RETRY_REQUEST; import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; @@ -37,7 +28,6 @@ import com.codahale.metrics.jmx.JmxReporter; import com.codahale.metrics.jvm.MemoryUsageGaugeSet; import com.codahale.metrics.jvm.ThreadStatesGaugeSet; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import java.io.IOException; import java.net.URI; @@ -94,9 +84,6 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea private static final Logging logger = new Logging(SnowflakeStreamingIngestClientInternal.class); - // Object mapper for all marshalling and unmarshalling - private static final ObjectMapper objectMapper = new ObjectMapper(); - // Counter to generate unique request ids per client private final AtomicLong counter = new AtomicLong(0); @@ -118,6 +105,9 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea // Reference to the flush service private final FlushService flushService; + // Reference to storage manager + private final IStorageManager storageManager; + // Indicates whether the client has closed private volatile boolean isClosed; @@ -145,6 +135,9 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea // Background thread that uploads telemetry data periodically private ScheduledExecutorService telemetryWorker; + // Snowflake service client to make API calls + private SnowflakeServiceClient snowflakeServiceClient; + /** * Constructor * @@ -228,8 +221,14 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea this.setupMetricsForClient(); } + this.snowflakeServiceClient = new SnowflakeServiceClient(this.httpClient, this.requestBuilder); + + this.storageManager = + new InternalStageManager(isTestMode, this.role, this.name, this.snowflakeServiceClient); + try { - this.flushService = new FlushService<>(this, this.channelCache, this.isTestMode); + this.flushService = + new FlushService<>(this, this.channelCache, this.storageManager, this.isTestMode); } catch (Exception e) { // Need to clean up the resources before throwing any exceptions cleanUpResources(); @@ -274,6 +273,7 @@ public SnowflakeStreamingIngestClientInternal( @VisibleForTesting public void injectRequestBuilder(RequestBuilder requestBuilder) { this.requestBuilder = requestBuilder; + this.snowflakeServiceClient = new SnowflakeServiceClient(this.httpClient, this.requestBuilder); } /** @@ -320,39 +320,17 @@ public SnowflakeStreamingIngestChannelInternal openChannel(OpenChannelRequest getName()); try { - Map payload = new HashMap<>(); - payload.put( - "request_id", this.flushService.getClientPrefix() + "_" + counter.getAndIncrement()); - payload.put("channel", request.getChannelName()); - payload.put("table", request.getTableName()); - payload.put("database", request.getDBName()); - payload.put("schema", request.getSchemaName()); - payload.put("write_mode", Constants.WriteMode.CLOUD_STORAGE.name()); - payload.put("role", this.role); - if (request.isOffsetTokenProvided()) { - payload.put("offset_token", request.getOffsetToken()); - } - - OpenChannelResponse response = - executeWithRetries( - OpenChannelResponse.class, - OPEN_CHANNEL_ENDPOINT, - payload, - "open channel", - STREAMING_OPEN_CHANNEL, - httpClient, - requestBuilder); - - // Check for Snowflake specific response code - if (response.getStatusCode() != RESPONSE_SUCCESS) { - logger.logDebug( - "Open channel request failed, channel={}, table={}, client={}, message={}", - request.getChannelName(), - request.getFullyQualifiedTableName(), - getName(), - response.getMessage()); - throw new SFException(ErrorCode.OPEN_CHANNEL_FAILURE, response.getMessage()); - } + OpenChannelRequestInternal openChannelRequest = + new OpenChannelRequestInternal( + this.storageManager.getClientPrefix() + "_" + counter.getAndIncrement(), + this.role, + request.getDBName(), + request.getSchemaName(), + request.getTableName(), + request.getChannelName(), + Constants.WriteMode.CLOUD_STORAGE, + request.getOffsetToken()); + OpenChannelResponse response = snowflakeServiceClient.openChannel(openChannelRequest); logger.logInfo( "Open channel request succeeded, channel={}, table={}, clientSequencer={}," @@ -405,51 +383,28 @@ public void dropChannel(DropChannelRequest request) { getName()); try { - Map payload = new HashMap<>(); - payload.put( - "request_id", this.flushService.getClientPrefix() + "_" + counter.getAndIncrement()); - payload.put("channel", request.getChannelName()); - payload.put("table", request.getTableName()); - payload.put("database", request.getDBName()); - payload.put("schema", request.getSchemaName()); - payload.put("role", this.role); - Long clientSequencer = null; - if (request instanceof DropChannelVersionRequest) { - clientSequencer = ((DropChannelVersionRequest) request).getClientSequencer(); - if (clientSequencer != null) { - payload.put("client_sequencer", clientSequencer); - } - } - - DropChannelResponse response = - executeWithRetries( - DropChannelResponse.class, - DROP_CHANNEL_ENDPOINT, - payload, - "drop channel", - STREAMING_DROP_CHANNEL, - httpClient, - requestBuilder); - - // Check for Snowflake specific response code - if (response.getStatusCode() != RESPONSE_SUCCESS) { - logger.logDebug( - "Drop channel request failed, channel={}, table={}, client={}, message={}", - request.getChannelName(), - request.getFullyQualifiedTableName(), - getName(), - response.getMessage()); - throw new SFException(ErrorCode.DROP_CHANNEL_FAILURE, response.getMessage()); - } + DropChannelRequestInternal dropChannelRequest = + new DropChannelRequestInternal( + this.storageManager.getClientPrefix() + "_" + counter.getAndIncrement(), + this.role, + request.getDBName(), + request.getSchemaName(), + request.getTableName(), + request.getChannelName(), + request instanceof DropChannelVersionRequest + ? ((DropChannelVersionRequest) request).getClientSequencer() + : null); + snowflakeServiceClient.dropChannel(dropChannelRequest); logger.logInfo( "Drop channel request succeeded, channel={}, table={}, clientSequencer={} client={}", request.getChannelName(), request.getFullyQualifiedTableName(), - clientSequencer, + request instanceof DropChannelVersionRequest + ? ((DropChannelVersionRequest) request).getClientSequencer() + : null, getName()); - - } catch (IOException | IngestResponseException e) { + } catch (IngestResponseException | IOException e) { throw new SFException(e, ErrorCode.DROP_CHANNEL_FAILURE, e.getMessage()); } } @@ -493,24 +448,8 @@ ChannelsStatusResponse getChannelsStatus( .collect(Collectors.toList()); request.setChannels(requestDTOs); request.setRole(this.role); - request.setRequestId(this.flushService.getClientPrefix() + "_" + counter.getAndIncrement()); - - String payload = objectMapper.writeValueAsString(request); - - ChannelsStatusResponse response = - executeWithRetries( - ChannelsStatusResponse.class, - CHANNEL_STATUS_ENDPOINT, - payload, - "channel status", - STREAMING_CHANNEL_STATUS, - httpClient, - requestBuilder); - - // Check for Snowflake specific response code - if (response.getStatusCode() != RESPONSE_SUCCESS) { - throw new SFException(ErrorCode.CHANNEL_STATUS_FAILURE, response.getMessage()); - } + + ChannelsStatusResponse response = snowflakeServiceClient.getChannelStatus(request); for (int idx = 0; idx < channels.size(); idx++) { SnowflakeStreamingIngestChannelInternal channel = channels.get(idx); @@ -607,32 +546,12 @@ void registerBlobs(List blobs, final int executionCount) { RegisterBlobResponse response = null; try { - Map payload = new HashMap<>(); - payload.put( - "request_id", this.flushService.getClientPrefix() + "_" + counter.getAndIncrement()); - payload.put("blobs", blobs); - payload.put("role", this.role); - - response = - executeWithRetries( - RegisterBlobResponse.class, - REGISTER_BLOB_ENDPOINT, - payload, - "register blob", - STREAMING_REGISTER_BLOB, - httpClient, - requestBuilder); - - // Check for Snowflake specific response code - if (response.getStatusCode() != RESPONSE_SUCCESS) { - logger.logDebug( - "Register blob request failed for blob={}, client={}, message={}, executionCount={}", - blobs.stream().map(BlobMetadata::getPath).collect(Collectors.toList()), - this.name, - response.getMessage(), - executionCount); - throw new SFException(ErrorCode.REGISTER_BLOB_FAILURE, response.getMessage()); - } + RegisterBlobRequest request = + new RegisterBlobRequest( + this.storageManager.getClientPrefix() + "_" + counter.getAndIncrement(), + this.role, + blobs); + response = snowflakeServiceClient.registerBlob(request, executionCount); } catch (IOException | IngestResponseException e) { throw new SFException(e, ErrorCode.REGISTER_BLOB_FAILURE, e.getMessage()); } @@ -821,8 +740,8 @@ CompletableFuture flush(boolean closing) { } /** Set the flag to indicate that a flush is needed */ - void setNeedFlush() { - this.flushService.setNeedFlush(); + void setNeedFlush(String fullyQualifiedTableName) { + this.flushService.setNeedFlush(fullyQualifiedTableName); } /** Remove the channel in the channel cache if the channel sequencer matches */ diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestResponse.java b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestResponse.java index 1ec01fceb..6c4df8c6d 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestResponse.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestResponse.java @@ -1,9 +1,16 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; +/** + * The StreamingIngestResponse class is an abstract class that represents a response from the + * Snowflake streaming ingest API. This class provides a common structure for all types of responses + * that can be received from the {@link SnowflakeServiceClient}. + */ abstract class StreamingIngestResponse { abstract Long getStatusCode(); + + abstract String getMessage(); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorage.java similarity index 64% rename from src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java rename to src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorage.java index e8e56f383..242b5cc43 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStage.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorage.java @@ -1,28 +1,27 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CLIENT_CONFIGURE; -import static net.snowflake.ingest.streaming.internal.StreamingIngestUtils.executeWithRetries; -import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; -import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; import static net.snowflake.ingest.utils.HttpUtil.generateProxyPropertiesForJDBC; import static net.snowflake.ingest.utils.Utils.getStackTrace; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.annotations.VisibleForTesting; import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.nio.file.Paths; -import java.util.HashMap; -import java.util.Map; +import java.time.Duration; +import java.time.Instant; import java.util.Optional; import java.util.Properties; import java.util.concurrent.TimeUnit; -import java.util.function.Function; import net.snowflake.client.core.OCSPMode; import net.snowflake.client.jdbc.SnowflakeFileTransferAgent; import net.snowflake.client.jdbc.SnowflakeFileTransferConfig; @@ -30,24 +29,35 @@ import net.snowflake.client.jdbc.SnowflakeSQLException; import net.snowflake.client.jdbc.cloud.storage.StageInfo; import net.snowflake.client.jdbc.internal.apache.commons.io.FileUtils; -import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode; -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper; -import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.node.ObjectNode; -import net.snowflake.ingest.connection.IngestResponseException; -import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.Logging; import net.snowflake.ingest.utils.SFException; import net.snowflake.ingest.utils.Utils; -/** Handles uploading files to the Snowflake Streaming Ingest Stage */ -class StreamingIngestStage { +/** Handles uploading files to the Snowflake Streaming Ingest Storage */ +class StreamingIngestStorage { private static final ObjectMapper mapper = new ObjectMapper(); + + /** + * Object mapper for parsing the client/configure response to Jackson version the same as + * jdbc.internal.fasterxml.jackson. We need two different versions of ObjectMapper because {@link + * SnowflakeFileTransferAgent#getFileTransferMetadatas(net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode)} + * expects a different version of json object than {@link StreamingIngestResponse}. TODO: + * SNOW-1493470 Align Jackson version + */ + private static final net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper + parseConfigureResponseMapper = + new net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper(); + private static final long REFRESH_THRESHOLD_IN_MS = TimeUnit.MILLISECONDS.convert(1, TimeUnit.MINUTES); - private static final Logging logger = new Logging(StreamingIngestStage.class); + // Stage credential refresh interval, currently the token will expire in 1hr for GCS and 2hr for + // AWS/Azure, so set it a bit smaller than 1hr + private static final Duration refreshDuration = Duration.ofMinutes(58); + private static Instant prevRefresh = Instant.EPOCH; + + private static final Logging logger = new Logging(StreamingIngestStorage.class); /** * Wrapper class containing SnowflakeFileTransferMetadata and the timestamp at which the metadata @@ -79,60 +89,61 @@ state to record unknown age. } private SnowflakeFileTransferMetadataWithAge fileTransferMetadataWithAge; - private final CloseableHttpClient httpClient; - private final RequestBuilder requestBuilder; - private final String role; + private final IStorageManager owningManager; + private final TLocation location; private final String clientName; - private String clientPrefix; private final int maxUploadRetries; // Proxy parameters that we set while calling the Snowflake JDBC to upload the streams private final Properties proxyProperties; - StreamingIngestStage( - boolean isTestMode, - String role, - CloseableHttpClient httpClient, - RequestBuilder requestBuilder, + /** + * Default constructor + * + * @param owningManager the storage manager owning this storage + * @param clientName The client name + * @param fileLocationInfo The file location information from open channel response + * @param location A reference to the target location + * @param maxUploadRetries The maximum number of retries to attempt + */ + StreamingIngestStorage( + IStorageManager owningManager, String clientName, + FileLocationInfo fileLocationInfo, + TLocation location, int maxUploadRetries) throws SnowflakeSQLException, IOException { - this.httpClient = httpClient; - this.role = role; - this.requestBuilder = requestBuilder; - this.clientName = clientName; - this.proxyProperties = generateProxyPropertiesForJDBC(); - this.maxUploadRetries = maxUploadRetries; - - if (!isTestMode) { - refreshSnowflakeMetadata(); - } + this( + owningManager, + clientName, + (SnowflakeFileTransferMetadataWithAge) null, + location, + maxUploadRetries); + createFileTransferMetadataWithAge(fileLocationInfo); } /** * Constructor for TESTING that takes SnowflakeFileTransferMetadataWithAge as input * - * @param isTestMode must be true - * @param role Snowflake role used by the Client - * @param httpClient http client reference - * @param requestBuilder request builder to build the HTTP request + * @param owningManager the storage manager owning this storage * @param clientName the client name * @param testMetadata SnowflakeFileTransferMetadataWithAge to test with + * @param location A reference to the target location + * @param maxUploadRetries the maximum number of retries to attempt */ - StreamingIngestStage( - boolean isTestMode, - String role, - CloseableHttpClient httpClient, - RequestBuilder requestBuilder, + StreamingIngestStorage( + IStorageManager owningManager, String clientName, SnowflakeFileTransferMetadataWithAge testMetadata, - int maxRetryCount) + TLocation location, + int maxUploadRetries) throws SnowflakeSQLException, IOException { - this(isTestMode, role, httpClient, requestBuilder, clientName, maxRetryCount); - if (!isTestMode) { - throw new SFException(ErrorCode.INTERNAL_ERROR); - } + this.owningManager = owningManager; + this.clientName = clientName; + this.maxUploadRetries = maxUploadRetries; + this.proxyProperties = generateProxyPropertiesForJDBC(); + this.location = location; this.fileTransferMetadataWithAge = testMetadata; } @@ -180,13 +191,19 @@ private void putRemote(String fullFilePath, byte[] data, int retryCount) InputStream inStream = new ByteArrayInputStream(data); try { + // Proactively refresh the credential if it's going to expire, to avoid the token expiration + // error from JDBC which confuses customer + if (Instant.now().isAfter(prevRefresh.plus(refreshDuration))) { + refreshSnowflakeMetadata(); + } + SnowflakeFileTransferAgent.uploadWithoutConnection( SnowflakeFileTransferConfig.Builder.newInstance() .setSnowflakeFileTransferMetadata(fileTransferMetadataCopy) .setUploadStream(inStream) .setRequireCompress(false) .setOcspMode(OCSPMode.FAIL_OPEN) - .setStreamingIngestClientKey(this.clientPrefix) + .setStreamingIngestClientKey(this.owningManager.getClientPrefix()) .setStreamingIngestClientName(this.clientName) .setProxyProperties(this.proxyProperties) .setDestFileName(fullFilePath) @@ -194,9 +211,6 @@ private void putRemote(String fullFilePath, byte[] data, int retryCount) } catch (Exception e) { if (retryCount == 0) { // for the first exception, we always perform a metadata refresh. - logger.logInfo( - "Stage metadata need to be refreshed due to upload error: {} on first retry attempt", - e.getMessage()); this.refreshSnowflakeMetadata(); } if (retryCount >= maxUploadRetries) { @@ -244,32 +258,27 @@ synchronized SnowflakeFileTransferMetadataWithAge refreshSnowflakeMetadata(boole return fileTransferMetadataWithAge; } - Map payload = new HashMap<>(); - payload.put("role", this.role); - Map response = this.makeClientConfigureCall(payload); + FileLocationInfo location = + this.owningManager.getRefreshedLocation(this.location, Optional.empty()); + return createFileTransferMetadataWithAge(location); + } - JsonNode responseNode = this.parseClientConfigureResponse(response); - // Do not change the prefix everytime we have to refresh credentials - if (Utils.isNullOrEmpty(this.clientPrefix)) { - this.clientPrefix = createClientPrefix(responseNode); - } - Utils.assertStringNotNullOrEmpty("client prefix", this.clientPrefix); + private SnowflakeFileTransferMetadataWithAge createFileTransferMetadataWithAge( + FileLocationInfo fileLocationInfo) + throws JsonProcessingException, + net.snowflake.client.jdbc.internal.fasterxml.jackson.core.JsonProcessingException, + SnowflakeSQLException { + Utils.assertStringNotNullOrEmpty("client prefix", this.owningManager.getClientPrefix()); - if (responseNode - .get("data") - .get("stageInfo") - .get("locationType") - .toString() + if (fileLocationInfo + .getLocationType() .replaceAll( "^[\"]|[\"]$", "") // Replace the first and last character if they're double quotes .equals(StageInfo.StageType.LOCAL_FS.name())) { this.fileTransferMetadataWithAge = new SnowflakeFileTransferMetadataWithAge( - responseNode - .get("data") - .get("stageInfo") - .get("location") - .toString() + fileLocationInfo + .getLocation() .replaceAll( "^[\"]|[\"]$", ""), // Replace the first and last character if they're double quotes @@ -278,26 +287,14 @@ synchronized SnowflakeFileTransferMetadataWithAge refreshSnowflakeMetadata(boole this.fileTransferMetadataWithAge = new SnowflakeFileTransferMetadataWithAge( (SnowflakeFileTransferMetadataV1) - SnowflakeFileTransferAgent.getFileTransferMetadatas(responseNode).get(0), + SnowflakeFileTransferAgent.getFileTransferMetadatas( + parseFileLocationInfo(fileLocationInfo)) + .get(0), Optional.of(System.currentTimeMillis())); } - return this.fileTransferMetadataWithAge; - } - /** - * Creates a client-specific prefix that will be also part of the files registered by this client. - * The prefix will include a server-side generated string and the GlobalID of the deployment the - * client is registering blobs to. The latter (deploymentId) is needed in order to guarantee that - * blob filenames are unique across deployments even with replication enabled. - * - * @param response the client/configure response from the server - * @return the client prefix. - */ - private String createClientPrefix(final JsonNode response) { - final String prefix = response.get("prefix").textValue(); - final String deploymentId = - response.has("deployment_id") ? "_" + response.get("deployment_id").longValue() : ""; - return prefix + deploymentId; + prevRefresh = Instant.now(); + return this.fileTransferMetadataWithAge; } /** @@ -309,74 +306,39 @@ private String createClientPrefix(final JsonNode response) { SnowflakeFileTransferMetadataV1 fetchSignedURL(String fileName) throws SnowflakeSQLException, IOException { - Map payload = new HashMap<>(); - payload.put("role", this.role); - payload.put("file_name", fileName); - Map response = this.makeClientConfigureCall(payload); - - JsonNode responseNode = this.parseClientConfigureResponse(response); + FileLocationInfo location = + this.owningManager.getRefreshedLocation(this.location, Optional.of(fileName)); SnowflakeFileTransferMetadataV1 metadata = (SnowflakeFileTransferMetadataV1) - SnowflakeFileTransferAgent.getFileTransferMetadatas(responseNode).get(0); + SnowflakeFileTransferAgent.getFileTransferMetadatas(parseFileLocationInfo(location)) + .get(0); // Transfer agent trims path for fileName metadata.setPresignedUrlFileName(fileName); return metadata; } - private static class MapStatusGetter implements Function { - public MapStatusGetter() {} - - public Long apply(T input) { - try { - return ((Integer) ((Map) input).get("status_code")).longValue(); - } catch (Exception e) { - throw new SFException(ErrorCode.INTERNAL_ERROR, "failed to get status_code from response"); - } - } - } - - private static final MapStatusGetter statusGetter = new MapStatusGetter(); - - private JsonNode parseClientConfigureResponse(Map response) { - JsonNode responseNode = mapper.valueToTree(response); + private net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode + parseFileLocationInfo(FileLocationInfo fileLocationInfo) + throws JsonProcessingException, + net.snowflake.client.jdbc.internal.fasterxml.jackson.core.JsonProcessingException { + JsonNode fileLocationInfoNode = mapper.valueToTree(fileLocationInfo); // Currently there are a few mismatches between the client/configure response and what // SnowflakeFileTransferAgent expects - ObjectNode mutable = (ObjectNode) responseNode; - mutable.putObject("data"); - ObjectNode dataNode = (ObjectNode) mutable.get("data"); - dataNode.set("stageInfo", responseNode.get("stage_location")); + + ObjectNode node = mapper.createObjectNode(); + node.putObject("data"); + ObjectNode dataNode = (ObjectNode) node.get("data"); + dataNode.set("stageInfo", fileLocationInfoNode); // JDBC expects this field which maps to presignedFileUrlName. We will set this later dataNode.putArray("src_locations").add("placeholder"); - return responseNode; - } - private Map makeClientConfigureCall(Map payload) - throws IOException { - try { - - Map response = - executeWithRetries( - Map.class, - CLIENT_CONFIGURE_ENDPOINT, - mapper.writeValueAsString(payload), - "client configure", - STREAMING_CLIENT_CONFIGURE, - httpClient, - requestBuilder, - statusGetter); - - // Check for Snowflake specific response code - if (!response.get("status_code").equals((int) RESPONSE_SUCCESS)) { - throw new SFException( - ErrorCode.CLIENT_CONFIGURE_FAILURE, response.get("message").toString()); - } - return response; - } catch (IngestResponseException e) { - throw new SFException(e, ErrorCode.CLIENT_CONFIGURE_FAILURE, e.getMessage()); - } + // use String as intermediate object to avoid Jackson version mismatch + // TODO: SNOW-1493470 Align Jackson version + String responseString = mapper.writeValueAsString(node); + return parseConfigureResponseMapper.readTree(responseString); } /** @@ -422,9 +384,4 @@ void putLocal(String fullFilePath, byte[] data) { throw new SFException(ex, ErrorCode.BLOB_UPLOAD_FAILURE); } } - - /** Get the server generated unique prefix for this client */ - String getClientPrefix() { - return this.clientPrefix; - } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtils.java b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtils.java index 56e960064..538283b4e 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtils.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtils.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.utils.Constants.MAX_STREAMING_INGEST_API_CHANNEL_RETRY; @@ -6,7 +10,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; -import java.util.Map; import java.util.function.Function; import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpUriRequest; @@ -77,7 +80,7 @@ public static void sleepForRetry(int executionCount) { static T executeWithRetries( Class targetClass, String endpoint, - Map payload, + IStreamingIngestRequest payload, String message, ServiceResponseHandler.ApiName apiName, CloseableHttpClient httpClient, diff --git a/src/main/java/net/snowflake/ingest/utils/Constants.java b/src/main/java/net/snowflake/ingest/utils/Constants.java index 3579e4d24..6df7b6727 100644 --- a/src/main/java/net/snowflake/ingest/utils/Constants.java +++ b/src/main/java/net/snowflake/ingest/utils/Constants.java @@ -141,6 +141,24 @@ public static BdecParquetCompression fromName(String name) { name, Arrays.asList(BdecParquetCompression.values()))); } } + + public enum BinaryStringEncoding { + HEX, + BASE64; + + public static BinaryStringEncoding fromName(String name) { + for (BinaryStringEncoding e : BinaryStringEncoding.values()) { + if (e.name().toLowerCase().equals(name.toLowerCase())) { + return e; + } + } + throw new IllegalArgumentException( + String.format( + "Unsupported BinaryStringEncoding = '%s', allowed values are %s", + name, Arrays.asList(BinaryStringEncoding.values()))); + } + } + // Parameters public static final boolean DISABLE_BACKGROUND_FLUSH = false; public static final boolean COMPRESS_BLOB_TWICE = false; diff --git a/src/main/java/net/snowflake/ingest/utils/ErrorCode.java b/src/main/java/net/snowflake/ingest/utils/ErrorCode.java index b863717e9..a9aab9c3b 100644 --- a/src/main/java/net/snowflake/ingest/utils/ErrorCode.java +++ b/src/main/java/net/snowflake/ingest/utils/ErrorCode.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.utils; @@ -41,7 +41,8 @@ public enum ErrorCode { OAUTH_REFRESH_TOKEN_ERROR("0033"), INVALID_CONFIG_PARAMETER("0034"), CRYPTO_PROVIDER_ERROR("0035"), - DROP_CHANNEL_FAILURE("0036"); + DROP_CHANNEL_FAILURE("0036"), + CLIENT_DEPLOYMENT_ID_MISMATCH("0037"); public static final String errorMessageResource = "net.snowflake.ingest.ingest_error_messages"; diff --git a/src/main/java/net/snowflake/ingest/utils/HttpUtil.java b/src/main/java/net/snowflake/ingest/utils/HttpUtil.java index 1ff65a095..0da6370f7 100644 --- a/src/main/java/net/snowflake/ingest/utils/HttpUtil.java +++ b/src/main/java/net/snowflake/ingest/utils/HttpUtil.java @@ -79,7 +79,7 @@ public class HttpUtil { private static final ReentrantLock idleConnectionMonitorThreadLock = new ReentrantLock(true); private static final int DEFAULT_CONNECTION_TIMEOUT_MINUTES = 1; - private static final int DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_MINUTES = 5; + private static final int DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_MINUTES = 1; /** * After how many seconds of inactivity should be idle connections evicted from the connection @@ -294,7 +294,8 @@ static HttpRequestRetryHandler getHttpRequestRetryHandler() { if (exception instanceof NoHttpResponseException || exception instanceof javax.net.ssl.SSLException || exception instanceof java.net.SocketException - || exception instanceof java.net.UnknownHostException) { + || exception instanceof java.net.UnknownHostException + || exception instanceof java.net.SocketTimeoutException) { LOGGER.info( "Retrying request which caused {} with " + "URI:{}, retryCount:{} and maxRetryCount:{}", exception.getClass().getName(), diff --git a/src/main/java/net/snowflake/ingest/utils/ParameterProvider.java b/src/main/java/net/snowflake/ingest/utils/ParameterProvider.java index b98972a7d..8d09a62a6 100644 --- a/src/main/java/net/snowflake/ingest/utils/ParameterProvider.java +++ b/src/main/java/net/snowflake/ingest/utils/ParameterProvider.java @@ -39,6 +39,9 @@ public class ParameterProvider { public static final String BDEC_PARQUET_COMPRESSION_ALGORITHM = "BDEC_PARQUET_COMPRESSION_ALGORITHM".toLowerCase(); + public static final String BINARY_STRING_ENCODING = + "BINARY_STRING_ENCODING".toLowerCase(); + // Default values public static final long BUFFER_FLUSH_CHECK_INTERVAL_IN_MILLIS_DEFAULT = 100; public static final long INSERT_THROTTLE_INTERVAL_IN_MILLIS_DEFAULT = 1000; @@ -50,8 +53,8 @@ public class ParameterProvider { public static final int IO_TIME_CPU_RATIO_DEFAULT = 2; public static final int BLOB_UPLOAD_MAX_RETRY_COUNT_DEFAULT = 24; public static final long MAX_MEMORY_LIMIT_IN_BYTES_DEFAULT = -1L; - public static final long MAX_CHANNEL_SIZE_IN_BYTES_DEFAULT = 128 * 1024 * 1024; - public static final long MAX_CHUNK_SIZE_IN_BYTES_DEFAULT = 512 * 1024 * 1024; + public static final long MAX_CHANNEL_SIZE_IN_BYTES_DEFAULT = 64 * 1024 * 1024; + public static final long MAX_CHUNK_SIZE_IN_BYTES_DEFAULT = 256 * 1024 * 1024; // Lag related parameters public static final long MAX_CLIENT_LAG_DEFAULT = 1000; // 1 second @@ -64,6 +67,9 @@ public class ParameterProvider { public static final Constants.BdecParquetCompression BDEC_PARQUET_COMPRESSION_ALGORITHM_DEFAULT = Constants.BdecParquetCompression.GZIP; + public static final Constants.BinaryStringEncoding BINARY_STRING_ENCODING_DEFAULT = + Constants.BinaryStringEncoding.HEX; + /* Parameter that enables using internal Parquet buffers for buffering of rows before serializing. It reduces memory consumption compared to using Java Objects for buffering.*/ public static final boolean ENABLE_PARQUET_INTERNAL_BUFFERING_DEFAULT = false; @@ -188,6 +194,13 @@ private void setParameterMap(Map parameterOverrides, Properties BDEC_PARQUET_COMPRESSION_ALGORITHM_DEFAULT, parameterOverrides, props); + + this.updateValue( + BINARY_STRING_ENCODING, + BINARY_STRING_ENCODING_DEFAULT, + parameterOverrides, + props); + } /** @return Longest interval in milliseconds between buffer flushes */ @@ -407,6 +420,18 @@ public Constants.BdecParquetCompression getBdecParquetCompressionAlgorithm() { return Constants.BdecParquetCompression.fromName((String) val); } + /** @return binary string encoding */ + public Constants.BinaryStringEncoding getBinaryStringEncoding() { + Object val = + this.parameterMap.getOrDefault( + BINARY_STRING_ENCODING, BINARY_STRING_ENCODING_DEFAULT); + if (val instanceof Constants.BinaryStringEncoding) { + return (Constants.BinaryStringEncoding) val; + } + return Constants.BinaryStringEncoding.fromName((String) val); + } + + @Override public String toString() { return "ParameterProvider{" + "parameterMap=" + parameterMap + '}'; diff --git a/src/main/java/net/snowflake/ingest/utils/Utils.java b/src/main/java/net/snowflake/ingest/utils/Utils.java index a06df4027..5220625da 100644 --- a/src/main/java/net/snowflake/ingest/utils/Utils.java +++ b/src/main/java/net/snowflake/ingest/utils/Utils.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.utils; @@ -384,4 +384,31 @@ public static String getStackTrace(Throwable e) { } return stackTrace.toString(); } + + /** + * Get the fully qualified table name + * + * @param dbName the database name + * @param schemaName the schema name + * @param tableName the table name + * @return the fully qualified table name + */ + public static String getFullyQualifiedTableName( + String dbName, String schemaName, String tableName) { + return String.format("%s.%s.%s", dbName, schemaName, tableName); + } + + /** + * Get the fully qualified channel name + * + * @param dbName the database name + * @param schemaName the schema name + * @param tableName the table name + * @param channelName the channel name + * @return the fully qualified channel name + */ + public static String getFullyQualifiedChannelName( + String dbName, String schemaName, String tableName, String channelName) { + return String.format("%s.%s.%s.%s", dbName, schemaName, tableName, channelName); + } } diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java index 8b71cfd0e..58e7df4f3 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java @@ -6,6 +6,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import net.snowflake.ingest.utils.Constants; @@ -17,6 +18,7 @@ import org.apache.parquet.column.values.factory.DefaultV1ValuesWriterFactory; import org.apache.parquet.crypto.FileEncryptionProperties; import org.apache.parquet.hadoop.api.WriteSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.io.DelegatingPositionOutputStream; import org.apache.parquet.io.OutputFile; import org.apache.parquet.io.ParquetEncodingException; @@ -35,6 +37,7 @@ public class BdecParquetWriter implements AutoCloseable { private final InternalParquetRecordWriter> writer; private final CodecFactory codecFactory; + private long rowsWritten = 0; /** * Creates a BDEC specific parquet writer. @@ -100,14 +103,28 @@ public BdecParquetWriter( encodingProps); } + /** @return List of row counts per block stored in the parquet footer */ + public List getRowCountsFromFooter() { + final List blockRowCounts = new ArrayList<>(); + for (BlockMetaData metadata : writer.getFooter().getBlocks()) { + blockRowCounts.add(metadata.getRowCount()); + } + return blockRowCounts; + } + public void writeRow(List row) { try { writer.write(row); + rowsWritten++; } catch (InterruptedException | IOException e) { throw new SFException(ErrorCode.INTERNAL_ERROR, "parquet row write failed", e); } } + public long getRowsWritten() { + return rowsWritten; + } + @Override public void close() throws IOException { try { diff --git a/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties b/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties index d2fea0b0d..03e50d9b6 100644 --- a/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties +++ b/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties @@ -37,5 +37,6 @@ 0032=URI builder fail to build url: {0} 0033=OAuth token refresh failure: {0} 0034=Invalid config parameter: {0} -0035=Too large batch of rows passed to insertRows, the batch size cannot exceed {0} bytes, recommended batch size for optimal performance and memory utilization is {1} bytes. We recommend splitting large batches into multiple smaller ones and call insertRows for each smaller batch separately. -0036=Failed to load {0}. If you use FIPS, import BouncyCastleFipsProvider in the application: {1} \ No newline at end of file +0035=Failed to load {0}. If you use FIPS, import BouncyCastleFipsProvider in the application: {1} +0036=Failed to drop channel: {0} +0037=Deployment ID mismatch, Client was created on: {0}, Got upload location for: {1}. Please restart client: {2}. \ No newline at end of file diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java new file mode 100644 index 000000000..e220aec79 --- /dev/null +++ b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java @@ -0,0 +1,133 @@ +package net.snowflake.ingest.streaming.internal; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import net.snowflake.ingest.utils.Constants; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.Pair; +import net.snowflake.ingest.utils.SFException; +import org.apache.parquet.hadoop.BdecParquetWriter; +import org.apache.parquet.schema.MessageType; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.Mockito; + +public class BlobBuilderTest { + + @Test + public void testSerializationErrors() throws Exception { + // Construction succeeds if both data and metadata contain 1 row + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(1, false)), + Constants.BdecVersion.THREE); + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(1, true)), + Constants.BdecVersion.THREE); + + // Construction fails if metadata contains 0 rows and data 1 row + try { + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(0, false)), + Constants.BdecVersion.THREE); + Assert.fail("Should not pass enableParquetInternalBuffering=false"); + } catch (SFException e) { + Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); + Assert.assertTrue(e.getMessage().contains("serializeFromJavaObjects")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("totalMetadataRowCount=0")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsWritten=1")); + Assert.assertTrue(e.getMessage().contains("perChannelRowCountsInMetadata=0")); + Assert.assertTrue(e.getMessage().contains("perBlockRowCountsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("channelsCountInMetadata=1")); + Assert.assertTrue(e.getMessage().contains("countOfSerializedJavaObjects=1")); + } + + try { + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(0, true)), + Constants.BdecVersion.THREE); + Assert.fail("Should not pass enableParquetInternalBuffering=true"); + } catch (SFException e) { + Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); + Assert.assertTrue(e.getMessage().contains("serializeFromParquetWriteBuffers")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("totalMetadataRowCount=0")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsWritten=1")); + Assert.assertTrue(e.getMessage().contains("perChannelRowCountsInMetadata=0")); + Assert.assertTrue(e.getMessage().contains("perBlockRowCountsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("channelsCountInMetadata=1")); + Assert.assertTrue(e.getMessage().contains("countOfSerializedJavaObjects=-1")); + } + } + + /** + * Creates a channel data configurable number of rows in metadata and 1 physical row (using both + * with and without internal buffering optimization) + */ + private List> createChannelDataPerTable( + int metadataRowCount, boolean enableParquetInternalBuffering) throws IOException { + String columnName = "C1"; + ChannelData channelData = Mockito.spy(new ChannelData<>()); + MessageType schema = createSchema(columnName); + Mockito.doReturn( + new ParquetFlusher( + schema, + enableParquetInternalBuffering, + 100L, + Constants.BdecParquetCompression.GZIP)) + .when(channelData) + .createFlusher(); + + channelData.setRowSequencer(1L); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + BdecParquetWriter bdecParquetWriter = + new BdecParquetWriter( + stream, + schema, + new HashMap<>(), + "CHANNEL", + 1000, + Constants.BdecParquetCompression.GZIP); + bdecParquetWriter.writeRow(Collections.singletonList("1")); + channelData.setVectors( + new ParquetChunkData( + Collections.singletonList(Collections.singletonList("A")), + bdecParquetWriter, + stream, + new HashMap<>())); + channelData.setColumnEps(new HashMap<>()); + channelData.setRowCount(metadataRowCount); + channelData.setMinMaxInsertTimeInMs(new Pair<>(2L, 3L)); + + channelData.getColumnEps().putIfAbsent(columnName, new RowBufferStats(columnName, null, 1)); + channelData.setChannelContext( + new ChannelFlushContext("channel1", "DB", "SCHEMA", "TABLE", 1L, "enc", 1L)); + return Collections.singletonList(channelData); + } + + private static MessageType createSchema(String columnName) { + ParquetTypeGenerator.ParquetTypeInfo c1 = + ParquetTypeGenerator.generateColumnParquetTypeInfo(createTestTextColumn(columnName), 1); + return new MessageType("bdec", Collections.singletonList(c1.getParquetType())); + } + + private static ColumnMetadata createTestTextColumn(String name) { + ColumnMetadata colChar = new ColumnMetadata(); + colChar.setOrdinal(1); + colChar.setName(name); + colChar.setPhysicalType("LOB"); + colChar.setNullable(true); + colChar.setLogicalType("TEXT"); + colChar.setByteLength(14); + colChar.setLength(11); + colChar.setScale(0); + return colChar; + } +} diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ChannelCacheTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/ChannelCacheTest.java index 947908ef9..db1d737ba 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ChannelCacheTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ChannelCacheTest.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static java.time.ZoneOffset.UTC; @@ -95,7 +99,7 @@ public void testAddChannel() { UTC); cache.addChannel(channel); Assert.assertEquals(1, cache.getSize()); - Assert.assertTrue(channel == cache.iterator().next().getValue().get(channelName)); + Assert.assertTrue(channel == cache.entrySet().iterator().next().getValue().get(channelName)); SnowflakeStreamingIngestChannelInternal channelDup = new SnowflakeStreamingIngestChannelInternal<>( @@ -117,7 +121,7 @@ public void testAddChannel() { Assert.assertTrue(channelDup.isValid()); Assert.assertEquals(1, cache.getSize()); ConcurrentHashMap> channels = - cache.iterator().next().getValue(); + cache.entrySet().iterator().next().getValue(); Assert.assertEquals(1, channels.size()); Assert.assertTrue(channelDup == channels.get(channelName)); Assert.assertFalse(channel == channelDup); @@ -130,7 +134,7 @@ public void testIterator() { Map.Entry< String, ConcurrentHashMap>>> - iter = cache.iterator(); + iter = cache.entrySet().iterator(); Map.Entry< String, ConcurrentHashMap>> @@ -160,7 +164,7 @@ public void testCloseAllChannels() { Map.Entry< String, ConcurrentHashMap>>> - iter = cache.iterator(); + iter = cache.entrySet().iterator(); while (iter.hasNext()) { for (SnowflakeStreamingIngestChannelInternal channel : iter.next().getValue().values()) { Assert.assertTrue(channel.isClosed()); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java index 8ab22619f..99bbbfd92 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java @@ -45,6 +45,8 @@ import java.util.Map; import java.util.Optional; import java.util.TimeZone; + +import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import org.apache.commons.codec.DecoderException; @@ -888,76 +890,76 @@ public void testValidateAndParseBinary() throws DecoderException { assertArrayEquals( "honk".getBytes(StandardCharsets.UTF_8), validateAndParseBinary( - "COL", "honk".getBytes(StandardCharsets.UTF_8), Optional.empty(), 0)); + "COL", "honk".getBytes(StandardCharsets.UTF_8), Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); assertArrayEquals( new byte[] {-1, 0, 1}, - validateAndParseBinary("COL", new byte[] {-1, 0, 1}, Optional.empty(), 0)); + validateAndParseBinary("COL", new byte[] {-1, 0, 1}, Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); assertArrayEquals( Hex.decodeHex("1234567890abcdef"), // pragma: allowlist secret NOT A SECRET validateAndParseBinary( "COL", "1234567890abcdef", // pragma: allowlist secret NOT A SECRET Optional.empty(), - 0)); // pragma: allowlist secret NOT A SECRET + 0, Constants.BinaryStringEncoding.HEX)); // pragma: allowlist secret NOT A SECRET assertArrayEquals( Hex.decodeHex("1234567890abcdef"), // pragma: allowlist secret NOT A SECRET validateAndParseBinary( "COL", " 1234567890abcdef \t\n", Optional.empty(), - 0)); // pragma: allowlist secret NOT A SECRET + 0, Constants.BinaryStringEncoding.HEX)); // pragma: allowlist secret NOT A SECRET assertArrayEquals( - maxAllowedArray, validateAndParseBinary("COL", maxAllowedArray, Optional.empty(), 0)); + maxAllowedArray, validateAndParseBinary("COL", maxAllowedArray, Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); assertArrayEquals( maxAllowedArrayMinusOne, - validateAndParseBinary("COL", maxAllowedArrayMinusOne, Optional.empty(), 0)); + validateAndParseBinary("COL", maxAllowedArrayMinusOne, Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); // Too large arrays should be rejected expectError( ErrorCode.INVALID_VALUE_ROW, - () -> validateAndParseBinary("COL", new byte[1], Optional.of(0), 0)); + () -> validateAndParseBinary("COL", new byte[1], Optional.of(0), 0, Constants.BinaryStringEncoding.HEX)); expectError( ErrorCode.INVALID_VALUE_ROW, - () -> validateAndParseBinary("COL", new byte[BYTES_8_MB + 1], Optional.empty(), 0)); + () -> validateAndParseBinary("COL", new byte[BYTES_8_MB + 1], Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( ErrorCode.INVALID_VALUE_ROW, - () -> validateAndParseBinary("COL", new byte[8], Optional.of(7), 0)); + () -> validateAndParseBinary("COL", new byte[8], Optional.of(7), 0, Constants.BinaryStringEncoding.HEX)); expectError( ErrorCode.INVALID_VALUE_ROW, - () -> validateAndParseBinary("COL", "aabb", Optional.of(1), 0)); + () -> validateAndParseBinary("COL", "aabb", Optional.of(1), 0, Constants.BinaryStringEncoding.HEX)); // unsupported data types should fail expectError( ErrorCode.INVALID_VALUE_ROW, - () -> validateAndParseBinary("COL", "000", Optional.empty(), 0)); + () -> validateAndParseBinary("COL", "000", Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( ErrorCode.INVALID_VALUE_ROW, - () -> validateAndParseBinary("COL", "abcg", Optional.empty(), 0)); + () -> validateAndParseBinary("COL", "abcg", Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( - ErrorCode.INVALID_VALUE_ROW, () -> validateAndParseBinary("COL", "c", Optional.empty(), 0)); + ErrorCode.INVALID_VALUE_ROW, () -> validateAndParseBinary("COL", "c", Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseBinary( - "COL", Arrays.asList((byte) 1, (byte) 2, (byte) 3), Optional.empty(), 0)); + "COL", Arrays.asList((byte) 1, (byte) 2, (byte) 3), Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( - ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseBinary("COL", 1, Optional.empty(), 0)); + ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseBinary("COL", 1, Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( - ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseBinary("COL", 12, Optional.empty(), 0)); + ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseBinary("COL", 12, Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( ErrorCode.INVALID_FORMAT_ROW, - () -> validateAndParseBinary("COL", 1.5, Optional.empty(), 0)); + () -> validateAndParseBinary("COL", 1.5, Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( ErrorCode.INVALID_FORMAT_ROW, - () -> validateAndParseBinary("COL", BigInteger.ONE, Optional.empty(), 0)); + () -> validateAndParseBinary("COL", BigInteger.ONE, Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( ErrorCode.INVALID_FORMAT_ROW, - () -> validateAndParseBinary("COL", false, Optional.empty(), 0)); + () -> validateAndParseBinary("COL", false, Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectError( ErrorCode.INVALID_FORMAT_ROW, - () -> validateAndParseBinary("COL", new Object(), Optional.empty(), 0)); + () -> validateAndParseBinary("COL", new Object(), Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); } @Test @@ -1179,19 +1181,19 @@ public void testExceptionMessages() { "The given row cannot be converted to the internal format: Object of type java.lang.Object" + " cannot be ingested into Snowflake column COL of type BINARY, rowIndex:0. Allowed" + " Java types: byte[], String", - () -> validateAndParseBinary("COL", new Object(), Optional.empty(), 0)); + () -> validateAndParseBinary("COL", new Object(), Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); expectErrorCodeAndMessage( ErrorCode.INVALID_VALUE_ROW, "The given row cannot be converted to the internal format due to invalid value: Value" + " cannot be ingested into Snowflake column COL of type BINARY, rowIndex:0, reason:" + " Binary too long: length=2 maxLength=1", - () -> validateAndParseBinary("COL", new byte[] {1, 2}, Optional.of(1), 0)); + () -> validateAndParseBinary("COL", new byte[] {1, 2}, Optional.of(1), 0, Constants.BinaryStringEncoding.HEX)); expectErrorCodeAndMessage( ErrorCode.INVALID_VALUE_ROW, "The given row cannot be converted to the internal format due to invalid value: Value" + " cannot be ingested into Snowflake column COL of type BINARY, rowIndex:0, reason:" + " Not a valid hex string", - () -> validateAndParseBinary("COL", "ghi", Optional.empty(), 0)); + () -> validateAndParseBinary("COL", "ghi", Optional.empty(), 0, Constants.BinaryStringEncoding.HEX)); // VARIANT expectErrorCodeAndMessage( diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java index f200c7177..5ab500f9e 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java @@ -36,7 +36,9 @@ import java.util.Map; import java.util.TimeZone; import java.util.UUID; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; @@ -51,6 +53,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; +import org.mockito.stubbing.Answer; public class FlushServiceTest { public FlushServiceTest() { @@ -77,22 +80,31 @@ private abstract static class TestContext implements AutoCloseable { ChannelCache channelCache; final Map> channels = new HashMap<>(); FlushService flushService; - StreamingIngestStage stage; + IStorageManager storageManager; + StreamingIngestStorage storage; ParameterProvider parameterProvider; RegisterService registerService; final List> channelData = new ArrayList<>(); TestContext() { - stage = Mockito.mock(StreamingIngestStage.class); - Mockito.when(stage.getClientPrefix()).thenReturn("client_prefix"); + storage = Mockito.mock(StreamingIngestStorage.class); parameterProvider = new ParameterProvider(); client = Mockito.mock(SnowflakeStreamingIngestClientInternal.class); Mockito.when(client.getParameterProvider()).thenReturn(parameterProvider); + storageManager = Mockito.spy(new InternalStageManager<>(true, "role", "client", null)); + Mockito.doReturn(storage).when(storageManager).getStorage(ArgumentMatchers.any()); + Mockito.when(storageManager.getClientPrefix()).thenReturn("client_prefix"); + Mockito.when(client.getParameterProvider()) + .thenAnswer((Answer) (i) -> parameterProvider); channelCache = new ChannelCache<>(); Mockito.when(client.getChannelCache()).thenReturn(channelCache); registerService = Mockito.spy(new RegisterService(client, client.isTestMode())); - flushService = Mockito.spy(new FlushService<>(client, channelCache, stage, true)); + flushService = Mockito.spy(new FlushService<>(client, channelCache, storageManager, true)); + } + + void setParameterOverride(Map parameterOverride) { + this.parameterProvider = new ParameterProvider(parameterOverride, null); } ChannelData flushChannel(String name) { @@ -105,7 +117,10 @@ ChannelData flushChannel(String name) { BlobMetadata buildAndUpload() throws Exception { List>> blobData = Collections.singletonList(channelData); - return flushService.buildAndUpload("file_name", blobData); + return flushService.buildAndUpload( + "file_name", + blobData, + blobData.get(0).get(0).getChannelContext().getFullyQualifiedTableName()); } abstract SnowflakeStreamingIngestChannelInternal createChannel( @@ -389,10 +404,11 @@ private static ColumnMetadata createLargeTestTextColumn(String name) { @Test public void testGetFilePath() { TestContext testContext = testContextFactory.create(); - FlushService flushService = testContext.flushService; + IStorageManager storageManager = testContext.storageManager; Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC")); String clientPrefix = "honk"; - String outputString = flushService.getBlobPath(calendar, clientPrefix); + String outputString = + ((InternalStageManager) storageManager).getBlobPath(calendar, clientPrefix); Path outputPath = Paths.get(outputString); Assert.assertTrue(outputPath.getFileName().toString().contains(clientPrefix)); Assert.assertTrue( @@ -422,30 +438,125 @@ public void testGetFilePath() { @Test public void testFlush() throws Exception { - TestContext testContext = testContextFactory.create(); + int numChannels = 4; + Long maxLastFlushTime = Long.MAX_VALUE - 1000L; // -1000L to avoid jitter overflow + TestContext>> testContext = testContextFactory.create(); + addChannel1(testContext); FlushService flushService = testContext.flushService; + ChannelCache channelCache = testContext.channelCache; Mockito.when(flushService.isTestMode()).thenReturn(false); // Nothing to flush flushService.flush(false).get(); - Mockito.verify(flushService, Mockito.times(0)).distributeFlushTasks(); + Mockito.verify(flushService, Mockito.times(0)).distributeFlushTasks(Mockito.any()); // Force = true flushes flushService.flush(true).get(); - Mockito.verify(flushService).distributeFlushTasks(); - Mockito.verify(flushService, Mockito.times(1)).distributeFlushTasks(); + Mockito.verify(flushService, Mockito.times(1)).distributeFlushTasks(Mockito.any()); + + IntStream.range(0, numChannels) + .forEach( + i -> { + addChannel(testContext, i, 1L); + channelCache.setLastFlushTime(getFullyQualifiedTableName(i), maxLastFlushTime); + }); // isNeedFlush = true flushes - flushService.isNeedFlush = true; + flushService.setNeedFlush(getFullyQualifiedTableName(0)); flushService.flush(false).get(); - Mockito.verify(flushService, Mockito.times(2)).distributeFlushTasks(); + Mockito.verify(flushService, Mockito.times(2)).distributeFlushTasks(Mockito.any()); Assert.assertFalse(flushService.isNeedFlush); + Assert.assertNotEquals( + maxLastFlushTime, channelCache.getLastFlushTime(getFullyQualifiedTableName(0))); + IntStream.range(0, numChannels) + .forEach( + i -> { + Assert.assertFalse(channelCache.getNeedFlush(getFullyQualifiedTableName(i))); + assertTimeDiffwithinThreshold( + channelCache.getLastFlushTime(getFullyQualifiedTableName(0)), + channelCache.getLastFlushTime(getFullyQualifiedTableName(i)), + 1000L); + }); // lastFlushTime causes flush - flushService.lastFlushTime = 0; + flushService.lastFlushTime = 0L; flushService.flush(false).get(); - Mockito.verify(flushService, Mockito.times(3)).distributeFlushTasks(); + Mockito.verify(flushService, Mockito.times(3)).distributeFlushTasks(Mockito.any()); Assert.assertTrue(flushService.lastFlushTime > 0); + Assert.assertNotEquals( + maxLastFlushTime, channelCache.getLastFlushTime(getFullyQualifiedTableName(0))); + IntStream.range(0, numChannels) + .forEach( + i -> { + Assert.assertFalse(channelCache.getNeedFlush(getFullyQualifiedTableName(i))); + assertTimeDiffwithinThreshold( + channelCache.getLastFlushTime(getFullyQualifiedTableName(0)), + channelCache.getLastFlushTime(getFullyQualifiedTableName(i)), + 1000L); + }); + } + + @Test + public void testNonInterleaveFlush() throws ExecutionException, InterruptedException { + int numChannels = 4; + Long maxLastFlushTime = Long.MAX_VALUE - 1000L; // -1000L to avoid jitter overflow + TestContext>> testContext = testContextFactory.create(); + FlushService flushService = testContext.flushService; + ChannelCache channelCache = testContext.channelCache; + Mockito.when(flushService.isTestMode()).thenReturn(false); + testContext.setParameterOverride( + Collections.singletonMap(ParameterProvider.MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST, 1)); + + // Test need flush + IntStream.range(0, numChannels) + .forEach( + i -> { + addChannel(testContext, i, 1L); + channelCache.setLastFlushTime(getFullyQualifiedTableName(i), maxLastFlushTime); + if (i % 2 == 0) { + flushService.setNeedFlush(getFullyQualifiedTableName(i)); + } + }); + flushService.flush(false).get(); + Mockito.verify(flushService, Mockito.times(1)).distributeFlushTasks(Mockito.any()); + IntStream.range(0, numChannels) + .forEach( + i -> { + Assert.assertFalse(channelCache.getNeedFlush(getFullyQualifiedTableName(i))); + if (i % 2 == 0) { + Assert.assertNotEquals( + maxLastFlushTime, channelCache.getLastFlushTime(getFullyQualifiedTableName(i))); + } else { + assertTimeDiffwithinThreshold( + maxLastFlushTime, + channelCache.getLastFlushTime(getFullyQualifiedTableName(i)), + 1000L); + } + }); + + // Test time based flush + IntStream.range(0, numChannels) + .forEach( + i -> { + channelCache.setLastFlushTime( + getFullyQualifiedTableName(i), i % 2 == 0 ? 0L : maxLastFlushTime); + }); + flushService.flush(false).get(); + Mockito.verify(flushService, Mockito.times(2)).distributeFlushTasks(Mockito.any()); + IntStream.range(0, numChannels) + .forEach( + i -> { + Assert.assertFalse(channelCache.getNeedFlush(getFullyQualifiedTableName(i))); + if (i % 2 == 0) { + Assert.assertNotEquals( + 0L, channelCache.getLastFlushTime(getFullyQualifiedTableName(i)).longValue()); + } else { + assertTimeDiffwithinThreshold( + maxLastFlushTime, + channelCache.getLastFlushTime(getFullyQualifiedTableName(i)), + 1000L); + } + }); } @Test @@ -480,7 +591,8 @@ public void testBlobCreation() throws Exception { // Force = true flushes flushService.flush(true).get(); - Mockito.verify(flushService, Mockito.atLeast(2)).buildAndUpload(Mockito.any(), Mockito.any()); + Mockito.verify(flushService, Mockito.atLeast(2)) + .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any()); } @Test @@ -529,7 +641,8 @@ public void testBlobSplitDueToDifferentSchema() throws Exception { // Force = true flushes flushService.flush(true).get(); - Mockito.verify(flushService, Mockito.atLeast(2)).buildAndUpload(Mockito.any(), Mockito.any()); + Mockito.verify(flushService, Mockito.atLeast(2)) + .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any()); } @Test @@ -563,7 +676,8 @@ public void testBlobSplitDueToChunkSizeLimit() throws Exception { // Force = true flushes flushService.flush(true).get(); - Mockito.verify(flushService, Mockito.times(2)).buildAndUpload(Mockito.any(), Mockito.any()); + Mockito.verify(flushService, Mockito.times(2)) + .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any()); } @Test @@ -603,7 +717,7 @@ public void runTestBlobSplitDueToNumberOfChunks(int numberOfRows) throws Excepti ArgumentCaptor>>>>> blobDataCaptor = ArgumentCaptor.forClass(List.class); Mockito.verify(flushService, Mockito.times(expectedBlobs)) - .buildAndUpload(Mockito.any(), blobDataCaptor.capture()); + .buildAndUpload(Mockito.any(), blobDataCaptor.capture(), Mockito.any()); // 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns List>>>>> allUploadedBlobs = @@ -646,7 +760,7 @@ public void testBlobSplitDueToNumberOfChunksWithLeftoverChannels() throws Except ArgumentCaptor>>>>> blobDataCaptor = ArgumentCaptor.forClass(List.class); Mockito.verify(flushService, Mockito.atLeast(2)) - .buildAndUpload(Mockito.any(), blobDataCaptor.capture()); + .buildAndUpload(Mockito.any(), blobDataCaptor.capture(), Mockito.any()); // 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns List>>>>> allUploadedBlobs = @@ -764,12 +878,15 @@ public void testBuildAndUpload() throws Exception { .build(); // Check FlushService.upload called with correct arguments + final ArgumentCaptor storageCaptor = + ArgumentCaptor.forClass(StreamingIngestStorage.class); final ArgumentCaptor nameCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor blobCaptor = ArgumentCaptor.forClass(byte[].class); final ArgumentCaptor> metadataCaptor = ArgumentCaptor.forClass(List.class); Mockito.verify(testContext.flushService) .upload( + storageCaptor.capture(), nameCaptor.capture(), blobCaptor.capture(), metadataCaptor.capture(), @@ -912,10 +1029,10 @@ public void testInvalidateChannels() { innerData.add(channel1Data); innerData.add(channel2Data); - StreamingIngestStage stage = Mockito.mock(StreamingIngestStage.class); - Mockito.when(stage.getClientPrefix()).thenReturn("client_prefix"); + IStorageManager storageManager = + Mockito.spy(new InternalStageManager<>(true, "role", "client", null)); FlushService flushService = - new FlushService<>(client, channelCache, stage, false); + new FlushService<>(client, channelCache, storageManager, false); flushService.invalidateAllChannelsInBlob(blobData, "Invalidated by test"); Assert.assertFalse(channel1.isValid()); @@ -1063,4 +1180,12 @@ private Timer setupTimer(long expectedLatencyMs) { return timer; } + + private String getFullyQualifiedTableName(int tableId) { + return String.format("db1.PUBLIC.table%d", tableId); + } + + private void assertTimeDiffwithinThreshold(Long time1, Long time2, long threshold) { + Assert.assertTrue(Math.abs(time1 - time2) <= threshold); + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java new file mode 100644 index 000000000..5b28e9c45 --- /dev/null +++ b/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java @@ -0,0 +1,122 @@ +package net.snowflake.ingest.streaming.internal; + +import static java.time.ZoneOffset.UTC; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import net.snowflake.ingest.streaming.InsertValidationResponse; +import net.snowflake.ingest.streaming.OpenChannelRequest; +import net.snowflake.ingest.utils.Utils; +import org.junit.Assert; +import org.junit.Test; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.TimeValue; + +@State(Scope.Thread) +public class InsertRowsBenchmarkTest { + + private SnowflakeStreamingIngestChannelInternal channel; + private SnowflakeStreamingIngestClientInternal client; + + @Param({"100000"}) + private int numRows; + + @Setup(Level.Trial) + public void setUpBeforeAll() { + client = new SnowflakeStreamingIngestClientInternal("client_PARQUET"); + channel = + new SnowflakeStreamingIngestChannelInternal<>( + "channel", + "db", + "schema", + "table", + "0", + 0L, + 0L, + client, + "key", + 1234L, + OpenChannelRequest.OnErrorOption.CONTINUE, + UTC); + // Setup column fields and vectors + ColumnMetadata col = new ColumnMetadata(); + col.setOrdinal(1); + col.setName("COL"); + col.setPhysicalType("SB16"); + col.setNullable(false); + col.setLogicalType("FIXED"); + col.setPrecision(38); + col.setScale(0); + + channel.setupSchema(Collections.singletonList(col)); + assert Utils.getProvider() != null; + } + + @TearDown(Level.Trial) + public void tearDownAfterAll() throws Exception { + channel.close(); + client.close(); + } + + @Benchmark + public void testInsertRow() { + Map row = new HashMap<>(); + row.put("col", 1); + + for (int i = 0; i < numRows; i++) { + InsertValidationResponse response = channel.insertRow(row, String.valueOf(i)); + Assert.assertFalse(response.hasErrors()); + } + } + + @Test + public void insertRow() throws Exception { + setUpBeforeAll(); + Map row = new HashMap<>(); + row.put("col", 1); + + for (int i = 0; i < 1000000; i++) { + InsertValidationResponse response = channel.insertRow(row, String.valueOf(i)); + Assert.assertFalse(response.hasErrors()); + } + tearDownAfterAll(); + } + + @Test + public void launchBenchmark() throws RunnerException { + Options opt = + new OptionsBuilder() + // Specify which benchmarks to run. + // You can be more specific if you'd like to run only one benchmark per test. + .include(this.getClass().getName() + ".*") + // Set the following options as needed + .mode(Mode.AverageTime) + .timeUnit(TimeUnit.MICROSECONDS) + .warmupTime(TimeValue.seconds(1)) + .warmupIterations(2) + .measurementTime(TimeValue.seconds(1)) + .measurementIterations(10) + .threads(2) + .forks(1) + .shouldFailOnError(true) + .shouldDoGC(true) + // .jvmArgs("-XX:+UnlockDiagnosticVMOptions", "-XX:+PrintInlining") + // .addProfiler(WinPerfAsmProfiler.class) + .build(); + + new Runner(opt).run(); + } +} diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserTest.java index 6878478e2..8be5b31e7 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserTest.java @@ -10,6 +10,8 @@ import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; + +import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.SFException; import org.apache.parquet.schema.PrimitiveType; import org.junit.Assert; @@ -31,7 +33,7 @@ public void parseValueFixedSB1ToInt32() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - 12, testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0); + 12, testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -57,7 +59,7 @@ public void parseValueFixedSB2ToInt32() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - 1234, testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0); + 1234, testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -83,7 +85,7 @@ public void parseValueFixedSB4ToInt32() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - 123456789, testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0); + 123456789, testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -114,7 +116,7 @@ public void parseValueFixedSB8ToInt64() { PrimitiveType.PrimitiveTypeName.INT64, rowBufferStats, UTC, - 0); + 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -145,7 +147,7 @@ public void parseValueFixedSB16ToByteArray() { PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, rowBufferStats, UTC, - 0); + 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -178,7 +180,7 @@ public void parseValueFixedDecimalToInt32() { PrimitiveType.PrimitiveTypeName.DOUBLE, rowBufferStats, UTC, - 0); + 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -202,7 +204,7 @@ public void parseValueDouble() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - 12345.54321d, testCol, PrimitiveType.PrimitiveTypeName.DOUBLE, rowBufferStats, UTC, 0); + 12345.54321d, testCol, PrimitiveType.PrimitiveTypeName.DOUBLE, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -226,7 +228,7 @@ public void parseValueBoolean() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - true, testCol, PrimitiveType.PrimitiveTypeName.BOOLEAN, rowBufferStats, UTC, 0); + true, testCol, PrimitiveType.PrimitiveTypeName.BOOLEAN, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -255,7 +257,7 @@ public void parseValueBinary() { PrimitiveType.PrimitiveTypeName.BINARY, rowBufferStats, UTC, - 0); + 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -292,7 +294,7 @@ private void testJsonWithLogicalType(String logicalType) { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - var, testCol, PrimitiveType.PrimitiveTypeName.BINARY, rowBufferStats, UTC, 0); + var, testCol, PrimitiveType.PrimitiveTypeName.BINARY, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -333,7 +335,7 @@ private void testNullJsonWithLogicalType(String var) { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - var, testCol, PrimitiveType.PrimitiveTypeName.BINARY, rowBufferStats, UTC, 0); + var, testCol, PrimitiveType.PrimitiveTypeName.BINARY, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -363,7 +365,7 @@ public void parseValueArrayToBinary() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - input, testCol, PrimitiveType.PrimitiveTypeName.BINARY, rowBufferStats, UTC, 0); + input, testCol, PrimitiveType.PrimitiveTypeName.BINARY, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); String resultArray = "[{\"a\":\"1\",\"b\":\"2\",\"c\":\"3\"}]"; @@ -395,7 +397,7 @@ public void parseValueTextToBinary() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - text, testCol, PrimitiveType.PrimitiveTypeName.BINARY, rowBufferStats, UTC, 0); + text, testCol, PrimitiveType.PrimitiveTypeName.BINARY, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); String result = text; @@ -434,7 +436,7 @@ public void parseValueTimestampNtzSB4Error() { PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, - 0)); + 0, Constants.BinaryStringEncoding.HEX)); Assert.assertEquals( "Unknown data type for logical: TIMESTAMP_NTZ, physical: SB4.", exception.getMessage()); } @@ -458,7 +460,7 @@ public void parseValueTimestampNtzSB8ToINT64() { PrimitiveType.PrimitiveTypeName.INT64, rowBufferStats, UTC, - 0); + 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -488,7 +490,7 @@ public void parseValueTimestampNtzSB16ToByteArray() { PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, rowBufferStats, UTC, - 0); + 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -514,7 +516,7 @@ public void parseValueDateToInt32() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - "2021-01-01", testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0); + "2021-01-01", testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -539,7 +541,7 @@ public void parseValueTimeSB4ToInt32() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - "01:00:00", testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0); + "01:00:00", testCol, PrimitiveType.PrimitiveTypeName.INT32, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -564,7 +566,7 @@ public void parseValueTimeSB8ToInt64() { RowBufferStats rowBufferStats = new RowBufferStats("COL1"); ParquetValueParser.ParquetBufferValue pv = ParquetValueParser.parseColumnValueToParquet( - "01:00:00.123", testCol, PrimitiveType.PrimitiveTypeName.INT64, rowBufferStats, UTC, 0); + "01:00:00.123", testCol, PrimitiveType.PrimitiveTypeName.INT64, rowBufferStats, UTC, 0, Constants.BinaryStringEncoding.HEX); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) @@ -597,7 +599,7 @@ public void parseValueTimeSB16Error() { PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, rowBufferStats, UTC, - 0)); + 0, Constants.BinaryStringEncoding.HEX)); Assert.assertEquals( "Unknown data type for logical: TIME, physical: SB16.", exception.getMessage()); } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java index 2a3bb7edd..5b573c6b0 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java @@ -3,6 +3,7 @@ import static java.time.ZoneOffset.UTC; import static net.snowflake.ingest.utils.ParameterProvider.MAX_ALLOWED_ROW_SIZE_IN_BYTES_DEFAULT; import static net.snowflake.ingest.utils.ParameterProvider.MAX_CHUNK_SIZE_IN_BYTES_DEFAULT; +import static org.junit.Assert.fail; import java.math.BigDecimal; import java.math.BigInteger; @@ -14,6 +15,8 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import net.snowflake.ingest.streaming.InsertValidationResponse; import net.snowflake.ingest.streaming.OpenChannelRequest; import net.snowflake.ingest.utils.Constants; @@ -126,7 +129,8 @@ private AbstractRowBuffer createTestBuffer(OpenChannelRequest.OnErrorOption o enableParquetMemoryOptimization, MAX_CHUNK_SIZE_IN_BYTES_DEFAULT, MAX_ALLOWED_ROW_SIZE_IN_BYTES_DEFAULT, - Constants.BdecParquetCompression.GZIP), + Constants.BdecParquetCompression.GZIP, + Constants.BinaryStringEncoding.HEX), null, null); } @@ -144,7 +148,7 @@ public void testCollatedColumnsAreRejected() { collatedColumn.setCollation("en-ci"); try { this.rowBufferOnErrorAbort.setupSchema(Collections.singletonList(collatedColumn)); - Assert.fail("Collated columns are not supported"); + fail("Collated columns are not supported"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNSUPPORTED_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -164,7 +168,7 @@ public void buildFieldErrorStates() { testCol.setPrecision(4); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -176,7 +180,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("FIXED"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -188,7 +192,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("TIMESTAMP_NTZ"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -200,7 +204,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("TIMESTAMP_TZ"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -212,7 +216,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("TIME"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -244,7 +248,7 @@ public void testInvalidLogicalType() { try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(colInvalidLogical)); - Assert.fail("Setup should fail if invalid column metadata is provided"); + fail("Setup should fail if invalid column metadata is provided"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); // Do nothing @@ -264,7 +268,7 @@ public void testInvalidPhysicalType() { try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(colInvalidPhysical)); - Assert.fail("Setup should fail if invalid column metadata is provided"); + fail("Setup should fail if invalid column metadata is provided"); } catch (SFException e) { Assert.assertEquals(e.getVendorCode(), ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode()); } @@ -630,7 +634,7 @@ public void testInvalidEPInfo() { try { AbstractRowBuffer.buildEpInfoFromStats(1, colStats); - Assert.fail("should fail when row count is smaller than null count."); + fail("should fail when row count is smaller than null count."); } catch (SFException e) { Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); } @@ -1725,4 +1729,78 @@ public void testOnErrorAbortRowsWithError() { Assert.assertEquals(1, snapshotAbortParquet.size()); Assert.assertEquals(Arrays.asList("a"), snapshotAbortParquet.get(0)); } + + @Test + public void testParquetChunkMetadataCreationIsThreadSafe() throws InterruptedException { + final String testFileA = "testFileA"; + final String testFileB = "testFileB"; + + final ParquetRowBuffer bufferUnderTest = + (ParquetRowBuffer) createTestBuffer(OpenChannelRequest.OnErrorOption.CONTINUE); + + final ColumnMetadata colChar = new ColumnMetadata(); + colChar.setOrdinal(1); + colChar.setName("COLCHAR"); + colChar.setPhysicalType("LOB"); + colChar.setNullable(true); + colChar.setLogicalType("TEXT"); + colChar.setByteLength(14); + colChar.setLength(11); + colChar.setScale(0); + + bufferUnderTest.setupSchema(Collections.singletonList(colChar)); + + loadData(bufferUnderTest, Collections.singletonMap("colChar", "a")); + + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference> firstFlushResult = new AtomicReference<>(); + final Thread t = + getThreadThatWaitsForLockReleaseAndFlushes( + bufferUnderTest, testFileA, latch, firstFlushResult); + t.start(); + + final ChannelData secondFlushResult = bufferUnderTest.flush(testFileB); + Assert.assertEquals(testFileB, getPrimaryFileId(secondFlushResult)); + + latch.countDown(); + t.join(); + + Assert.assertNotNull(firstFlushResult.get()); + Assert.assertEquals(testFileA, getPrimaryFileId(firstFlushResult.get())); + Assert.assertEquals(testFileB, getPrimaryFileId(secondFlushResult)); + } + + private static Thread getThreadThatWaitsForLockReleaseAndFlushes( + final ParquetRowBuffer bufferUnderTest, + final String filenameToFlush, + final CountDownLatch latch, + final AtomicReference> flushResult) { + return new Thread( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + fail("Thread was unexpectedly interrupted"); + } + + final ChannelData flush = + loadData(bufferUnderTest, Collections.singletonMap("colChar", "b")) + .flush(filenameToFlush); + flushResult.set(flush); + }); + } + + private static ParquetRowBuffer loadData( + final ParquetRowBuffer bufferToLoad, final Map data) { + final List> validRows = new ArrayList<>(); + validRows.add(data); + + final InsertValidationResponse nResponse = bufferToLoad.insertRows(validRows, "1", "1"); + Assert.assertFalse(nResponse.hasErrors()); + return bufferToLoad; + } + + private static String getPrimaryFileId(final ChannelData chunkData) { + return chunkData.getVectors().metadata.get(Constants.PRIMARY_FILE_ID_KEY); + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java index 87e3f8f11..5d8d8d36a 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static java.time.ZoneOffset.UTC; @@ -61,11 +65,6 @@ public long getMaxMemory() { return maxMemory; } - @Override - public long getTotalMemory() { - return maxMemory; - } - @Override public long getFreeMemory() { return freeMemory; @@ -264,7 +263,7 @@ public void testOpenChannelRequestCreationSuccess() { Assert.assertEquals( "STREAMINGINGEST_TEST.PUBLIC.T_STREAMINGINGEST", request.getFullyQualifiedTableName()); - Assert.assertFalse(request.isOffsetTokenProvided()); + Assert.assertNull(request.getOffsetToken()); } @Test @@ -281,7 +280,6 @@ public void testOpenChannelRequesCreationtWithOffsetToken() { Assert.assertEquals( "STREAMINGINGEST_TEST.PUBLIC.T_STREAMINGINGEST", request.getFullyQualifiedTableName()); Assert.assertEquals("TEST_TOKEN", request.getOffsetToken()); - Assert.assertTrue(request.isOffsetTokenProvided()); } @Test @@ -313,8 +311,14 @@ public void testOpenChannelPostRequest() throws Exception { requestBuilder.generateStreamingIngestPostRequest( payload, OPEN_CHANNEL_ENDPOINT, "open channel"); - Assert.assertEquals( - String.format("%s%s", urlStr, OPEN_CHANNEL_ENDPOINT), request.getRequestLine().getUri()); + String expectedUrlPattern = + String.format("%s%s", urlStr, OPEN_CHANNEL_ENDPOINT) + "(\\?requestId=[a-f0-9\\-]{36})?"; + + Assert.assertTrue( + String.format( + "Expected URL to match pattern: %s but was: %s", + expectedUrlPattern, request.getRequestLine().getUri()), + request.getRequestLine().getUri().matches(expectedUrlPattern)); Assert.assertNotNull(request.getFirstHeader(HttpHeaders.USER_AGENT)); Assert.assertNotNull(request.getFirstHeader(HttpHeaders.AUTHORIZATION)); Assert.assertEquals("POST", request.getMethod()); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java index 1693e1520..553efbd31 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java @@ -80,9 +80,27 @@ public class SnowflakeStreamingIngestClientTest { SnowflakeStreamingIngestChannelInternal channel4; @Before - public void setup() { + public void setup() throws Exception { objectMapper.setVisibility(PropertyAccessor.GETTER, JsonAutoDetect.Visibility.ANY); objectMapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY); + Properties prop = new Properties(); + prop.put(USER, TestUtils.getUser()); + prop.put(ACCOUNT_URL, TestUtils.getHost()); + prop.put(PRIVATE_KEY, TestUtils.getPrivateKey()); + prop.put(ROLE, TestUtils.getRole()); + + CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); + RequestBuilder requestBuilder = + new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); + SnowflakeStreamingIngestClientInternal client = + new SnowflakeStreamingIngestClientInternal<>( + "client", + new SnowflakeURL("snowflake.dev.local:8082"), + null, + httpClient, + true, + requestBuilder, + null); channel1 = new SnowflakeStreamingIngestChannelInternal<>( "channel1", @@ -92,7 +110,7 @@ public void setup() { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -108,7 +126,7 @@ public void setup() { "0", 2L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -124,7 +142,7 @@ public void setup() { "0", 3L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -140,7 +158,7 @@ public void setup() { "0", 3L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -357,7 +375,7 @@ public void testGetChannelsStatusWithRequest() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -368,7 +386,6 @@ public void testGetChannelsStatusWithRequest() throws Exception { ChannelsStatusRequest.ChannelStatusRequestDTO dto = new ChannelsStatusRequest.ChannelStatusRequestDTO(channel); ChannelsStatusRequest request = new ChannelsStatusRequest(); - request.setRequestId("null_0"); request.setChannels(Collections.singletonList(dto)); ChannelsStatusResponse result = client.getChannelsStatus(Collections.singletonList(channel)); Assert.assertEquals(response.getMessage(), result.getMessage()); @@ -462,7 +479,7 @@ public void testGetChannelsStatusWithRequestError() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -495,6 +512,16 @@ public void testRegisterBlobRequestCreationSuccess() throws Exception { RequestBuilder requestBuilder = new RequestBuilder(url, prop.get(USER).toString(), keyPair, null, null); + CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); + SnowflakeStreamingIngestClientInternal client = + new SnowflakeStreamingIngestClientInternal<>( + "client", + new SnowflakeURL("snowflake.dev.local:8082"), + null, + httpClient, + true, + requestBuilder, + null); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -504,7 +531,7 @@ public void testRegisterBlobRequestCreationSuccess() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -547,9 +574,15 @@ public void testRegisterBlobRequestCreationSuccess() throws Exception { HttpPost request = requestBuilder.generateStreamingIngestPostRequest( payload, REGISTER_BLOB_ENDPOINT, "register blob"); + String expectedUrlPattern = + String.format("%s%s", urlStr, REGISTER_BLOB_ENDPOINT) + "(\\?requestId=[a-f0-9\\-]{36})?"; + + Assert.assertTrue( + String.format( + "Expected URL to match pattern: %s but was: %s", + expectedUrlPattern, request.getRequestLine().getUri()), + request.getRequestLine().getUri().matches(expectedUrlPattern)); - Assert.assertEquals( - String.format("%s%s", urlStr, REGISTER_BLOB_ENDPOINT), request.getRequestLine().getUri()); Assert.assertNotNull(request.getFirstHeader(HttpHeaders.USER_AGENT)); Assert.assertNotNull(request.getFirstHeader(HttpHeaders.AUTHORIZATION)); Assert.assertEquals("POST", request.getMethod()); @@ -1421,7 +1454,7 @@ public void testGetLatestCommittedOffsetTokens() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -1432,7 +1465,6 @@ public void testGetLatestCommittedOffsetTokens() throws Exception { ChannelsStatusRequest.ChannelStatusRequestDTO dto = new ChannelsStatusRequest.ChannelStatusRequestDTO(channel); ChannelsStatusRequest request = new ChannelsStatusRequest(); - request.setRequestId("null_0"); request.setChannels(Collections.singletonList(dto)); Map result = client.getLatestCommittedOffsetTokens(Collections.singletonList(channel)); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestBigFilesIT.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestBigFilesIT.java index ead26acd6..b7f9e6829 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestBigFilesIT.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestBigFilesIT.java @@ -130,7 +130,7 @@ private void ingestRandomRowsToTable( boolean isNullable) throws ExecutionException, InterruptedException { - List> rows = new ArrayList<>(); + final List> rows = Collections.synchronizedList(new ArrayList<>()); for (int i = 0; i < batchSize; i++) { Random r = new Random(); rows.add(TestUtils.getRandomRow(r, isNullable)); @@ -138,7 +138,8 @@ private void ingestRandomRowsToTable( ExecutorService testThreadPool = Executors.newFixedThreadPool(numChannels); CompletableFuture[] futures = new CompletableFuture[numChannels]; - List channelList = new ArrayList<>(); + List channelList = + Collections.synchronizedList(new ArrayList<>()); for (int i = 0; i < numChannels; i++) { final String channelName = "CHANNEL" + i; int finalI = i; diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestIT.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestIT.java index 90f8862c6..1941a48f5 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestIT.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestIT.java @@ -8,6 +8,7 @@ import static net.snowflake.ingest.utils.ParameterProvider.BDEC_PARQUET_COMPRESSION_ALGORITHM; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.Is.is; +import static org.mockito.Mockito.atLeastOnce; import java.math.BigDecimal; import java.math.BigInteger; @@ -18,6 +19,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Calendar; +import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; @@ -186,7 +188,7 @@ public void testSimpleIngest() throws Exception { // verify expected request sent to server String[] expectedPayloadParams = {"request_id", "blobs", "role", "blob_stats"}; for (String expectedParam : expectedPayloadParams) { - Mockito.verify(requestBuilder) + Mockito.verify(requestBuilder, atLeastOnce()) .generateStreamingIngestPostRequest( ArgumentMatchers.contains(expectedParam), ArgumentMatchers.refEq(REGISTER_BLOB_ENDPOINT), @@ -274,7 +276,7 @@ public void testDropChannel() throws Exception { @Test public void testParameterOverrides() throws Exception { Map parameterMap = new HashMap<>(); - parameterMap.put(ParameterProvider.MAX_CLIENT_LAG, "3 sec"); + parameterMap.put(ParameterProvider.MAX_CLIENT_LAG, "3 seconds"); parameterMap.put(ParameterProvider.BUFFER_FLUSH_CHECK_INTERVAL_IN_MILLIS, 50L); parameterMap.put(ParameterProvider.INSERT_THROTTLE_THRESHOLD_IN_PERCENTAGE, 1); parameterMap.put(ParameterProvider.INSERT_THROTTLE_THRESHOLD_IN_BYTES, 1024); @@ -744,7 +746,8 @@ public void testMultiThread() throws Exception { int numRows = 10000; ExecutorService testThreadPool = Executors.newFixedThreadPool(numThreads); CompletableFuture[] futures = new CompletableFuture[numThreads]; - List channelList = new ArrayList<>(); + List channelList = + Collections.synchronizedList(new ArrayList<>()); for (int i = 0; i < numThreads; i++) { final String channelName = "CHANNEL" + i; futures[i] = diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorageTest.java similarity index 75% rename from src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java rename to src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorageTest.java index 1ba9f98df..d4c3f0374 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStageTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestStorageTest.java @@ -1,6 +1,7 @@ package net.snowflake.ingest.streaming.internal; import static net.snowflake.client.core.Constants.CLOUD_STORAGE_CREDENTIALS_EXPIRED; +import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; import static net.snowflake.ingest.utils.HttpUtil.HTTP_PROXY_PASSWORD; import static net.snowflake.ingest.utils.HttpUtil.HTTP_PROXY_USER; import static net.snowflake.ingest.utils.HttpUtil.NON_PROXY_HOSTS; @@ -33,6 +34,7 @@ import net.snowflake.client.jdbc.SnowflakeSQLException; import net.snowflake.client.jdbc.cloud.storage.StageInfo; import net.snowflake.client.jdbc.internal.amazonaws.util.IOUtils; +import net.snowflake.client.jdbc.internal.apache.http.HttpEntity; import net.snowflake.client.jdbc.internal.apache.http.StatusLine; import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; import net.snowflake.client.jdbc.internal.apache.http.entity.BasicHttpEntity; @@ -42,8 +44,7 @@ import net.snowflake.client.jdbc.internal.google.common.util.concurrent.ThreadFactoryBuilder; import net.snowflake.ingest.TestUtils; import net.snowflake.ingest.connection.RequestBuilder; -import net.snowflake.ingest.utils.Constants; -import net.snowflake.ingest.utils.ParameterProvider; +import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import org.junit.Assert; import org.junit.Test; @@ -56,7 +57,7 @@ @RunWith(PowerMockRunner.class) @PrepareForTest({TestUtils.class, HttpUtil.class, SnowflakeFileTransferAgent.class}) -public class StreamingIngestStageTest { +public class StreamingIngestStorageTest { private final String prefix = "EXAMPLE_PREFIX"; @@ -99,6 +100,21 @@ public class StreamingIngestStageTest { + " \"EXAMPLE_AWS_SECRET_KEY\", \"AWS_TOKEN\": \"EXAMPLE_AWS_TOKEN\", \"AWS_ID\":" + " \"EXAMPLE_AWS_ID\", \"AWS_KEY\": \"EXAMPLE_AWS_KEY\"}, \"presignedUrl\": null," + " \"endPoint\": null}}"; + String remoteMetaResponseDifferentDeployment = + "{\"src_locations\": [\"foo/\"]," + + " \"deployment_id\": " + + (deploymentId + 1) + + "," + + " \"status_code\": 0, \"message\": \"Success\", \"prefix\":" + + " \"" + + prefix + + "\", \"stage_location\": {\"locationType\": \"S3\", \"location\":" + + " \"foo/streaming_ingest/\", \"path\": \"streaming_ingest/\", \"region\":" + + " \"us-east-1\", \"storageAccount\": null, \"isClientSideEncrypted\": true," + + " \"creds\": {\"AWS_KEY_ID\": \"EXAMPLE_AWS_KEY_ID\", \"AWS_SECRET_KEY\":" + + " \"EXAMPLE_AWS_SECRET_KEY\", \"AWS_TOKEN\": \"EXAMPLE_AWS_TOKEN\", \"AWS_ID\":" + + " \"EXAMPLE_AWS_ID\", \"AWS_KEY\": \"EXAMPLE_AWS_KEY\"}, \"presignedUrl\": null," + + " \"endPoint\": null}}"; private void setupMocksForRefresh() throws Exception { PowerMockito.mockStatic(HttpUtil.class); @@ -114,15 +130,16 @@ public void testPutRemote() throws Exception { byte[] dataBytes = "Hello Upload".getBytes(StandardCharsets.UTF_8); - StreamingIngestStage stage = - new StreamingIngestStage( - true, - "role", - null, - null, + IStorageManager storageManager = Mockito.mock(IStorageManager.class); + Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix"); + + StreamingIngestStorage stage = + new StreamingIngestStorage( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( originalMetadata, Optional.of(System.currentTimeMillis())), + null, 1); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); @@ -156,16 +173,14 @@ public void testPutLocal() throws Exception { String fullFilePath = "testOutput"; String fileName = "putLocalOutput"; - StreamingIngestStage stage = + StreamingIngestStorage stage = Mockito.spy( - new StreamingIngestStage( - true, - "role", - null, + new StreamingIngestStorage( null, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( fullFilePath, Optional.of(System.currentTimeMillis())), + null, 1)); Mockito.doReturn(true).when(stage).isLocalFS(); @@ -186,15 +201,16 @@ public void doTestPutRemoteRefreshes() throws Exception { byte[] dataBytes = "Hello Upload".getBytes(StandardCharsets.UTF_8); - StreamingIngestStage stage = - new StreamingIngestStage( - true, - "role", - null, - null, + IStorageManager storageManager = Mockito.mock(IStorageManager.class); + Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix"); + + StreamingIngestStorage stage = + new StreamingIngestStorage<>( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( originalMetadata, Optional.of(System.currentTimeMillis())), + null, maxUploadRetryCount); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); SnowflakeSQLException e = @@ -240,16 +256,17 @@ public void testPutRemoteGCS() throws Exception { byte[] dataBytes = "Hello Upload".getBytes(StandardCharsets.UTF_8); - StreamingIngestStage stage = + IStorageManager storageManager = Mockito.mock(IStorageManager.class); + Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix"); + + StreamingIngestStorage stage = Mockito.spy( - new StreamingIngestStage( - true, - "role", - null, - null, + new StreamingIngestStorage<>( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( originalMetadata, Optional.of(System.currentTimeMillis())), + null, 1)); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); SnowflakeFileTransferMetadataV1 metaMock = Mockito.mock(SnowflakeFileTransferMetadataV1.class); @@ -265,22 +282,30 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception { RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class); CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class); CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class); + SnowflakeStreamingIngestClientInternal mockClientInternal = + Mockito.mock(SnowflakeStreamingIngestClientInternal.class); + Mockito.when(mockClientInternal.getRole()).thenReturn("role"); StatusLine mockStatusLine = Mockito.mock(StatusLine.class); Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200); - BasicHttpEntity entity = new BasicHttpEntity(); - entity.setContent( - new ByteArrayInputStream(exampleRemoteMetaResponse.getBytes(StandardCharsets.UTF_8))); - Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine); - Mockito.when(mockResponse.getEntity()).thenReturn(entity); + Mockito.when(mockResponse.getEntity()).thenReturn(createHttpEntity(exampleRemoteMetaResponse)); Mockito.when(mockClient.execute(Mockito.any())).thenReturn(mockResponse); - ParameterProvider parameterProvider = new ParameterProvider(); - StreamingIngestStage stage = - new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName", 1); + SnowflakeServiceClient snowflakeServiceClient = + new SnowflakeServiceClient(mockClient, mockBuilder); + IStorageManager storageManager = + new InternalStageManager<>(true, "role", "client", snowflakeServiceClient); + + StreamingIngestStorage stage = + new StreamingIngestStorage<>( + storageManager, + "clientName", + (StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null, + null, + 1); - StreamingIngestStage.SnowflakeFileTransferMetadataWithAge metadataWithAge = + StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge metadataWithAge = stage.refreshSnowflakeMetadata(true); final ArgumentCaptor endpointCaptor = ArgumentCaptor.forClass(String.class); @@ -288,7 +313,7 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception { Mockito.verify(mockBuilder) .generateStreamingIngestPostRequest( stringCaptor.capture(), endpointCaptor.capture(), Mockito.any()); - Assert.assertEquals(Constants.CLIENT_CONFIGURE_ENDPOINT, endpointCaptor.getValue()); + Assert.assertEquals(CLIENT_CONFIGURE_ENDPOINT, endpointCaptor.getValue()); Assert.assertTrue(metadataWithAge.timestamp.isPresent()); Assert.assertEquals( StageInfo.StageType.S3, metadataWithAge.fileTransferMetadata.getStageInfo().getStageType()); @@ -299,27 +324,83 @@ public void testRefreshSnowflakeMetadataRemote() throws Exception { Assert.assertEquals( Paths.get("placeholder").toAbsolutePath(), Paths.get(metadataWithAge.fileTransferMetadata.getPresignedUrlFileName()).toAbsolutePath()); - Assert.assertEquals(prefix + "_" + deploymentId, stage.getClientPrefix()); + Assert.assertEquals(prefix + "_" + deploymentId, storageManager.getClientPrefix()); } @Test - public void testFetchSignedURL() throws Exception { + public void testRefreshSnowflakeMetadataDeploymentIdMismatch() throws Exception { RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class); CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class); CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine mockStatusLine = Mockito.mock(StatusLine.class); Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200); + Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine); BasicHttpEntity entity = new BasicHttpEntity(); entity.setContent( new ByteArrayInputStream(exampleRemoteMetaResponse.getBytes(StandardCharsets.UTF_8))); + BasicHttpEntity entityFromDifferentDeployment = new BasicHttpEntity(); + entityFromDifferentDeployment.setContent( + new ByteArrayInputStream( + remoteMetaResponseDifferentDeployment.getBytes(StandardCharsets.UTF_8))); + Mockito.when(mockResponse.getEntity()) + .thenReturn(entity) + .thenReturn(entityFromDifferentDeployment); + Mockito.when(mockClient.execute(Mockito.any())) + .thenReturn(mockResponse) + .thenReturn(mockResponse); + + SnowflakeServiceClient snowflakeServiceClient = + new SnowflakeServiceClient(mockClient, mockBuilder); + IStorageManager storageManager = + new InternalStageManager<>(true, "role", "clientName", snowflakeServiceClient); + + StreamingIngestStorage storage = storageManager.getStorage(""); + storage.refreshSnowflakeMetadata(true); + + Assert.assertEquals(prefix + "_" + deploymentId, storageManager.getClientPrefix()); + + SFException exception = + Assert.assertThrows(SFException.class, () -> storage.refreshSnowflakeMetadata(true)); + Assert.assertEquals( + ErrorCode.CLIENT_DEPLOYMENT_ID_MISMATCH.getMessageCode(), exception.getVendorCode()); + Assert.assertEquals( + "Deployment ID mismatch, Client was created on: " + + deploymentId + + ", Got upload location for: " + + (deploymentId + 1) + + ". Please" + + " restart client: clientName.", + exception.getMessage()); + } + + @Test + public void testFetchSignedURL() throws Exception { + RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class); + CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class); + CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class); + SnowflakeStreamingIngestClientInternal mockClientInternal = + Mockito.mock(SnowflakeStreamingIngestClientInternal.class); + Mockito.when(mockClientInternal.getRole()).thenReturn("role"); + SnowflakeServiceClient snowflakeServiceClient = + new SnowflakeServiceClient(mockClient, mockBuilder); + IStorageManager storageManager = + new InternalStageManager<>(true, "role", "client", snowflakeServiceClient); + StatusLine mockStatusLine = Mockito.mock(StatusLine.class); + Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200); + Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine); - Mockito.when(mockResponse.getEntity()).thenReturn(entity); + Mockito.when(mockResponse.getEntity()).thenReturn(createHttpEntity(exampleRemoteMetaResponse)); Mockito.when(mockClient.execute(Mockito.any())).thenReturn(mockResponse); - StreamingIngestStage stage = - new StreamingIngestStage(true, "role", mockClient, mockBuilder, "clientName", 1); + StreamingIngestStorage stage = + new StreamingIngestStorage( + storageManager, + "clientName", + (StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null, + null, + 1); SnowflakeFileTransferMetadataV1 metadata = stage.fetchSignedURL("path/fileName"); @@ -328,7 +409,7 @@ public void testFetchSignedURL() throws Exception { Mockito.verify(mockBuilder) .generateStreamingIngestPostRequest( stringCaptor.capture(), endpointCaptor.capture(), Mockito.any()); - Assert.assertEquals(Constants.CLIENT_CONFIGURE_ENDPOINT, endpointCaptor.getValue()); + Assert.assertEquals(CLIENT_CONFIGURE_ENDPOINT, endpointCaptor.getValue()); Assert.assertEquals(StageInfo.StageType.S3, metadata.getStageInfo().getStageType()); Assert.assertEquals("foo/streaming_ingest/", metadata.getStageInfo().getLocation()); Assert.assertEquals("path/fileName", metadata.getPresignedUrlFileName()); @@ -345,26 +426,26 @@ public void testRefreshSnowflakeMetadataSynchronized() throws Exception { RequestBuilder mockBuilder = Mockito.mock(RequestBuilder.class); CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class); CloseableHttpResponse mockResponse = Mockito.mock(CloseableHttpResponse.class); + SnowflakeStreamingIngestClientInternal mockClientInternal = + Mockito.mock(SnowflakeStreamingIngestClientInternal.class); + Mockito.when(mockClientInternal.getRole()).thenReturn("role"); + SnowflakeServiceClient snowflakeServiceClient = + new SnowflakeServiceClient(mockClient, mockBuilder); + IStorageManager storageManager = + new InternalStageManager<>(true, "role", "client", snowflakeServiceClient); StatusLine mockStatusLine = Mockito.mock(StatusLine.class); Mockito.when(mockStatusLine.getStatusCode()).thenReturn(200); - BasicHttpEntity entity = new BasicHttpEntity(); - entity.setContent( - new ByteArrayInputStream(exampleRemoteMetaResponse.getBytes(StandardCharsets.UTF_8))); - Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine); - Mockito.when(mockResponse.getEntity()).thenReturn(entity); + Mockito.when(mockResponse.getEntity()).thenReturn(createHttpEntity(exampleRemoteMetaResponse)); Mockito.when(mockClient.execute(Mockito.any())).thenReturn(mockResponse); - StreamingIngestStage stage = - new StreamingIngestStage( - true, - "role", - mockClient, - mockBuilder, + StreamingIngestStorage stage = + new StreamingIngestStorage<>( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( - originalMetadata, Optional.of(0L)), + (StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge) null, + null, 1); ThreadFactory buildUploadThreadFactory = @@ -493,15 +574,16 @@ public void testRefreshMetadataOnFirstPutException() throws Exception { byte[] dataBytes = "Hello Upload".getBytes(StandardCharsets.UTF_8); - StreamingIngestStage stage = - new StreamingIngestStage( - true, - "role", - null, - null, + IStorageManager storageManager = Mockito.mock(IStorageManager.class); + Mockito.when(storageManager.getClientPrefix()).thenReturn("testPrefix"); + + StreamingIngestStorage stage = + new StreamingIngestStorage<>( + storageManager, "clientName", - new StreamingIngestStage.SnowflakeFileTransferMetadataWithAge( + new StreamingIngestStorage.SnowflakeFileTransferMetadataWithAge( originalMetadata, Optional.of(System.currentTimeMillis())), + null, maxUploadRetryCount); PowerMockito.mockStatic(SnowflakeFileTransferAgent.class); SnowflakeSQLException e = @@ -546,4 +628,10 @@ public Object answer(org.mockito.invocation.InvocationOnMock invocation) InputStream capturedInput = capturedConfig.getUploadStream(); Assert.assertEquals("Hello Upload", IOUtils.toString(capturedInput)); } + + private HttpEntity createHttpEntity(String content) { + BasicHttpEntity entity = new BasicHttpEntity(); + entity.setContent(new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8))); + return entity; + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtilsIT.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtilsIT.java index 4e054c209..b40cdad82 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtilsIT.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestUtilsIT.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CLIENT_CONFIGURE; @@ -5,9 +9,6 @@ import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; -import com.fasterxml.jackson.databind.ObjectMapper; -import java.util.HashMap; -import java.util.Map; import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; import net.snowflake.ingest.TestUtils; import net.snowflake.ingest.connection.IngestResponseException; @@ -53,11 +54,11 @@ public void testJWTRetries() throws Exception { "testJWTRetries")); // build payload - Map payload = new HashMap<>(); - if (!TestUtils.getRole().isEmpty() && !TestUtils.getRole().equals("DEFAULT_ROLE")) { - payload.put("role", TestUtils.getRole()); - } - ObjectMapper mapper = new ObjectMapper(); + ClientConfigureRequest request = + new ClientConfigureRequest( + !TestUtils.getRole().isEmpty() && !TestUtils.getRole().equals("DEFAULT_ROLE") + ? TestUtils.getRole() + : null); // request wih invalid token, should get 401 3 times PowerMockito.doReturn("invalid_token").when(spyManager).getToken(); @@ -66,7 +67,7 @@ public void testJWTRetries() throws Exception { executeWithRetries( ChannelsStatusResponse.class, CLIENT_CONFIGURE_ENDPOINT, - mapper.writeValueAsString(payload), + request, "client configure", STREAMING_CLIENT_CONFIGURE, httpClient, @@ -84,7 +85,7 @@ public void testJWTRetries() throws Exception { executeWithRetries( ChannelsStatusResponse.class, CLIENT_CONFIGURE_ENDPOINT, - mapper.writeValueAsString(payload), + request, "client configure", STREAMING_CLIENT_CONFIGURE, httpClient, @@ -101,7 +102,7 @@ public void testJWTRetries() throws Exception { executeWithRetries( ChannelsStatusResponse.class, CLIENT_CONFIGURE_ENDPOINT, - mapper.writeValueAsString(payload), + request, "client configure", STREAMING_CLIENT_CONFIGURE, httpClient,