Skip to content

Commit f13ae77

Browse files
authored
feat(ingest): add escape hatch methods to SqlParsingAggregator (#9860)
1 parent ac1ee6c commit f13ae77

11 files changed

+730
-135
lines changed

metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator_v2.py metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py

+159-30
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55
import logging
66
import pathlib
77
import tempfile
8+
import uuid
89
from collections import defaultdict
910
from datetime import datetime, timezone
10-
from typing import Callable, Dict, Iterable, List, Optional, Set, cast
11+
from typing import Callable, Dict, Iterable, List, Optional, Set, Union, cast
1112

13+
import datahub.emitter.mce_builder as builder
1214
import datahub.metadata.schema_classes as models
1315
from datahub.emitter.mce_builder import get_sys_time, make_ts_millis
1416
from datahub.emitter.mcp import MetadataChangeProposalWrapper
1517
from datahub.emitter.sql_parsing_builder import compute_upstream_fields
1618
from datahub.ingestion.api.report import Report
19+
from datahub.ingestion.api.workunit import MetadataWorkUnit
1720
from datahub.ingestion.graph.client import DataHubGraph
1821
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig, UsageAggregator
1922
from datahub.metadata.urns import (
@@ -32,7 +35,7 @@
3235
infer_output_schema,
3336
sqlglot_lineage,
3437
)
35-
from datahub.sql_parsing.sqlglot_utils import generate_hash
38+
from datahub.sql_parsing.sqlglot_utils import generate_hash, get_query_fingerprint
3639
from datahub.utilities.file_backed_collections import (
3740
ConnectionWrapper,
3841
FileBackedDict,
@@ -57,8 +60,6 @@ class QueryLogSetting(enum.Enum):
5760

5861
@dataclasses.dataclass
5962
class ViewDefinition:
60-
# TODO view urn?
61-
6263
view_definition: str
6364
default_db: Optional[str] = None
6465
default_schema: Optional[str] = None
@@ -95,6 +96,18 @@ def make_last_modified_audit_stamp(self) -> models.AuditStampClass:
9596
)
9697

9798

99+
@dataclasses.dataclass
100+
class KnownQueryLineageInfo:
101+
query_text: str
102+
103+
downstream: UrnStr
104+
upstreams: List[UrnStr]
105+
column_lineage: Optional[List[ColumnLineageInfo]] = None
106+
107+
timestamp: Optional[datetime] = None
108+
query_type: QueryType = QueryType.UNKNOWN
109+
110+
98111
@dataclasses.dataclass
99112
class SqlAggregatorReport(Report):
100113
_aggregator: "SqlParsingAggregator"
@@ -103,12 +116,16 @@ class SqlAggregatorReport(Report):
103116
num_observed_queries: int = 0
104117
num_observed_queries_failed: int = 0
105118
num_observed_queries_column_failed: int = 0
106-
observed_query_parse_failures = LossyList[str]()
119+
observed_query_parse_failures: LossyList[str] = dataclasses.field(
120+
default_factory=LossyList
121+
)
107122

108123
num_view_definitions: int = 0
109124
num_views_failed: int = 0
110125
num_views_column_failed: int = 0
111-
views_parse_failures = LossyDict[UrnStr, str]()
126+
views_parse_failures: LossyDict[UrnStr, str] = dataclasses.field(
127+
default_factory=LossyDict
128+
)
112129

113130
num_queries_with_temp_tables_in_session: int = 0
114131

@@ -142,8 +159,8 @@ def __init__(
142159
self,
143160
*,
144161
platform: str,
145-
platform_instance: Optional[str],
146-
env: str,
162+
platform_instance: Optional[str] = None,
163+
env: str = builder.DEFAULT_ENV,
147164
graph: Optional[DataHubGraph] = None,
148165
generate_lineage: bool = True,
149166
generate_queries: bool = True,
@@ -246,7 +263,7 @@ def _need_schemas(self) -> bool:
246263
return self.generate_lineage or self.generate_usage_statistics
247264

248265
def register_schema(
249-
self, urn: DatasetUrn, schema: models.SchemaMetadataClass
266+
self, urn: Union[str, DatasetUrn], schema: models.SchemaMetadataClass
250267
) -> None:
251268
# If lineage or usage is enabled, adds the schema to the schema resolver
252269
# by putting the condition in here, we can avoid all the conditional
@@ -255,6 +272,16 @@ def register_schema(
255272
if self._need_schemas:
256273
self._schema_resolver.add_schema_metadata(str(urn), schema)
257274

275+
def register_schemas_from_stream(
276+
self, stream: Iterable[MetadataWorkUnit]
277+
) -> Iterable[MetadataWorkUnit]:
278+
for wu in stream:
279+
schema_metadata = wu.get_aspect_of_type(models.SchemaMetadataClass)
280+
if schema_metadata:
281+
self.register_schema(wu.get_urn(), schema_metadata)
282+
283+
yield wu
284+
258285
def _initialize_schema_resolver_from_graph(self, graph: DataHubGraph) -> None:
259286
# requires a graph instance
260287
# if no schemas are currently registered in the schema resolver
@@ -284,6 +311,96 @@ def _initialize_schema_resolver_from_graph(self, graph: DataHubGraph) -> None:
284311
env=self.env,
285312
)
286313

314+
def add_known_query_lineage(
315+
self, known_query_lineage: KnownQueryLineageInfo
316+
) -> None:
317+
"""Add a query and it's precomputed lineage to the aggregator.
318+
319+
This is useful for cases where we have lineage information that was
320+
computed outside of the SQL parsing aggregator, e.g. from a data
321+
warehouse's system tables.
322+
323+
This will also generate an operation aspect for the query if there is
324+
a timestamp and the query type field is set to a mutation type.
325+
326+
Args:
327+
known_query_lineage: The known query lineage information.
328+
"""
329+
330+
# Generate a fingerprint for the query.
331+
query_fingerprint = get_query_fingerprint(
332+
known_query_lineage.query_text, self.platform.platform_name
333+
)
334+
# TODO format the query text?
335+
336+
# Register the query.
337+
self._add_to_query_map(
338+
QueryMetadata(
339+
query_id=query_fingerprint,
340+
formatted_query_string=known_query_lineage.query_text,
341+
session_id=_MISSING_SESSION_ID,
342+
query_type=known_query_lineage.query_type,
343+
lineage_type=models.DatasetLineageTypeClass.TRANSFORMED,
344+
latest_timestamp=known_query_lineage.timestamp,
345+
actor=None,
346+
upstreams=known_query_lineage.upstreams,
347+
column_lineage=known_query_lineage.column_lineage or [],
348+
confidence_score=1.0,
349+
)
350+
)
351+
352+
# Register the lineage.
353+
self._lineage_map.for_mutation(
354+
known_query_lineage.downstream, OrderedSet()
355+
).add(query_fingerprint)
356+
357+
def add_known_lineage_mapping(
358+
self,
359+
upstream_urn: UrnStr,
360+
downstream_urn: UrnStr,
361+
lineage_type: str = models.DatasetLineageTypeClass.COPY,
362+
) -> None:
363+
"""Add a known lineage mapping to the aggregator.
364+
365+
By mapping, we mean that the downstream is effectively a copy or
366+
alias of the upstream. This is useful for things like external tables
367+
(e.g. Redshift Spectrum, Redshift UNLOADs, Snowflake external tables).
368+
369+
Because this method takes in urns, it does not require that the urns
370+
are part of the platform that the aggregator is configured for.
371+
372+
TODO: In the future, this method will also generate CLL if we have
373+
schemas for either the upstream or downstream.
374+
375+
The known lineage mapping does not contribute to usage statistics or operations.
376+
377+
Args:
378+
upstream_urn: The upstream dataset URN.
379+
downstream_urn: The downstream dataset URN.
380+
"""
381+
382+
# We generate a fake "query" object to hold the lineage.
383+
query_id = self._known_lineage_query_id()
384+
385+
# Register the query.
386+
self._add_to_query_map(
387+
QueryMetadata(
388+
query_id=query_id,
389+
formatted_query_string="-skip-",
390+
session_id=_MISSING_SESSION_ID,
391+
query_type=QueryType.UNKNOWN,
392+
lineage_type=lineage_type,
393+
latest_timestamp=None,
394+
actor=None,
395+
upstreams=[upstream_urn],
396+
column_lineage=[],
397+
confidence_score=1.0,
398+
)
399+
)
400+
401+
# Register the lineage.
402+
self._lineage_map.for_mutation(downstream_urn, OrderedSet()).add(query_id)
403+
287404
def add_view_definition(
288405
self,
289406
view_urn: DatasetUrn,
@@ -449,6 +566,10 @@ def _make_schema_resolver_for_session(
449566
def _process_view_definition(
450567
self, view_urn: UrnStr, view_definition: ViewDefinition
451568
) -> None:
569+
# Note that in some cases, the view definition will be a SELECT statement
570+
# instead of a CREATE VIEW ... AS SELECT statement. In those cases, we can't
571+
# trust the parsed query type or downstream urn.
572+
452573
# Run the SQL parser.
453574
parsed = self._run_sql_parser(
454575
view_definition.view_definition,
@@ -464,10 +585,6 @@ def _process_view_definition(
464585
elif parsed.debug_info.error:
465586
self.report.num_views_column_failed += 1
466587

467-
# Note that in some cases, the view definition will be a SELECT statement
468-
# instead of a CREATE VIEW ... AS SELECT statement. In those cases, we can't
469-
# trust the parsed query type or downstream urn.
470-
471588
query_fingerprint = self._view_query_id(view_urn)
472589

473590
# Register the query.
@@ -540,15 +657,6 @@ def _add_to_query_map(self, new: QueryMetadata) -> None:
540657
else:
541658
self._query_map[query_fingerprint] = new
542659

543-
"""
544-
def add_lineage(self) -> None:
545-
# A secondary mechanism for adding non-SQL-based lineage
546-
# e.g. redshift external tables might use this when pointing at s3
547-
548-
# TODO Add this once we have a use case for it
549-
pass
550-
"""
551-
552660
def gen_metadata(self) -> Iterable[MetadataChangeProposalWrapper]:
553661
# diff from v1 - we generate operations here, and it also
554662
# generates MCPWs instead of workunits
@@ -569,7 +677,7 @@ def _gen_lineage_mcps(self) -> Iterable[MetadataChangeProposalWrapper]:
569677

570678
# Generate lineage and queries.
571679
queries_generated: Set[QueryId] = set()
572-
for downstream_urn in self._lineage_map:
680+
for downstream_urn in sorted(self._lineage_map):
573681
yield from self._gen_lineage_for_downstream(
574682
downstream_urn, queries_generated=queries_generated
575683
)
@@ -640,7 +748,9 @@ def _gen_lineage_for_downstream(
640748
dataset=upstream_urn,
641749
type=queries_map[query_id].lineage_type,
642750
query=(
643-
self._query_urn(query_id) if self.generate_queries else None
751+
self._query_urn(query_id)
752+
if self.can_generate_query(query_id)
753+
else None
644754
),
645755
created=query.make_created_audit_stamp(),
646756
auditStamp=models.AuditStampClass(
@@ -671,7 +781,9 @@ def _gen_lineage_for_downstream(
671781
SchemaFieldUrn(downstream_urn, downstream_column).urn()
672782
],
673783
query=(
674-
self._query_urn(query_id) if self.generate_queries else None
784+
self._query_urn(query_id)
785+
if self.can_generate_query(query_id)
786+
else None
675787
),
676788
confidenceScore=queries_map[query_id].confidence_score,
677789
)
@@ -682,9 +794,10 @@ def _gen_lineage_for_downstream(
682794
aspect=upstream_aspect,
683795
)
684796

685-
if not self.generate_queries:
686-
return
687797
for query_id in required_queries:
798+
if not self.can_generate_query(query_id):
799+
continue
800+
688801
# Avoid generating the same query twice.
689802
if query_id in queries_generated:
690803
continue
@@ -696,6 +809,7 @@ def _gen_lineage_for_downstream(
696809
entityUrn=self._query_urn(query_id),
697810
aspects=[
698811
models.QueryPropertiesClass(
812+
dataPlatform=self.platform.urn(),
699813
statement=models.QueryStatementClass(
700814
value=query.formatted_query_string,
701815
language=models.QueryLanguageClass.SQL,
@@ -729,6 +843,19 @@ def _composite_query_id(cls, composed_of_queries: Iterable[QueryId]) -> str:
729843
def _view_query_id(cls, view_urn: UrnStr) -> str:
730844
return f"view_{DatasetUrn.url_encode(view_urn)}"
731845

846+
@classmethod
847+
def _known_lineage_query_id(cls) -> str:
848+
return f"known_{uuid.uuid4()}"
849+
850+
@classmethod
851+
def _is_known_lineage_query_id(cls, query_id: QueryId) -> bool:
852+
# Our query fingerprints are hex and won't have underscores, so this will
853+
# never conflict with a real query fingerprint.
854+
return query_id.startswith("known_")
855+
856+
def can_generate_query(self, query_id: QueryId) -> bool:
857+
return self.generate_queries and not self._is_known_lineage_query_id(query_id)
858+
732859
def _resolve_query_with_temp_tables(
733860
self,
734861
base_query: QueryMetadata,
@@ -895,8 +1022,10 @@ def _gen_operation_for_downstream(
8951022
operationType=operation_type,
8961023
lastUpdatedTimestamp=make_ts_millis(query.latest_timestamp),
8971024
actor=query.actor.urn() if query.actor else None,
898-
customProperties={
899-
"query_urn": self._query_urn(query_id),
900-
},
1025+
customProperties=(
1026+
{"query_urn": self._query_urn(query_id)}
1027+
if self.can_generate_query(query_id)
1028+
else None
1029+
),
9011030
)
9021031
yield MetadataChangeProposalWrapper(entityUrn=downstream_urn, aspect=aspect)

metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import hashlib
2+
import logging
23
from typing import Dict, Iterable, Optional, Union
34

45
import sqlglot
6+
import sqlglot.errors
57

8+
logger = logging.getLogger(__name__)
69
DialectOrStr = Union[sqlglot.Dialect, str]
710

811

@@ -139,10 +142,17 @@ def get_query_fingerprint(
139142
The fingerprint for the SQL query.
140143
"""
141144

142-
dialect = get_dialect(dialect)
143-
expression_sql = generalize_query(expression, dialect=dialect)
144-
fingerprint = generate_hash(expression_sql)
145+
try:
146+
dialect = get_dialect(dialect)
147+
expression_sql = generalize_query(expression, dialect=dialect)
148+
except (ValueError, sqlglot.errors.SqlglotError) as e:
149+
if not isinstance(expression, str):
150+
raise
145151

152+
logger.debug("Failed to generalize query for fingerprinting: %s", e)
153+
expression_sql = expression
154+
155+
fingerprint = generate_hash(expression_sql)
146156
return fingerprint
147157

148158

0 commit comments

Comments
 (0)