1
1
import doctest
2
+ import re
2
3
from typing import List
3
4
4
5
from datahub .sql_parsing .schema_resolver import SchemaResolver
5
- from datahub .sql_parsing .sqlglot_lineage import SqlParsingDebugInfo , sqlglot_lineage
6
+ from datahub .sql_parsing .sqlglot_lineage import sqlglot_lineage
6
7
from datahub .utilities .delayed_iter import delayed_iter
7
8
from datahub .utilities .is_pytest import is_pytest_running
8
9
from datahub .utilities .urns .dataset_urn import DatasetUrn
9
10
10
11
11
- class SqlglotSQLParser :
12
+ class SqlLineageSQLParser :
13
+ """
14
+ It uses `sqlglot_lineage` to extract tables and columns, serving as a replacement for the `sqllineage` implementation, similar to BigQuery.
15
+ Reference: [BigQuery SQL Lineage Test](https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/tests/unit/bigquery/test_bigquery_sql_lineage.py#L8).
16
+ """
17
+
18
+ _MYVIEW_SQL_TABLE_NAME_TOKEN = "__my_view__.__sql_table_name__"
19
+ _MYVIEW_LOOKER_TOKEN = "my_view.SQL_TABLE_NAME"
20
+
12
21
def __init__ (self , sql_query : str , platform : str = "bigquery" ) -> None :
13
- self .result = sqlglot_lineage (sql_query , SchemaResolver (platform = platform ))
22
+ # SqlLineageParser lowercarese tablenames and we need to replace Looker specific token which should be uppercased
23
+ sql_query = re .sub (
24
+ rf"(\${{{ self ._MYVIEW_LOOKER_TOKEN } }})" ,
25
+ rf"{ self ._MYVIEW_SQL_TABLE_NAME_TOKEN } " ,
26
+ sql_query ,
27
+ )
28
+ self .sql_query = sql_query
29
+ self .schema_resolver = SchemaResolver (platform = platform )
30
+ self .result = sqlglot_lineage (sql_query , self .schema_resolver )
14
31
15
32
def get_tables (self ) -> List [str ]:
16
33
ans = []
17
34
for urn in self .result .in_tables :
18
35
table_ref = DatasetUrn .from_string (urn )
19
36
ans .append (str (table_ref .name ))
20
- return ans
37
+
38
+ result = [
39
+ self ._MYVIEW_LOOKER_TOKEN if c == self ._MYVIEW_SQL_TABLE_NAME_TOKEN else c
40
+ for c in ans
41
+ ]
42
+ # Sort tables to make the list deterministic
43
+ result .sort ()
44
+
45
+ return result
21
46
22
47
def get_columns (self ) -> List [str ]:
23
- ans = set ()
48
+ ans = []
24
49
for col_info in self .result .column_lineage or []:
25
50
for col_ref in col_info .upstreams :
26
- ans .add (col_ref .column )
27
- return list (ans )
28
-
29
- def get_downstream_columns (self ) -> List [str ]:
30
- ans = set ()
31
- for col_info in self .result .column_lineage or []:
32
- ans .add (col_info .downstream .column )
33
- return list (ans )
34
-
35
- def debug_info (self ) -> SqlParsingDebugInfo :
36
- return self .result .debug_info
51
+ ans .append (col_ref .column )
52
+ return ans
37
53
38
54
39
55
def test_delayed_iter ():
@@ -73,7 +89,7 @@ def maker(n):
73
89
def test_sqllineage_sql_parser_get_tables_from_simple_query ():
74
90
sql_query = "SELECT foo.a, foo.b, bar.c FROM foo JOIN bar ON (foo.a == bar.b);"
75
91
76
- tables_list = SqlglotSQLParser (sql_query ).get_tables ()
92
+ tables_list = SqlLineageSQLParser (sql_query ).get_tables ()
77
93
tables_list .sort ()
78
94
assert tables_list == ["bar" , "foo" ]
79
95
@@ -126,31 +142,33 @@ def test_sqllineage_sql_parser_get_tables_from_complex_query():
126
142
5)
127
143
"""
128
144
129
- tables_list = SqlglotSQLParser (sql_query ).get_tables ()
145
+ tables_list = SqlLineageSQLParser (sql_query ).get_tables ()
130
146
tables_list .sort ()
131
147
assert tables_list == ["schema1.foo" , "schema2.bar" ]
132
148
133
149
134
150
def test_sqllineage_sql_parser_get_columns_with_join ():
135
151
sql_query = "SELECT foo.a, foo.b, bar.c FROM foo JOIN bar ON (foo.a == bar.b);"
136
152
137
- columns_list = SqlglotSQLParser (sql_query ).get_columns ()
153
+ columns_list = SqlLineageSQLParser (sql_query ).get_columns ()
138
154
columns_list .sort ()
139
155
assert columns_list == ["a" , "b" , "c" ]
140
156
141
157
142
158
def test_sqllineage_sql_parser_get_columns_from_simple_query ():
143
159
sql_query = "SELECT foo.a, foo.b FROM foo;"
144
160
145
- parser = SqlglotSQLParser (sql_query )
146
- assert sorted (parser .get_columns ()) == ["a" , "b" ]
161
+ columns_list = SqlLineageSQLParser (sql_query ).get_columns ()
162
+ columns_list .sort ()
163
+ assert columns_list == ["a" , "b" ]
147
164
148
165
149
166
def test_sqllineage_sql_parser_get_columns_with_alias_and_count_star ():
150
167
sql_query = "SELECT foo.a, foo.b, bar.c as test, count(*) as count FROM foo JOIN bar ON (foo.a == bar.b);"
151
- parser = SqlglotSQLParser (sql_query )
152
- assert sorted (parser .get_columns ()) == ["a" , "b" , "c" ]
153
- assert sorted (parser .get_downstream_columns ()) == ["a" , "b" , "count" , "test" ]
168
+
169
+ columns_list = SqlLineageSQLParser (sql_query ).get_columns ()
170
+ columns_list .sort ()
171
+ assert columns_list == ["a" , "b" , "c" ]
154
172
155
173
156
174
def test_sqllineage_sql_parser_get_columns_with_more_complex_join ():
@@ -171,9 +189,10 @@ def test_sqllineage_sql_parser_get_columns_with_more_complex_join():
171
189
WHERE
172
190
fp.dt = '2018-01-01'
173
191
"""
174
- parser = SqlglotSQLParser (sql_query )
175
- assert sorted (parser .get_columns ()) == ["bs" , "pi" , "tt" , "v" ]
176
- assert sorted (parser .get_downstream_columns ()) == ["bs" , "pi" , "pt" , "pu" , "v" ]
192
+
193
+ columns_list = SqlLineageSQLParser (sql_query ).get_columns ()
194
+ columns_list .sort ()
195
+ assert columns_list == ["bs" , "pi" , "tt" , "tt" , "v" ]
177
196
178
197
179
198
def test_sqllineage_sql_parser_get_columns_complex_query_with_union ():
@@ -223,17 +242,10 @@ def test_sqllineage_sql_parser_get_columns_complex_query_with_union():
223
242
4,
224
243
5)
225
244
"""
226
- parser = SqlglotSQLParser (sql_query )
227
- columns_list = parser .get_columns ()
228
- assert sorted (columns_list ) == ["c" , "e" , "u" , "x" ]
229
- assert sorted (parser .get_downstream_columns ()) == [
230
- "c" ,
231
- "count(*)" ,
232
- "date" ,
233
- "e" ,
234
- "u" ,
235
- "x" ,
236
- ]
245
+
246
+ columns_list = SqlLineageSQLParser (sql_query ).get_columns ()
247
+ columns_list .sort ()
248
+ assert columns_list == ["c" , "c" , "e" , "e" , "e" , "e" , "u" , "u" , "x" , "x" ]
237
249
238
250
239
251
def test_sqllineage_sql_parser_get_tables_from_templated_query ():
@@ -246,11 +258,9 @@ def test_sqllineage_sql_parser_get_tables_from_templated_query():
246
258
FROM
247
259
${my_view.SQL_TABLE_NAME} AS my_view
248
260
"""
249
- parser = SqlglotSQLParser (sql_query )
250
- tables_list = parser .get_tables ()
261
+ tables_list = SqlLineageSQLParser (sql_query ).get_tables ()
251
262
tables_list .sort ()
252
- assert tables_list == []
253
- assert parser .debug_info ().table_error is None
263
+ assert tables_list == ["my_view.SQL_TABLE_NAME" ]
254
264
255
265
256
266
def test_sqllineage_sql_parser_get_columns_from_templated_query ():
@@ -263,15 +273,9 @@ def test_sqllineage_sql_parser_get_columns_from_templated_query():
263
273
FROM
264
274
${my_view.SQL_TABLE_NAME} AS my_view
265
275
"""
266
- parser = SqlglotSQLParser (sql_query )
267
- assert sorted (parser .get_columns ()) == []
268
- assert sorted (parser .get_downstream_columns ()) == [
269
- "city" ,
270
- "country" ,
271
- "measurement" ,
272
- "timestamp" ,
273
- ]
274
- assert parser .debug_info ().column_error is None
276
+ columns_list = SqlLineageSQLParser (sql_query ).get_columns ()
277
+ columns_list .sort ()
278
+ assert columns_list == ["city" , "country" , "measurement" , "timestamp" ]
275
279
276
280
277
281
def test_sqllineage_sql_parser_with_weird_lookml_query ():
@@ -280,14 +284,9 @@ def test_sqllineage_sql_parser_with_weird_lookml_query():
280
284
platform VARCHAR(20) AS aliased_platform,
281
285
country VARCHAR(20) FROM fragment_derived_view'
282
286
"""
283
- parser = SqlglotSQLParser (sql_query )
284
- columns_list = parser .get_columns ()
287
+ columns_list = SqlLineageSQLParser (sql_query ).get_columns ()
285
288
columns_list .sort ()
286
289
assert columns_list == []
287
- assert (
288
- str (parser .debug_info ().table_error )
289
- == "Error tokenizing 'untry VARCHAR(20) FROM fragment_derived_view'\n ': Missing ' from 5:143"
290
- )
291
290
292
291
293
292
def test_sqllineage_sql_parser_tables_from_redash_query ():
@@ -302,7 +301,7 @@ def test_sqllineage_sql_parser_tables_from_redash_query():
302
301
GROUP BY
303
302
name,
304
303
year(order_date)"""
305
- table_list = SqlglotSQLParser (sql_query ).get_tables ()
304
+ table_list = SqlLineageSQLParser (sql_query ).get_tables ()
306
305
table_list .sort ()
307
306
assert table_list == ["order_items" , "orders" , "staffs" ]
308
307
@@ -324,18 +323,9 @@ def test_sqllineage_sql_parser_tables_with_special_names():
324
323
"hour-table" ,
325
324
"timestamp-table" ,
326
325
]
327
- expected_columns = [
328
- "column-admin" ,
329
- "column-data" ,
330
- "column-date" ,
331
- "column-hour" ,
332
- "column-timestamp" ,
333
- ]
334
- assert sorted (SqlglotSQLParser (sql_query ).get_tables ()) == expected_tables
335
- assert sorted (SqlglotSQLParser (sql_query ).get_columns ()) == []
336
- assert (
337
- sorted (SqlglotSQLParser (sql_query ).get_downstream_columns ()) == expected_columns
338
- )
326
+ expected_columns : List [str ] = []
327
+ assert sorted (SqlLineageSQLParser (sql_query ).get_tables ()) == expected_tables
328
+ assert sorted (SqlLineageSQLParser (sql_query ).get_columns ()) == expected_columns
339
329
340
330
341
331
def test_logging_name_extraction ():
0 commit comments