Skip to content

Commit 461025b

Browse files
test: added tests for sqlglot
1 parent 1aa0a24 commit 461025b

File tree

1 file changed

+353
-0
lines changed

1 file changed

+353
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
import doctest
2+
from typing import List
3+
4+
from datahub.sql_parsing.schema_resolver import SchemaResolver
5+
from datahub.sql_parsing.sqlglot_lineage import SqlParsingDebugInfo, sqlglot_lineage
6+
from datahub.utilities.delayed_iter import delayed_iter
7+
from datahub.utilities.is_pytest import is_pytest_running
8+
from datahub.utilities.urns.dataset_urn import DatasetUrn
9+
10+
11+
class SqlglotSQLParser:
12+
def __init__(self, sql_query: str, platform: str = "bigquery") -> None:
13+
self.result = sqlglot_lineage(sql_query, SchemaResolver(platform=platform))
14+
15+
def get_tables(self) -> List[str]:
16+
ans = []
17+
for urn in self.result.in_tables:
18+
table_ref = DatasetUrn.from_string(urn)
19+
ans.append(str(table_ref.name))
20+
return ans
21+
22+
def get_columns(self) -> List[str]:
23+
ans = set()
24+
for col_info in self.result.column_lineage or []:
25+
for col_ref in col_info.upstreams:
26+
ans.add(col_ref.column)
27+
return list(ans)
28+
29+
def get_downstream_columns(self) -> List[str]:
30+
ans = set()
31+
for col_info in self.result.column_lineage or []:
32+
ans.add(col_info.downstream.column)
33+
return list(ans)
34+
35+
def debug_info(self) -> SqlParsingDebugInfo:
36+
return self.result.debug_info
37+
38+
39+
def test_delayed_iter():
40+
events = []
41+
42+
def maker(n):
43+
for i in range(n):
44+
events.append(("add", i))
45+
yield i
46+
47+
for i in delayed_iter(maker(4), 2):
48+
events.append(("remove", i))
49+
50+
assert events == [
51+
("add", 0),
52+
("add", 1),
53+
("add", 2),
54+
("remove", 0),
55+
("add", 3),
56+
("remove", 1),
57+
("remove", 2),
58+
("remove", 3),
59+
]
60+
61+
events.clear()
62+
for i in delayed_iter(maker(2), None):
63+
events.append(("remove", i))
64+
65+
assert events == [
66+
("add", 0),
67+
("add", 1),
68+
("remove", 0),
69+
("remove", 1),
70+
]
71+
72+
73+
def test_sqllineage_sql_parser_get_tables_from_simple_query():
74+
sql_query = "SELECT foo.a, foo.b, bar.c FROM foo JOIN bar ON (foo.a == bar.b);"
75+
76+
tables_list = SqlglotSQLParser(sql_query).get_tables()
77+
tables_list.sort()
78+
assert tables_list == ["bar", "foo"]
79+
80+
81+
def test_sqllineage_sql_parser_get_tables_from_complex_query():
82+
sql_query = """
83+
(
84+
SELECT
85+
CAST(substring(e, 1, 10) AS date) AS __d_a_t_e,
86+
e AS e,
87+
u AS u,
88+
x,
89+
c,
90+
count(*)
91+
FROM
92+
schema1.foo
93+
WHERE
94+
datediff('day',
95+
substring(e, 1, 10) :: date,
96+
date :: date) <= 7
97+
AND CAST(substring(e, 1, 10) AS date) >= date('2010-01-01')
98+
AND CAST(substring(e, 1, 10) AS date) < getdate()
99+
GROUP BY
100+
1,
101+
2,
102+
3,
103+
4,
104+
5)
105+
UNION ALL(
106+
SELECT
107+
CAST(substring(e, 1, 10) AS date) AS date,
108+
e AS e,
109+
u AS u,
110+
x,
111+
c,
112+
count(*)
113+
FROM
114+
schema2.bar
115+
WHERE
116+
datediff('day',
117+
substring(e, 1, 10) :: date,
118+
date :: date) <= 7
119+
AND CAST(substring(e, 1, 10) AS date) >= date('2020-08-03')
120+
AND CAST(substring(e, 1, 10) AS date) < getdate()
121+
GROUP BY
122+
1,
123+
2,
124+
3,
125+
4,
126+
5)
127+
"""
128+
129+
tables_list = SqlglotSQLParser(sql_query).get_tables()
130+
tables_list.sort()
131+
assert tables_list == ["schema1.foo", "schema2.bar"]
132+
133+
134+
def test_sqllineage_sql_parser_get_columns_with_join():
135+
sql_query = "SELECT foo.a, foo.b, bar.c FROM foo JOIN bar ON (foo.a == bar.b);"
136+
137+
columns_list = SqlglotSQLParser(sql_query).get_columns()
138+
columns_list.sort()
139+
assert columns_list == ["a", "b", "c"]
140+
141+
142+
def test_sqllineage_sql_parser_get_columns_from_simple_query():
143+
sql_query = "SELECT foo.a, foo.b FROM foo;"
144+
145+
parser = SqlglotSQLParser(sql_query)
146+
assert sorted(parser.get_columns()) == ["a", "b"]
147+
148+
149+
def test_sqllineage_sql_parser_get_columns_with_alias_and_count_star():
150+
sql_query = "SELECT foo.a, foo.b, bar.c as test, count(*) as count FROM foo JOIN bar ON (foo.a == bar.b);"
151+
parser = SqlglotSQLParser(sql_query)
152+
assert sorted(parser.get_columns()) == ["a", "b", "c"]
153+
assert sorted(parser.get_downstream_columns()) == ["a", "b", "count", "test"]
154+
155+
156+
def test_sqllineage_sql_parser_get_columns_with_more_complex_join():
157+
sql_query = """
158+
INSERT
159+
INTO
160+
foo
161+
SELECT
162+
pl.pi pi,
163+
REGEXP_REPLACE(pl.tt, '_', ' ') pt,
164+
pl.tt pu,
165+
fp.v,
166+
fp.bs
167+
FROM
168+
bar pl
169+
JOIN baz fp ON
170+
fp.rt = pl.rt
171+
WHERE
172+
fp.dt = '2018-01-01'
173+
"""
174+
parser = SqlglotSQLParser(sql_query)
175+
assert sorted(parser.get_columns()) == ["bs", "pi", "tt", "v"]
176+
assert sorted(parser.get_downstream_columns()) == ["bs", "pi", "pt", "pu", "v"]
177+
178+
179+
def test_sqllineage_sql_parser_get_columns_complex_query_with_union():
180+
sql_query = """
181+
(
182+
SELECT
183+
CAST(substring(e, 1, 10) AS date) AS date ,
184+
e AS e,
185+
u AS u,
186+
x,
187+
c,
188+
count(*)
189+
FROM
190+
foo
191+
WHERE
192+
datediff('day',
193+
substring(e, 1, 10) :: date,
194+
date :: date) <= 7
195+
AND CAST(substring(e, 1, 10) AS date) >= date('2010-01-01')
196+
AND CAST(substring(e, 1, 10) AS date) < getdate()
197+
GROUP BY
198+
1,
199+
2,
200+
3,
201+
4,
202+
5)
203+
UNION ALL(
204+
SELECT
205+
CAST(substring(e, 1, 10) AS date) AS date,
206+
e AS e,
207+
u AS u,
208+
x,
209+
c,
210+
count(*)
211+
FROM
212+
bar
213+
WHERE
214+
datediff('day',
215+
substring(e, 1, 10) :: date,
216+
date :: date) <= 7
217+
AND CAST(substring(e, 1, 10) AS date) >= date('2020-08-03')
218+
AND CAST(substring(e, 1, 10) AS date) < getdate()
219+
GROUP BY
220+
1,
221+
2,
222+
3,
223+
4,
224+
5)
225+
"""
226+
parser = SqlglotSQLParser(sql_query)
227+
columns_list = parser.get_columns()
228+
assert sorted(columns_list) == ["c", "e", "u", "x"]
229+
assert sorted(parser.get_downstream_columns()) == [
230+
"c",
231+
"count(*)",
232+
"date",
233+
"e",
234+
"u",
235+
"x",
236+
]
237+
238+
239+
def test_sqllineage_sql_parser_get_tables_from_templated_query():
240+
sql_query = """
241+
SELECT
242+
country,
243+
city,
244+
timestamp,
245+
measurement
246+
FROM
247+
${my_view.SQL_TABLE_NAME} AS my_view
248+
"""
249+
parser = SqlglotSQLParser(sql_query)
250+
tables_list = parser.get_tables()
251+
tables_list.sort()
252+
assert tables_list == []
253+
assert parser.debug_info().table_error is None
254+
255+
256+
def test_sqllineage_sql_parser_get_columns_from_templated_query():
257+
sql_query = """
258+
SELECT
259+
country,
260+
city,
261+
timestamp,
262+
measurement
263+
FROM
264+
${my_view.SQL_TABLE_NAME} AS my_view
265+
"""
266+
parser = SqlglotSQLParser(sql_query)
267+
assert sorted(parser.get_columns()) == []
268+
assert sorted(parser.get_downstream_columns()) == [
269+
"city",
270+
"country",
271+
"measurement",
272+
"timestamp",
273+
]
274+
assert parser.debug_info().column_error is None
275+
276+
277+
def test_sqllineage_sql_parser_with_weird_lookml_query():
278+
sql_query = """
279+
SELECT date DATE,
280+
platform VARCHAR(20) AS aliased_platform,
281+
country VARCHAR(20) FROM fragment_derived_view'
282+
"""
283+
parser = SqlglotSQLParser(sql_query)
284+
columns_list = parser.get_columns()
285+
columns_list.sort()
286+
assert columns_list == []
287+
assert (
288+
str(parser.debug_info().table_error)
289+
== "Error tokenizing 'untry VARCHAR(20) FROM fragment_derived_view'\n ': Missing ' from 5:143"
290+
)
291+
292+
293+
def test_sqllineage_sql_parser_tables_from_redash_query():
294+
sql_query = """SELECT
295+
name,
296+
SUM(quantity * list_price * (1 - discount)) AS total,
297+
YEAR(order_date) as order_year
298+
FROM
299+
`orders` o
300+
INNER JOIN `order_items` i ON i.order_id = o.order_id
301+
INNER JOIN `staffs` s ON s.staff_id = o.staff_id
302+
GROUP BY
303+
name,
304+
year(order_date)"""
305+
table_list = SqlglotSQLParser(sql_query).get_tables()
306+
table_list.sort()
307+
assert table_list == ["order_items", "orders", "staffs"]
308+
309+
310+
def test_sqllineage_sql_parser_tables_with_special_names():
311+
# The hyphen appears after the special token in tables names, and before the special token in the column names.
312+
sql_query = """
313+
SELECT `column-date`, `column-hour`, `column-timestamp`, `column-data`, `column-admin`
314+
FROM `date-table` d
315+
JOIN `hour-table` h on d.`column-date`= h.`column-hour`
316+
JOIN `timestamp-table` t on d.`column-date` = t.`column-timestamp`
317+
JOIN `data-table` da on d.`column-date` = da.`column-data`
318+
JOIN `admin-table` a on d.`column-date` = a.`column-admin`
319+
"""
320+
expected_tables = [
321+
"admin-table",
322+
"data-table",
323+
"date-table",
324+
"hour-table",
325+
"timestamp-table",
326+
]
327+
expected_columns = [
328+
"column-admin",
329+
"column-data",
330+
"column-date",
331+
"column-hour",
332+
"column-timestamp",
333+
]
334+
assert sorted(SqlglotSQLParser(sql_query).get_tables()) == expected_tables
335+
assert sorted(SqlglotSQLParser(sql_query).get_columns()) == []
336+
assert (
337+
sorted(SqlglotSQLParser(sql_query).get_downstream_columns()) == expected_columns
338+
)
339+
340+
341+
def test_logging_name_extraction():
342+
import datahub.utilities.logging_manager
343+
344+
assert (
345+
doctest.testmod(
346+
datahub.utilities.logging_manager, raise_on_error=True
347+
).attempted
348+
> 0
349+
)
350+
351+
352+
def test_is_pytest_running() -> None:
353+
assert is_pytest_running()

0 commit comments

Comments
 (0)