Skip to content

Commit c1d7596

Browse files
fix: PR comments
1 parent 4427621 commit c1d7596

File tree

4 files changed

+67
-88
lines changed

4 files changed

+67
-88
lines changed

metadata-ingestion-modules/gx-plugin/setup.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,6 @@ def get_long_description():
1515

1616
rest_common = {"requests", "requests_file"}
1717

18-
sqlglot_lib = {
19-
# We heavily monkeypatch sqlglot.
20-
# Prior to the patching, we originally maintained an acryl-sqlglot fork:
21-
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
22-
"sqlglot[rs]==25.26.0",
23-
"patchy==2.8.0",
24-
}
25-
2618
_version: str = package_metadata["__version__"]
2719
_self_pin = (
2820
f"=={_version}"
@@ -42,8 +34,7 @@ def get_long_description():
4234
# https://github.com/ipython/traitlets/issues/741
4335
"traitlets<5.2.2",
4436
*rest_common,
45-
*sqlglot_lib,
46-
f"acryl-datahub[datahub-rest]{_self_pin}",
37+
f"acryl-datahub[datahub-rest,sql-parser]{_self_pin}",
4738
}
4839

4940
mypy_stubs = {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Note! The integration can use an SQL parser to try to parse the tables the chart depends on. This parsing is disabled by default,
2+
but can be enabled by setting `parse_table_names_from_sql: true`. The parser is based on the [`sqlglot`](https://pypi.org/project/sqlglot/) package.

metadata-ingestion/src/datahub/ingestion/source/unity/usage.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def _get_workunits_internal(
8080
) -> Iterable[MetadataWorkUnit]:
8181
table_map = defaultdict(list)
8282
query_hashes = set()
83+
print("table_refs", table_refs)
8384
for ref in table_refs:
8485
table_map[ref.table].append(ref)
8586
table_map[f"{ref.schema}.{ref.table}"].append(ref)
@@ -175,6 +176,7 @@ def _parse_query(
175176
self, query: Query, table_map: TableMap
176177
) -> Optional[QueryTableInfo]:
177178
with self.report.usage_perf_report.sql_parsing_timer:
179+
breakpoint()
178180
table_info = self._parse_query_via_sqlglot(query.query_text)
179181
if table_info is None and query.statement_type == QueryStatementType.SELECT:
180182
with self.report.usage_perf_report.spark_sql_parsing_timer:
@@ -218,20 +220,14 @@ def _parse_query_via_sqlglot(self, query: str) -> Optional[StringTableInfo]:
218220
return None
219221

220222
@staticmethod
221-
def _parse_sqllineage_table(sqllineage_table: str) -> str:
222-
full_table_name = str(sqllineage_table)
223+
def _parse_sqlglot_table(table_urn: str) -> str:
224+
full_table_name = DatasetUrn.from_string(table_urn).name
223225
default_schema = "<default>."
224226
if full_table_name.startswith(default_schema):
225227
return full_table_name[len(default_schema) :]
226228
else:
227229
return full_table_name
228230

229-
@staticmethod
230-
def _parse_sqlglot_table(table_urn: str) -> str:
231-
return UnityCatalogUsageExtractor._parse_sqllineage_table(
232-
DatasetUrn.from_string(table_urn).name
233-
)
234-
235231
def _parse_query_via_spark_sql_plan(self, query: str) -> Optional[StringTableInfo]:
236232
"""Parse query source tables via Spark SQL plan. This is a fallback option."""
237233
# Would be more effective if we upgrade pyspark

metadata-ingestion/tests/unit/utilities/test_utilities.py

+60-70
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,55 @@
11
import doctest
2+
import re
23
from typing import List
34

45
from datahub.sql_parsing.schema_resolver import SchemaResolver
5-
from datahub.sql_parsing.sqlglot_lineage import SqlParsingDebugInfo, sqlglot_lineage
6+
from datahub.sql_parsing.sqlglot_lineage import sqlglot_lineage
67
from datahub.utilities.delayed_iter import delayed_iter
78
from datahub.utilities.is_pytest import is_pytest_running
89
from datahub.utilities.urns.dataset_urn import DatasetUrn
910

1011

11-
class SqlglotSQLParser:
12+
class SqlLineageSQLParser:
13+
"""
14+
It uses `sqlglot_lineage` to extract tables and columns, serving as a replacement for the `sqllineage` implementation, similar to BigQuery.
15+
Reference: [BigQuery SQL Lineage Test](https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/tests/unit/bigquery/test_bigquery_sql_lineage.py#L8).
16+
"""
17+
18+
_MYVIEW_SQL_TABLE_NAME_TOKEN = "__my_view__.__sql_table_name__"
19+
_MYVIEW_LOOKER_TOKEN = "my_view.SQL_TABLE_NAME"
20+
1221
def __init__(self, sql_query: str, platform: str = "bigquery") -> None:
13-
self.result = sqlglot_lineage(sql_query, SchemaResolver(platform=platform))
22+
# SqlLineageParser lowercarese tablenames and we need to replace Looker specific token which should be uppercased
23+
sql_query = re.sub(
24+
rf"(\${{{self._MYVIEW_LOOKER_TOKEN}}})",
25+
rf"{self._MYVIEW_SQL_TABLE_NAME_TOKEN}",
26+
sql_query,
27+
)
28+
self.sql_query = sql_query
29+
self.schema_resolver = SchemaResolver(platform=platform)
30+
self.result = sqlglot_lineage(sql_query, self.schema_resolver)
1431

1532
def get_tables(self) -> List[str]:
1633
ans = []
1734
for urn in self.result.in_tables:
1835
table_ref = DatasetUrn.from_string(urn)
1936
ans.append(str(table_ref.name))
20-
return ans
37+
38+
result = [
39+
self._MYVIEW_LOOKER_TOKEN if c == self._MYVIEW_SQL_TABLE_NAME_TOKEN else c
40+
for c in ans
41+
]
42+
# Sort tables to make the list deterministic
43+
result.sort()
44+
45+
return result
2146

2247
def get_columns(self) -> List[str]:
23-
ans = set()
48+
ans = []
2449
for col_info in self.result.column_lineage or []:
2550
for col_ref in col_info.upstreams:
26-
ans.add(col_ref.column)
27-
return list(ans)
28-
29-
def get_downstream_columns(self) -> List[str]:
30-
ans = set()
31-
for col_info in self.result.column_lineage or []:
32-
ans.add(col_info.downstream.column)
33-
return list(ans)
34-
35-
def debug_info(self) -> SqlParsingDebugInfo:
36-
return self.result.debug_info
51+
ans.append(col_ref.column)
52+
return ans
3753

3854

3955
def test_delayed_iter():
@@ -73,7 +89,7 @@ def maker(n):
7389
def test_sqllineage_sql_parser_get_tables_from_simple_query():
7490
sql_query = "SELECT foo.a, foo.b, bar.c FROM foo JOIN bar ON (foo.a == bar.b);"
7591

76-
tables_list = SqlglotSQLParser(sql_query).get_tables()
92+
tables_list = SqlLineageSQLParser(sql_query).get_tables()
7793
tables_list.sort()
7894
assert tables_list == ["bar", "foo"]
7995

@@ -126,31 +142,33 @@ def test_sqllineage_sql_parser_get_tables_from_complex_query():
126142
5)
127143
"""
128144

129-
tables_list = SqlglotSQLParser(sql_query).get_tables()
145+
tables_list = SqlLineageSQLParser(sql_query).get_tables()
130146
tables_list.sort()
131147
assert tables_list == ["schema1.foo", "schema2.bar"]
132148

133149

134150
def test_sqllineage_sql_parser_get_columns_with_join():
135151
sql_query = "SELECT foo.a, foo.b, bar.c FROM foo JOIN bar ON (foo.a == bar.b);"
136152

137-
columns_list = SqlglotSQLParser(sql_query).get_columns()
153+
columns_list = SqlLineageSQLParser(sql_query).get_columns()
138154
columns_list.sort()
139155
assert columns_list == ["a", "b", "c"]
140156

141157

142158
def test_sqllineage_sql_parser_get_columns_from_simple_query():
143159
sql_query = "SELECT foo.a, foo.b FROM foo;"
144160

145-
parser = SqlglotSQLParser(sql_query)
146-
assert sorted(parser.get_columns()) == ["a", "b"]
161+
columns_list = SqlLineageSQLParser(sql_query).get_columns()
162+
columns_list.sort()
163+
assert columns_list == ["a", "b"]
147164

148165

149166
def test_sqllineage_sql_parser_get_columns_with_alias_and_count_star():
150167
sql_query = "SELECT foo.a, foo.b, bar.c as test, count(*) as count FROM foo JOIN bar ON (foo.a == bar.b);"
151-
parser = SqlglotSQLParser(sql_query)
152-
assert sorted(parser.get_columns()) == ["a", "b", "c"]
153-
assert sorted(parser.get_downstream_columns()) == ["a", "b", "count", "test"]
168+
169+
columns_list = SqlLineageSQLParser(sql_query).get_columns()
170+
columns_list.sort()
171+
assert columns_list == ["a", "b", "c"]
154172

155173

156174
def test_sqllineage_sql_parser_get_columns_with_more_complex_join():
@@ -171,9 +189,10 @@ def test_sqllineage_sql_parser_get_columns_with_more_complex_join():
171189
WHERE
172190
fp.dt = '2018-01-01'
173191
"""
174-
parser = SqlglotSQLParser(sql_query)
175-
assert sorted(parser.get_columns()) == ["bs", "pi", "tt", "v"]
176-
assert sorted(parser.get_downstream_columns()) == ["bs", "pi", "pt", "pu", "v"]
192+
193+
columns_list = SqlLineageSQLParser(sql_query).get_columns()
194+
columns_list.sort()
195+
assert columns_list == ["bs", "pi", "tt", "tt", "v"]
177196

178197

179198
def test_sqllineage_sql_parser_get_columns_complex_query_with_union():
@@ -223,17 +242,10 @@ def test_sqllineage_sql_parser_get_columns_complex_query_with_union():
223242
4,
224243
5)
225244
"""
226-
parser = SqlglotSQLParser(sql_query)
227-
columns_list = parser.get_columns()
228-
assert sorted(columns_list) == ["c", "e", "u", "x"]
229-
assert sorted(parser.get_downstream_columns()) == [
230-
"c",
231-
"count(*)",
232-
"date",
233-
"e",
234-
"u",
235-
"x",
236-
]
245+
246+
columns_list = SqlLineageSQLParser(sql_query).get_columns()
247+
columns_list.sort()
248+
assert columns_list == ["c", "c", "e", "e", "e", "e", "u", "u", "x", "x"]
237249

238250

239251
def test_sqllineage_sql_parser_get_tables_from_templated_query():
@@ -246,11 +258,9 @@ def test_sqllineage_sql_parser_get_tables_from_templated_query():
246258
FROM
247259
${my_view.SQL_TABLE_NAME} AS my_view
248260
"""
249-
parser = SqlglotSQLParser(sql_query)
250-
tables_list = parser.get_tables()
261+
tables_list = SqlLineageSQLParser(sql_query).get_tables()
251262
tables_list.sort()
252-
assert tables_list == []
253-
assert parser.debug_info().table_error is None
263+
assert tables_list == ["my_view.SQL_TABLE_NAME"]
254264

