Skip to content

Commit f6aa9e5

Browse files
authored
Destination Snowflake: CDK T+D initial state refactor (#35456)
Signed-off-by: Gireesh Sreepathi <[email protected]>
1 parent a13bd80 commit f6aa9e5

File tree

13 files changed

+238
-199
lines changed

13 files changed

+238
-199
lines changed

airbyte-integrations/connectors/destination-snowflake/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins {
33
}
44

55
airbyteJavaConnector {
6-
cdkVersionRequired = '0.20.9'
6+
cdkVersionRequired = '0.23.2'
77
features = ['db-destinations', 's3-destinations', 'typing-deduping']
88
useLocalCdk = false
99
}

airbyte-integrations/connectors/destination-snowflake/metadata.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ data:
55
connectorSubtype: database
66
connectorType: destination
77
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
8-
dockerImageTag: 3.5.13
8+
dockerImageTag: 3.5.14
99
dockerRepository: airbyte/destination-snowflake
1010
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
1111
githubIssueLabel: destination-snowflake

airbyte-integrations/connectors/destination-snowflake/src/main/java/io/airbyte/integrations/destination/snowflake/SnowflakeInternalStagingDestination.java

+8-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import io.airbyte.cdk.integrations.base.TypingAndDedupingFlag;
1414
import io.airbyte.cdk.integrations.destination.NamingConventionTransformer;
1515
import io.airbyte.cdk.integrations.destination.jdbc.AbstractJdbcDestination;
16+
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcDestinationHandler;
1617
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcSqlGenerator;
1718
import io.airbyte.cdk.integrations.destination.staging.StagingConsumerFactory;
1819
import io.airbyte.commons.json.Jsons;
@@ -129,6 +130,11 @@ protected JdbcSqlGenerator getSqlGenerator() {
129130
throw new UnsupportedOperationException("Snowflake does not yet use the native JDBC DV2 interface");
130131
}
131132

133+
@Override
134+
protected JdbcDestinationHandler getDestinationHandler(String databaseName, JdbcDatabase database) {
135+
throw new UnsupportedOperationException("Snowflake does not yet use the native JDBC DV2 interface");
136+
}
137+
132138
@Override
133139
public SerializedAirbyteMessageConsumer getSerializedMessageConsumer(final JsonNode config,
134140
final ConfiguredAirbyteCatalog catalog,
@@ -156,13 +162,11 @@ public SerializedAirbyteMessageConsumer getSerializedMessageConsumer(final JsonN
156162
final SnowflakeV1V2Migrator migrator = new SnowflakeV1V2Migrator(getNamingResolver(), database, databaseName);
157163
final SnowflakeV2TableMigrator v2TableMigrator = new SnowflakeV2TableMigrator(database, databaseName, sqlGenerator, snowflakeDestinationHandler);
158164
final boolean disableTypeDedupe = config.has(DISABLE_TYPE_DEDUPE) && config.get(DISABLE_TYPE_DEDUPE).asBoolean(false);
159-
final int defaultThreadCount = 8;
160165
if (disableTypeDedupe) {
161-
typerDeduper = new NoOpTyperDeduperWithV1V2Migrations<>(sqlGenerator, snowflakeDestinationHandler, parsedCatalog, migrator, v2TableMigrator,
162-
defaultThreadCount);
166+
typerDeduper = new NoOpTyperDeduperWithV1V2Migrations(sqlGenerator, snowflakeDestinationHandler, parsedCatalog, migrator, v2TableMigrator);
163167
} else {
164168
typerDeduper =
165-
new DefaultTyperDeduper<>(sqlGenerator, snowflakeDestinationHandler, parsedCatalog, migrator, v2TableMigrator, defaultThreadCount);
169+
new DefaultTyperDeduper(sqlGenerator, snowflakeDestinationHandler, parsedCatalog, migrator, v2TableMigrator);
166170
}
167171

168172
return StagingConsumerFactory.builder(

airbyte-integrations/connectors/destination-snowflake/src/main/java/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeColumn.java

-11
This file was deleted.

airbyte-integrations/connectors/destination-snowflake/src/main/java/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeColumnDefinition.java

-19
This file was deleted.

airbyte-integrations/connectors/destination-snowflake/src/main/java/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeDestinationHandler.java

+193-53
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,46 @@
44

55
package io.airbyte.integrations.destination.snowflake.typing_deduping;
66

7+
import static io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_EXTRACTED_AT;
8+
import static io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_META;
9+
import static io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_RAW_ID;
10+
import static io.airbyte.cdk.integrations.base.JavaBaseConstants.V2_FINAL_TABLE_METADATA_COLUMNS;
11+
12+
import com.fasterxml.jackson.databind.JsonNode;
713
import io.airbyte.cdk.db.jdbc.JdbcDatabase;
8-
import io.airbyte.integrations.base.destination.typing_deduping.DestinationHandler;
14+
import io.airbyte.cdk.integrations.destination.jdbc.ColumnDefinition;
15+
import io.airbyte.cdk.integrations.destination.jdbc.TableDefinition;
16+
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcDestinationHandler;
17+
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteProtocolType;
18+
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteType;
19+
import io.airbyte.integrations.base.destination.typing_deduping.Array;
20+
import io.airbyte.integrations.base.destination.typing_deduping.ColumnId;
21+
import io.airbyte.integrations.base.destination.typing_deduping.DestinationInitialState;
22+
import io.airbyte.integrations.base.destination.typing_deduping.DestinationInitialStateImpl;
23+
import io.airbyte.integrations.base.destination.typing_deduping.InitialRawTableState;
924
import io.airbyte.integrations.base.destination.typing_deduping.Sql;
25+
import io.airbyte.integrations.base.destination.typing_deduping.StreamConfig;
1026
import io.airbyte.integrations.base.destination.typing_deduping.StreamId;
27+
import io.airbyte.integrations.base.destination.typing_deduping.Struct;
28+
import io.airbyte.integrations.base.destination.typing_deduping.Union;
29+
import io.airbyte.integrations.base.destination.typing_deduping.UnsupportedOneOf;
1130
import java.sql.ResultSet;
1231
import java.sql.SQLException;
1332
import java.time.Instant;
33+
import java.util.Collections;
1434
import java.util.LinkedHashMap;
1535
import java.util.List;
1636
import java.util.Map;
1737
import java.util.Optional;
38+
import java.util.Set;
1839
import java.util.UUID;
40+
import java.util.stream.Collectors;
1941
import net.snowflake.client.jdbc.SnowflakeSQLException;
2042
import org.apache.commons.text.StringSubstitutor;
2143
import org.slf4j.Logger;
2244
import org.slf4j.LoggerFactory;
2345

24-
public class SnowflakeDestinationHandler implements DestinationHandler<SnowflakeTableDefinition> {
46+
public class SnowflakeDestinationHandler extends JdbcDestinationHandler {
2547

2648
private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeDestinationHandler.class);
2749
public static final String EXCEPTION_COMMON_PREFIX = "JavaScript execution error: Uncaught Execution of multiple statements failed on statement";
@@ -30,60 +52,74 @@ public class SnowflakeDestinationHandler implements DestinationHandler<Snowflake
3052
private final JdbcDatabase database;
3153

3254
public SnowflakeDestinationHandler(final String databaseName, final JdbcDatabase database) {
55+
super(databaseName, database);
3356
this.databaseName = databaseName;
3457
this.database = database;
3558
}
3659

37-
@Override
38-
public Optional<SnowflakeTableDefinition> findExistingTable(final StreamId id) throws SQLException {
39-
// The obvious database.getMetaData().getColumns() solution doesn't work, because JDBC translates
40-
// VARIANT as VARCHAR
41-
final LinkedHashMap<String, SnowflakeColumnDefinition> columns = database.queryJsons(
42-
"""
43-
SELECT column_name, data_type, is_nullable
44-
FROM information_schema.columns
45-
WHERE table_catalog = ?
46-
AND table_schema = ?
47-
AND table_name = ?
48-
ORDER BY ordinal_position;
49-
""",
50-
databaseName.toUpperCase(),
51-
id.finalNamespace().toUpperCase(),
52-
id.finalName().toUpperCase()).stream()
53-
.collect(LinkedHashMap::new,
54-
(map, row) -> map.put(
55-
row.get("COLUMN_NAME").asText(),
56-
new SnowflakeColumnDefinition(row.get("DATA_TYPE").asText(), fromSnowflakeBoolean(row.get("IS_NULLABLE").asText()))),
57-
LinkedHashMap::putAll);
58-
if (columns.isEmpty()) {
59-
return Optional.empty();
60-
} else {
61-
return Optional.of(new SnowflakeTableDefinition(columns));
60+
public static LinkedHashMap<String, LinkedHashMap<String, TableDefinition>> findExistingTables(final JdbcDatabase database,
61+
final String databaseName,
62+
final List<StreamId> streamIds)
63+
throws SQLException {
64+
final LinkedHashMap<String, LinkedHashMap<String, TableDefinition>> existingTables = new LinkedHashMap<>();
65+
final String paramHolder = String.join(",", Collections.nCopies(streamIds.size(), "?"));
66+
// convert list stream to array
67+
final String[] namespaces = streamIds.stream().map(StreamId::finalNamespace).toArray(String[]::new);
68+
final String[] names = streamIds.stream().map(StreamId::finalName).toArray(String[]::new);
69+
final String query = """
70+
SELECT table_schema, table_name, column_name, data_type, is_nullable
71+
FROM information_schema.columns
72+
WHERE table_catalog = ?
73+
AND table_schema IN (%s)
74+
AND table_name IN (%s)
75+
ORDER BY table_schema, table_name, ordinal_position;
76+
""".formatted(paramHolder, paramHolder);
77+
final String[] bindValues = new String[streamIds.size() * 2 + 1];
78+
bindValues[0] = databaseName.toUpperCase();
79+
System.arraycopy(namespaces, 0, bindValues, 1, namespaces.length);
80+
System.arraycopy(names, 0, bindValues, namespaces.length + 1, names.length);
81+
final List<JsonNode> results = database.queryJsons(query, bindValues);
82+
for (final JsonNode result : results) {
83+
final String tableSchema = result.get("TABLE_SCHEMA").asText();
84+
final String tableName = result.get("TABLE_NAME").asText();
85+
final String columnName = result.get("COLUMN_NAME").asText();
86+
final String dataType = result.get("DATA_TYPE").asText();
87+
final String isNullable = result.get("IS_NULLABLE").asText();
88+
final TableDefinition tableDefinition = existingTables
89+
.computeIfAbsent(tableSchema, k -> new LinkedHashMap<>())
90+
.computeIfAbsent(tableName, k -> new TableDefinition(new LinkedHashMap<>()));
91+
tableDefinition.columns().put(columnName, new ColumnDefinition(columnName, dataType, 0, fromIsNullableIsoString(isNullable)));
6292
}
93+
return existingTables;
6394
}
6495

65-
@Override
66-
public LinkedHashMap<String, SnowflakeTableDefinition> findExistingFinalTables(final List<StreamId> list) throws Exception {
67-
return null;
68-
}
69-
70-
@Override
71-
public boolean isFinalTableEmpty(final StreamId id) throws SQLException {
72-
final int rowCount = database.queryInt(
73-
"""
74-
SELECT row_count
75-
FROM information_schema.tables
76-
WHERE table_catalog = ?
77-
AND table_schema = ?
78-
AND table_name = ?
79-
""",
80-
databaseName.toUpperCase(),
81-
id.finalNamespace().toUpperCase(),
82-
id.finalName().toUpperCase());
83-
return rowCount == 0;
96+
private LinkedHashMap<String, LinkedHashMap<String, Integer>> getFinalTableRowCount(final List<StreamId> streamIds) throws SQLException {
97+
final LinkedHashMap<String, LinkedHashMap<String, Integer>> tableRowCounts = new LinkedHashMap<>();
98+
final String paramHolder = String.join(",", Collections.nCopies(streamIds.size(), "?"));
99+
// convert list stream to array
100+
final String[] namespaces = streamIds.stream().map(StreamId::finalNamespace).toArray(String[]::new);
101+
final String[] names = streamIds.stream().map(StreamId::finalName).toArray(String[]::new);
102+
final String query = """
103+
SELECT table_schema, table_name, row_count
104+
FROM information_schema.tables
105+
WHERE table_catalog = ?
106+
AND table_schema IN (%s)
107+
AND table_name IN (%s)
108+
""".formatted(paramHolder, paramHolder);
109+
final String[] bindValues = new String[streamIds.size() * 2 + 1];
110+
bindValues[0] = databaseName.toUpperCase();
111+
System.arraycopy(namespaces, 0, bindValues, 1, namespaces.length);
112+
System.arraycopy(names, 0, bindValues, namespaces.length + 1, names.length);
113+
final List<JsonNode> results = database.queryJsons(query, bindValues);
114+
for (final JsonNode result : results) {
115+
final String tableSchema = result.get("TABLE_SCHEMA").asText();
116+
final String tableName = result.get("TABLE_NAME").asText();
117+
final int rowCount = result.get("ROW_COUNT").asInt();
118+
tableRowCounts.computeIfAbsent(tableSchema, k -> new LinkedHashMap<>()).put(tableName, rowCount);
119+
}
120+
return tableRowCounts;
84121
}
85122

86-
@Override
87123
public InitialRawTableState getInitialRawTableState(final StreamId id) throws Exception {
88124
final ResultSet tables = database.getMetaData().getTables(
89125
databaseName,
@@ -158,12 +194,116 @@ public void execute(final Sql sql) throws Exception {
158194
}
159195
}
160196

161-
/**
162-
* In snowflake information_schema tables, booleans return "YES" and "NO", which DataBind doesn't
163-
* know how to use
164-
*/
165-
private boolean fromSnowflakeBoolean(final String input) {
166-
return input.equalsIgnoreCase("yes");
197+
private Set<String> getPks(final StreamConfig stream) {
198+
return stream.primaryKey() != null ? stream.primaryKey().stream().map(ColumnId::name).collect(Collectors.toSet()) : Collections.emptySet();
199+
}
200+
201+
private boolean isAirbyteRawIdColumnMatch(final TableDefinition existingTable) {
202+
final String abRawIdColumnName = COLUMN_NAME_AB_RAW_ID.toUpperCase();
203+
return existingTable.columns().containsKey(abRawIdColumnName) &&
204+
toJdbcTypeName(AirbyteProtocolType.STRING).equals(existingTable.columns().get(abRawIdColumnName).type());
205+
}
206+
207+
private boolean isAirbyteExtractedAtColumnMatch(final TableDefinition existingTable) {
208+
final String abExtractedAtColumnName = COLUMN_NAME_AB_EXTRACTED_AT.toUpperCase();
209+
return existingTable.columns().containsKey(abExtractedAtColumnName) &&
210+
toJdbcTypeName(AirbyteProtocolType.TIMESTAMP_WITH_TIMEZONE).equals(existingTable.columns().get(abExtractedAtColumnName).type());
211+
}
212+
213+
private boolean isAirbyteMetaColumnMatch(TableDefinition existingTable) {
214+
final String abMetaColumnName = COLUMN_NAME_AB_META.toUpperCase();
215+
return existingTable.columns().containsKey(abMetaColumnName) &&
216+
"VARIANT".equals(existingTable.columns().get(abMetaColumnName).type());
217+
}
218+
219+
protected boolean existingSchemaMatchesStreamConfig(final StreamConfig stream, final TableDefinition existingTable) {
220+
final Set<String> pks = getPks(stream);
221+
// This is same as JdbcDestinationHandler#existingSchemaMatchesStreamConfig with upper case
222+
// conversion.
223+
// TODO: Unify this using name transformer or something.
224+
if (!isAirbyteRawIdColumnMatch(existingTable) ||
225+
!isAirbyteExtractedAtColumnMatch(existingTable) ||
226+
!isAirbyteMetaColumnMatch(existingTable)) {
227+
// Missing AB meta columns from final table, we need them to do proper T+D so trigger soft-reset
228+
return false;
229+
}
230+
final LinkedHashMap<String, String> intendedColumns = stream.columns().entrySet().stream()
231+
.collect(LinkedHashMap::new,
232+
(map, column) -> map.put(column.getKey().name(), toJdbcTypeName(column.getValue())),
233+
LinkedHashMap::putAll);
234+
235+
// Filter out Meta columns since they don't exist in stream config.
236+
final LinkedHashMap<String, String> actualColumns = existingTable.columns().entrySet().stream()
237+
.filter(column -> V2_FINAL_TABLE_METADATA_COLUMNS.stream().map(String::toUpperCase)
238+
.noneMatch(airbyteColumnName -> airbyteColumnName.equals(column.getKey())))
239+
.collect(LinkedHashMap::new,
240+
(map, column) -> map.put(column.getKey(), column.getValue().type()),
241+
LinkedHashMap::putAll);
242+
// soft-resetting https://github.com/airbytehq/airbyte/pull/31082
243+
@SuppressWarnings("deprecation")
244+
final boolean hasPksWithNonNullConstraint = existingTable.columns().entrySet().stream()
245+
.anyMatch(c -> pks.contains(c.getKey()) && !c.getValue().isNullable());
246+
247+
return !hasPksWithNonNullConstraint
248+
&& actualColumns.equals(intendedColumns);
249+
250+
}
251+
252+
@Override
253+
public List<DestinationInitialState> gatherInitialState(List<StreamConfig> streamConfigs) throws Exception {
254+
List<StreamId> streamIds = streamConfigs.stream().map(StreamConfig::id).toList();
255+
final LinkedHashMap<String, LinkedHashMap<String, TableDefinition>> existingTables = findExistingTables(database, databaseName, streamIds);
256+
final LinkedHashMap<String, LinkedHashMap<String, Integer>> tableRowCounts = getFinalTableRowCount(streamIds);
257+
return streamConfigs.stream().map(streamConfig -> {
258+
try {
259+
final String namespace = streamConfig.id().finalNamespace().toUpperCase();
260+
final String name = streamConfig.id().finalName().toUpperCase();
261+
boolean isSchemaMismatch = false;
262+
boolean isFinalTableEmpty = true;
263+
boolean isFinalTablePresent = existingTables.containsKey(namespace) && existingTables.get(namespace).containsKey(name);
264+
boolean hasRowCount = tableRowCounts.containsKey(namespace) && tableRowCounts.get(namespace).containsKey(name);
265+
if (isFinalTablePresent) {
266+
final TableDefinition existingTable = existingTables.get(namespace).get(name);
267+
isSchemaMismatch = !existingSchemaMatchesStreamConfig(streamConfig, existingTable);
268+
isFinalTableEmpty = hasRowCount && tableRowCounts.get(namespace).get(name) == 0;
269+
}
270+
final InitialRawTableState initialRawTableState = getInitialRawTableState(streamConfig.id());
271+
return new DestinationInitialStateImpl(streamConfig, isFinalTablePresent, initialRawTableState, isSchemaMismatch, isFinalTableEmpty);
272+
} catch (Exception e) {
273+
throw new RuntimeException(e);
274+
}
275+
}).collect(Collectors.toList());
276+
}
277+
278+
@Override
279+
protected String toJdbcTypeName(AirbyteType airbyteType) {
280+
if (airbyteType instanceof final AirbyteProtocolType p) {
281+
return toJdbcTypeName(p);
282+
}
283+
284+
return switch (airbyteType.getTypeName()) {
285+
case Struct.TYPE -> "OBJECT";
286+
case Array.TYPE -> "ARRAY";
287+
case UnsupportedOneOf.TYPE -> "VARIANT";
288+
case Union.TYPE -> toJdbcTypeName(((Union) airbyteType).chooseType());
289+
default -> throw new IllegalArgumentException("Unrecognized type: " + airbyteType.getTypeName());
290+
};
291+
}
292+
293+
private String toJdbcTypeName(final AirbyteProtocolType airbyteProtocolType) {
294+
return switch (airbyteProtocolType) {
295+
case STRING -> "TEXT";
296+
case NUMBER -> "FLOAT";
297+
case INTEGER -> "NUMBER";
298+
case BOOLEAN -> "BOOLEAN";
299+
case TIMESTAMP_WITH_TIMEZONE -> "TIMESTAMP_TZ";
300+
case TIMESTAMP_WITHOUT_TIMEZONE -> "TIMESTAMP_NTZ";
301+
// If you change this - also change the logic in extractAndCast
302+
case TIME_WITH_TIMEZONE -> "TEXT";
303+
case TIME_WITHOUT_TIMEZONE -> "TIME";
304+
case DATE -> "DATE";
305+
case UNKNOWN -> "VARIANT";
306+
};
167307
}
168308

169309
}

0 commit comments

Comments
 (0)