4
4
5
5
package io .airbyte .integrations .destination .snowflake .typing_deduping ;
6
6
7
+ import static io .airbyte .cdk .integrations .base .JavaBaseConstants .COLUMN_NAME_AB_EXTRACTED_AT ;
8
+ import static io .airbyte .cdk .integrations .base .JavaBaseConstants .COLUMN_NAME_AB_META ;
9
+ import static io .airbyte .cdk .integrations .base .JavaBaseConstants .COLUMN_NAME_AB_RAW_ID ;
10
+ import static io .airbyte .cdk .integrations .base .JavaBaseConstants .V2_FINAL_TABLE_METADATA_COLUMNS ;
11
+
12
+ import com .fasterxml .jackson .databind .JsonNode ;
7
13
import io .airbyte .cdk .db .jdbc .JdbcDatabase ;
8
- import io .airbyte .integrations .base .destination .typing_deduping .DestinationHandler ;
14
+ import io .airbyte .cdk .integrations .destination .jdbc .ColumnDefinition ;
15
+ import io .airbyte .cdk .integrations .destination .jdbc .TableDefinition ;
16
+ import io .airbyte .cdk .integrations .destination .jdbc .typing_deduping .JdbcDestinationHandler ;
17
+ import io .airbyte .integrations .base .destination .typing_deduping .AirbyteProtocolType ;
18
+ import io .airbyte .integrations .base .destination .typing_deduping .AirbyteType ;
19
+ import io .airbyte .integrations .base .destination .typing_deduping .Array ;
20
+ import io .airbyte .integrations .base .destination .typing_deduping .ColumnId ;
21
+ import io .airbyte .integrations .base .destination .typing_deduping .DestinationInitialState ;
22
+ import io .airbyte .integrations .base .destination .typing_deduping .DestinationInitialStateImpl ;
23
+ import io .airbyte .integrations .base .destination .typing_deduping .InitialRawTableState ;
9
24
import io .airbyte .integrations .base .destination .typing_deduping .Sql ;
25
+ import io .airbyte .integrations .base .destination .typing_deduping .StreamConfig ;
10
26
import io .airbyte .integrations .base .destination .typing_deduping .StreamId ;
27
+ import io .airbyte .integrations .base .destination .typing_deduping .Struct ;
28
+ import io .airbyte .integrations .base .destination .typing_deduping .Union ;
29
+ import io .airbyte .integrations .base .destination .typing_deduping .UnsupportedOneOf ;
11
30
import java .sql .ResultSet ;
12
31
import java .sql .SQLException ;
13
32
import java .time .Instant ;
33
+ import java .util .Collections ;
14
34
import java .util .LinkedHashMap ;
15
35
import java .util .List ;
16
36
import java .util .Map ;
17
37
import java .util .Optional ;
38
+ import java .util .Set ;
18
39
import java .util .UUID ;
40
+ import java .util .stream .Collectors ;
19
41
import net .snowflake .client .jdbc .SnowflakeSQLException ;
20
42
import org .apache .commons .text .StringSubstitutor ;
21
43
import org .slf4j .Logger ;
22
44
import org .slf4j .LoggerFactory ;
23
45
24
- public class SnowflakeDestinationHandler implements DestinationHandler < SnowflakeTableDefinition > {
46
+ public class SnowflakeDestinationHandler extends JdbcDestinationHandler {
25
47
26
48
private static final Logger LOGGER = LoggerFactory .getLogger (SnowflakeDestinationHandler .class );
27
49
public static final String EXCEPTION_COMMON_PREFIX = "JavaScript execution error: Uncaught Execution of multiple statements failed on statement" ;
@@ -30,60 +52,74 @@ public class SnowflakeDestinationHandler implements DestinationHandler<Snowflake
30
52
private final JdbcDatabase database ;
31
53
32
54
public SnowflakeDestinationHandler (final String databaseName , final JdbcDatabase database ) {
55
+ super (databaseName , database );
33
56
this .databaseName = databaseName ;
34
57
this .database = database ;
35
58
}
36
59
37
- @ Override
38
- public Optional <SnowflakeTableDefinition > findExistingTable (final StreamId id ) throws SQLException {
39
- // The obvious database.getMetaData().getColumns() solution doesn't work, because JDBC translates
40
- // VARIANT as VARCHAR
41
- final LinkedHashMap <String , SnowflakeColumnDefinition > columns = database .queryJsons (
42
- """
43
- SELECT column_name, data_type, is_nullable
44
- FROM information_schema.columns
45
- WHERE table_catalog = ?
46
- AND table_schema = ?
47
- AND table_name = ?
48
- ORDER BY ordinal_position;
49
- """ ,
50
- databaseName .toUpperCase (),
51
- id .finalNamespace ().toUpperCase (),
52
- id .finalName ().toUpperCase ()).stream ()
53
- .collect (LinkedHashMap ::new ,
54
- (map , row ) -> map .put (
55
- row .get ("COLUMN_NAME" ).asText (),
56
- new SnowflakeColumnDefinition (row .get ("DATA_TYPE" ).asText (), fromSnowflakeBoolean (row .get ("IS_NULLABLE" ).asText ()))),
57
- LinkedHashMap ::putAll );
58
- if (columns .isEmpty ()) {
59
- return Optional .empty ();
60
- } else {
61
- return Optional .of (new SnowflakeTableDefinition (columns ));
60
+ public static LinkedHashMap <String , LinkedHashMap <String , TableDefinition >> findExistingTables (final JdbcDatabase database ,
61
+ final String databaseName ,
62
+ final List <StreamId > streamIds )
63
+ throws SQLException {
64
+ final LinkedHashMap <String , LinkedHashMap <String , TableDefinition >> existingTables = new LinkedHashMap <>();
65
+ final String paramHolder = String .join ("," , Collections .nCopies (streamIds .size (), "?" ));
66
+ // convert list stream to array
67
+ final String [] namespaces = streamIds .stream ().map (StreamId ::finalNamespace ).toArray (String []::new );
68
+ final String [] names = streamIds .stream ().map (StreamId ::finalName ).toArray (String []::new );
69
+ final String query = """
70
+ SELECT table_schema, table_name, column_name, data_type, is_nullable
71
+ FROM information_schema.columns
72
+ WHERE table_catalog = ?
73
+ AND table_schema IN (%s)
74
+ AND table_name IN (%s)
75
+ ORDER BY table_schema, table_name, ordinal_position;
76
+ """ .formatted (paramHolder , paramHolder );
77
+ final String [] bindValues = new String [streamIds .size () * 2 + 1 ];
78
+ bindValues [0 ] = databaseName .toUpperCase ();
79
+ System .arraycopy (namespaces , 0 , bindValues , 1 , namespaces .length );
80
+ System .arraycopy (names , 0 , bindValues , namespaces .length + 1 , names .length );
81
+ final List <JsonNode > results = database .queryJsons (query , bindValues );
82
+ for (final JsonNode result : results ) {
83
+ final String tableSchema = result .get ("TABLE_SCHEMA" ).asText ();
84
+ final String tableName = result .get ("TABLE_NAME" ).asText ();
85
+ final String columnName = result .get ("COLUMN_NAME" ).asText ();
86
+ final String dataType = result .get ("DATA_TYPE" ).asText ();
87
+ final String isNullable = result .get ("IS_NULLABLE" ).asText ();
88
+ final TableDefinition tableDefinition = existingTables
89
+ .computeIfAbsent (tableSchema , k -> new LinkedHashMap <>())
90
+ .computeIfAbsent (tableName , k -> new TableDefinition (new LinkedHashMap <>()));
91
+ tableDefinition .columns ().put (columnName , new ColumnDefinition (columnName , dataType , 0 , fromIsNullableIsoString (isNullable )));
62
92
}
93
+ return existingTables ;
63
94
}
64
95
65
- @ Override
66
- public LinkedHashMap <String , SnowflakeTableDefinition > findExistingFinalTables (final List <StreamId > list ) throws Exception {
67
- return null ;
68
- }
69
-
70
- @ Override
71
- public boolean isFinalTableEmpty (final StreamId id ) throws SQLException {
72
- final int rowCount = database .queryInt (
73
- """
74
- SELECT row_count
75
- FROM information_schema.tables
76
- WHERE table_catalog = ?
77
- AND table_schema = ?
78
- AND table_name = ?
79
- """ ,
80
- databaseName .toUpperCase (),
81
- id .finalNamespace ().toUpperCase (),
82
- id .finalName ().toUpperCase ());
83
- return rowCount == 0 ;
96
+ private LinkedHashMap <String , LinkedHashMap <String , Integer >> getFinalTableRowCount (final List <StreamId > streamIds ) throws SQLException {
97
+ final LinkedHashMap <String , LinkedHashMap <String , Integer >> tableRowCounts = new LinkedHashMap <>();
98
+ final String paramHolder = String .join ("," , Collections .nCopies (streamIds .size (), "?" ));
99
+ // convert list stream to array
100
+ final String [] namespaces = streamIds .stream ().map (StreamId ::finalNamespace ).toArray (String []::new );
101
+ final String [] names = streamIds .stream ().map (StreamId ::finalName ).toArray (String []::new );
102
+ final String query = """
103
+ SELECT table_schema, table_name, row_count
104
+ FROM information_schema.tables
105
+ WHERE table_catalog = ?
106
+ AND table_schema IN (%s)
107
+ AND table_name IN (%s)
108
+ """ .formatted (paramHolder , paramHolder );
109
+ final String [] bindValues = new String [streamIds .size () * 2 + 1 ];
110
+ bindValues [0 ] = databaseName .toUpperCase ();
111
+ System .arraycopy (namespaces , 0 , bindValues , 1 , namespaces .length );
112
+ System .arraycopy (names , 0 , bindValues , namespaces .length + 1 , names .length );
113
+ final List <JsonNode > results = database .queryJsons (query , bindValues );
114
+ for (final JsonNode result : results ) {
115
+ final String tableSchema = result .get ("TABLE_SCHEMA" ).asText ();
116
+ final String tableName = result .get ("TABLE_NAME" ).asText ();
117
+ final int rowCount = result .get ("ROW_COUNT" ).asInt ();
118
+ tableRowCounts .computeIfAbsent (tableSchema , k -> new LinkedHashMap <>()).put (tableName , rowCount );
119
+ }
120
+ return tableRowCounts ;
84
121
}
85
122
86
- @ Override
87
123
public InitialRawTableState getInitialRawTableState (final StreamId id ) throws Exception {
88
124
final ResultSet tables = database .getMetaData ().getTables (
89
125
databaseName ,
@@ -158,12 +194,116 @@ public void execute(final Sql sql) throws Exception {
158
194
}
159
195
}
160
196
161
- /**
162
- * In snowflake information_schema tables, booleans return "YES" and "NO", which DataBind doesn't
163
- * know how to use
164
- */
165
- private boolean fromSnowflakeBoolean (final String input ) {
166
- return input .equalsIgnoreCase ("yes" );
197
+ private Set <String > getPks (final StreamConfig stream ) {
198
+ return stream .primaryKey () != null ? stream .primaryKey ().stream ().map (ColumnId ::name ).collect (Collectors .toSet ()) : Collections .emptySet ();
199
+ }
200
+
201
+ private boolean isAirbyteRawIdColumnMatch (final TableDefinition existingTable ) {
202
+ final String abRawIdColumnName = COLUMN_NAME_AB_RAW_ID .toUpperCase ();
203
+ return existingTable .columns ().containsKey (abRawIdColumnName ) &&
204
+ toJdbcTypeName (AirbyteProtocolType .STRING ).equals (existingTable .columns ().get (abRawIdColumnName ).type ());
205
+ }
206
+
207
+ private boolean isAirbyteExtractedAtColumnMatch (final TableDefinition existingTable ) {
208
+ final String abExtractedAtColumnName = COLUMN_NAME_AB_EXTRACTED_AT .toUpperCase ();
209
+ return existingTable .columns ().containsKey (abExtractedAtColumnName ) &&
210
+ toJdbcTypeName (AirbyteProtocolType .TIMESTAMP_WITH_TIMEZONE ).equals (existingTable .columns ().get (abExtractedAtColumnName ).type ());
211
+ }
212
+
213
+ private boolean isAirbyteMetaColumnMatch (TableDefinition existingTable ) {
214
+ final String abMetaColumnName = COLUMN_NAME_AB_META .toUpperCase ();
215
+ return existingTable .columns ().containsKey (abMetaColumnName ) &&
216
+ "VARIANT" .equals (existingTable .columns ().get (abMetaColumnName ).type ());
217
+ }
218
+
219
+ protected boolean existingSchemaMatchesStreamConfig (final StreamConfig stream , final TableDefinition existingTable ) {
220
+ final Set <String > pks = getPks (stream );
221
+ // This is same as JdbcDestinationHandler#existingSchemaMatchesStreamConfig with upper case
222
+ // conversion.
223
+ // TODO: Unify this using name transformer or something.
224
+ if (!isAirbyteRawIdColumnMatch (existingTable ) ||
225
+ !isAirbyteExtractedAtColumnMatch (existingTable ) ||
226
+ !isAirbyteMetaColumnMatch (existingTable )) {
227
+ // Missing AB meta columns from final table, we need them to do proper T+D so trigger soft-reset
228
+ return false ;
229
+ }
230
+ final LinkedHashMap <String , String > intendedColumns = stream .columns ().entrySet ().stream ()
231
+ .collect (LinkedHashMap ::new ,
232
+ (map , column ) -> map .put (column .getKey ().name (), toJdbcTypeName (column .getValue ())),
233
+ LinkedHashMap ::putAll );
234
+
235
+ // Filter out Meta columns since they don't exist in stream config.
236
+ final LinkedHashMap <String , String > actualColumns = existingTable .columns ().entrySet ().stream ()
237
+ .filter (column -> V2_FINAL_TABLE_METADATA_COLUMNS .stream ().map (String ::toUpperCase )
238
+ .noneMatch (airbyteColumnName -> airbyteColumnName .equals (column .getKey ())))
239
+ .collect (LinkedHashMap ::new ,
240
+ (map , column ) -> map .put (column .getKey (), column .getValue ().type ()),
241
+ LinkedHashMap ::putAll );
242
+ // soft-resetting https://github.com/airbytehq/airbyte/pull/31082
243
+ @ SuppressWarnings ("deprecation" )
244
+ final boolean hasPksWithNonNullConstraint = existingTable .columns ().entrySet ().stream ()
245
+ .anyMatch (c -> pks .contains (c .getKey ()) && !c .getValue ().isNullable ());
246
+
247
+ return !hasPksWithNonNullConstraint
248
+ && actualColumns .equals (intendedColumns );
249
+
250
+ }
251
+
252
+ @ Override
253
+ public List <DestinationInitialState > gatherInitialState (List <StreamConfig > streamConfigs ) throws Exception {
254
+ List <StreamId > streamIds = streamConfigs .stream ().map (StreamConfig ::id ).toList ();
255
+ final LinkedHashMap <String , LinkedHashMap <String , TableDefinition >> existingTables = findExistingTables (database , databaseName , streamIds );
256
+ final LinkedHashMap <String , LinkedHashMap <String , Integer >> tableRowCounts = getFinalTableRowCount (streamIds );
257
+ return streamConfigs .stream ().map (streamConfig -> {
258
+ try {
259
+ final String namespace = streamConfig .id ().finalNamespace ().toUpperCase ();
260
+ final String name = streamConfig .id ().finalName ().toUpperCase ();
261
+ boolean isSchemaMismatch = false ;
262
+ boolean isFinalTableEmpty = true ;
263
+ boolean isFinalTablePresent = existingTables .containsKey (namespace ) && existingTables .get (namespace ).containsKey (name );
264
+ boolean hasRowCount = tableRowCounts .containsKey (namespace ) && tableRowCounts .get (namespace ).containsKey (name );
265
+ if (isFinalTablePresent ) {
266
+ final TableDefinition existingTable = existingTables .get (namespace ).get (name );
267
+ isSchemaMismatch = !existingSchemaMatchesStreamConfig (streamConfig , existingTable );
268
+ isFinalTableEmpty = hasRowCount && tableRowCounts .get (namespace ).get (name ) == 0 ;
269
+ }
270
+ final InitialRawTableState initialRawTableState = getInitialRawTableState (streamConfig .id ());
271
+ return new DestinationInitialStateImpl (streamConfig , isFinalTablePresent , initialRawTableState , isSchemaMismatch , isFinalTableEmpty );
272
+ } catch (Exception e ) {
273
+ throw new RuntimeException (e );
274
+ }
275
+ }).collect (Collectors .toList ());
276
+ }
277
+
278
+ @ Override
279
+ protected String toJdbcTypeName (AirbyteType airbyteType ) {
280
+ if (airbyteType instanceof final AirbyteProtocolType p ) {
281
+ return toJdbcTypeName (p );
282
+ }
283
+
284
+ return switch (airbyteType .getTypeName ()) {
285
+ case Struct .TYPE -> "OBJECT" ;
286
+ case Array .TYPE -> "ARRAY" ;
287
+ case UnsupportedOneOf .TYPE -> "VARIANT" ;
288
+ case Union .TYPE -> toJdbcTypeName (((Union ) airbyteType ).chooseType ());
289
+ default -> throw new IllegalArgumentException ("Unrecognized type: " + airbyteType .getTypeName ());
290
+ };
291
+ }
292
+
293
+ private String toJdbcTypeName (final AirbyteProtocolType airbyteProtocolType ) {
294
+ return switch (airbyteProtocolType ) {
295
+ case STRING -> "TEXT" ;
296
+ case NUMBER -> "FLOAT" ;
297
+ case INTEGER -> "NUMBER" ;
298
+ case BOOLEAN -> "BOOLEAN" ;
299
+ case TIMESTAMP_WITH_TIMEZONE -> "TIMESTAMP_TZ" ;
300
+ case TIMESTAMP_WITHOUT_TIMEZONE -> "TIMESTAMP_NTZ" ;
301
+ // If you change this - also change the logic in extractAndCast
302
+ case TIME_WITH_TIMEZONE -> "TEXT" ;
303
+ case TIME_WITHOUT_TIMEZONE -> "TIME" ;
304
+ case DATE -> "DATE" ;
305
+ case UNKNOWN -> "VARIANT" ;
306
+ };
167
307
}
168
308
169
309
}
0 commit comments