255265

256266
def test_sqllineage_sql_parser_get_columns_from_templated_query():
@@ -263,15 +273,9 @@ def test_sqllineage_sql_parser_get_columns_from_templated_query():
263273
FROM
264274
${my_view.SQL_TABLE_NAME} AS my_view
265275
"""
266-
parser = SqlglotSQLParser(sql_query)
267-
assert sorted(parser.get_columns()) == []
268-
assert sorted(parser.get_downstream_columns()) == [
269-
"city",
270-
"country",
271-
"measurement",
272-
"timestamp",
273-
]
274-
assert parser.debug_info().column_error is None
276+
columns_list = SqlLineageSQLParser(sql_query).get_columns()
277+
columns_list.sort()
278+
assert columns_list == ["city", "country", "measurement", "timestamp"]
275279

276280

277281
def test_sqllineage_sql_parser_with_weird_lookml_query():
@@ -280,14 +284,9 @@ def test_sqllineage_sql_parser_with_weird_lookml_query():
280284
platform VARCHAR(20) AS aliased_platform,
281285
country VARCHAR(20) FROM fragment_derived_view'
282286
"""
283-
parser = SqlglotSQLParser(sql_query)
284-
columns_list = parser.get_columns()
287+
columns_list = SqlLineageSQLParser(sql_query).get_columns()
285288
columns_list.sort()
286289
assert columns_list == []
287-
assert (
288-
str(parser.debug_info().table_error)
289-
== "Error tokenizing 'untry VARCHAR(20) FROM fragment_derived_view'\n ': Missing ' from 5:143"
290-
)
291290

292291

293292
def test_sqllineage_sql_parser_tables_from_redash_query():
@@ -302,7 +301,7 @@ def test_sqllineage_sql_parser_tables_from_redash_query():
302301
GROUP BY
303302
name,
304303
year(order_date)"""
305-
table_list = SqlglotSQLParser(sql_query).get_tables()
304+
table_list = SqlLineageSQLParser(sql_query).get_tables()
306305
table_list.sort()
307306
assert table_list == ["order_items", "orders", "staffs"]
308307

@@ -324,18 +323,9 @@ def test_sqllineage_sql_parser_tables_with_special_names():
324323
"hour-table",
325324
"timestamp-table",
326325
]
327-
expected_columns = [
328-
"column-admin",
329-
"column-data",
330-
"column-date",
331-
"column-hour",
332-
"column-timestamp",
333-
]
334-
assert sorted(SqlglotSQLParser(sql_query).get_tables()) == expected_tables
335-
assert sorted(SqlglotSQLParser(sql_query).get_columns()) == []
336-
assert (
337-
sorted(SqlglotSQLParser(sql_query).get_downstream_columns()) == expected_columns
338-
)
326+
expected_columns: List[str] = []
327+
assert sorted(SqlLineageSQLParser(sql_query).get_tables()) == expected_tables
328+
assert sorted(SqlLineageSQLParser(sql_query).get_columns()) == expected_columns
339329

340330

341331
def test_logging_name_extraction():

0 commit comments

Comments
 (0)