Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): add escape hatch methods to SqlParsingAggregator #9860

Merged
merged 6 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
import logging
import pathlib
import tempfile
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Callable, Dict, Iterable, List, Optional, Set, cast
from typing import Callable, Dict, Iterable, List, Optional, Set, Union, cast

import datahub.emitter.mce_builder as builder
import datahub.metadata.schema_classes as models
from datahub.emitter.mce_builder import get_sys_time, make_ts_millis
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.sql_parsing_builder import compute_upstream_fields
from datahub.ingestion.api.report import Report
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.usage.usage_common import BaseUsageConfig, UsageAggregator
from datahub.metadata.urns import (
Expand All @@ -32,7 +35,7 @@
infer_output_schema,
sqlglot_lineage,
)
from datahub.sql_parsing.sqlglot_utils import generate_hash
from datahub.sql_parsing.sqlglot_utils import generate_hash, get_query_fingerprint
from datahub.utilities.file_backed_collections import (
ConnectionWrapper,
FileBackedDict,
Expand All @@ -57,8 +60,6 @@ class QueryLogSetting(enum.Enum):

@dataclasses.dataclass
class ViewDefinition:
# TODO view urn?

view_definition: str
default_db: Optional[str] = None
default_schema: Optional[str] = None
Expand Down Expand Up @@ -95,6 +96,18 @@ def make_last_modified_audit_stamp(self) -> models.AuditStampClass:
)


@dataclasses.dataclass
class KnownQueryLineageInfo:
query_text: str

downstream: UrnStr
upstreams: List[UrnStr]
column_lineage: Optional[List[ColumnLineageInfo]] = None

timestamp: Optional[datetime] = None
query_type: QueryType = QueryType.UNKNOWN


@dataclasses.dataclass
class SqlAggregatorReport(Report):
_aggregator: "SqlParsingAggregator"
Expand All @@ -103,12 +116,16 @@ class SqlAggregatorReport(Report):
num_observed_queries: int = 0
num_observed_queries_failed: int = 0
num_observed_queries_column_failed: int = 0
observed_query_parse_failures = LossyList[str]()
observed_query_parse_failures: LossyList[str] = dataclasses.field(
default_factory=LossyList
)

num_view_definitions: int = 0
num_views_failed: int = 0
num_views_column_failed: int = 0
views_parse_failures = LossyDict[UrnStr, str]()
views_parse_failures: LossyDict[UrnStr, str] = dataclasses.field(
default_factory=LossyDict
)

num_queries_with_temp_tables_in_session: int = 0

Expand Down Expand Up @@ -142,8 +159,8 @@ def __init__(
self,
*,
platform: str,
platform_instance: Optional[str],
env: str,
platform_instance: Optional[str] = None,
env: str = builder.DEFAULT_ENV,
graph: Optional[DataHubGraph] = None,
generate_lineage: bool = True,
generate_queries: bool = True,
Expand Down Expand Up @@ -246,7 +263,7 @@ def _need_schemas(self) -> bool:
return self.generate_lineage or self.generate_usage_statistics

def register_schema(
self, urn: DatasetUrn, schema: models.SchemaMetadataClass
self, urn: Union[str, DatasetUrn], schema: models.SchemaMetadataClass
) -> None:
# If lineage or usage is enabled, adds the schema to the schema resolver
# by putting the condition in here, we can avoid all the conditional
Expand All @@ -255,6 +272,16 @@ def register_schema(
if self._need_schemas:
self._schema_resolver.add_schema_metadata(str(urn), schema)

def register_schemas_from_stream(
self, stream: Iterable[MetadataWorkUnit]
) -> Iterable[MetadataWorkUnit]:
for wu in stream:
schema_metadata = wu.get_aspect_of_type(models.SchemaMetadataClass)
if schema_metadata:
self.register_schema(wu.get_urn(), schema_metadata)

yield wu

def _initialize_schema_resolver_from_graph(self, graph: DataHubGraph) -> None:
# requires a graph instance
# if no schemas are currently registered in the schema resolver
Expand Down Expand Up @@ -284,6 +311,96 @@ def _initialize_schema_resolver_from_graph(self, graph: DataHubGraph) -> None:
env=self.env,
)

def add_known_query_lineage(
self, known_query_lineage: KnownQueryLineageInfo
) -> None:
"""Add a query and it's precomputed lineage to the aggregator.

This is useful for cases where we have lineage information that was
computed outside of the SQL parsing aggregator, e.g. from a data
warehouse's system tables.

This will also generate an operation aspect for the query if there is
a timestamp and the query type field is set to a mutation type.

Args:
known_query_lineage: The known query lineage information.
"""

# Generate a fingerprint for the query.
query_fingerprint = get_query_fingerprint(
known_query_lineage.query_text, self.platform.platform_name
)
# TODO format the query text?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think formatting the query might be useful if we don't do it on the frontend side when we are showing the queries, but it should be optional.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I will do it in a future PR


# Register the query.
self._add_to_query_map(
QueryMetadata(
query_id=query_fingerprint,
formatted_query_string=known_query_lineage.query_text,
session_id=_MISSING_SESSION_ID,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a value if we can pass in session id to the know lineage as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some case I think we should be able to capture

Copy link
Collaborator Author

@hsheth2 hsheth2 Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup I'll add that in the next PR

query_type=known_query_lineage.query_type,
lineage_type=models.DatasetLineageTypeClass.TRANSFORMED,
latest_timestamp=known_query_lineage.timestamp,
actor=None,
upstreams=known_query_lineage.upstreams,
column_lineage=known_query_lineage.column_lineage or [],
confidence_score=1.0,
)
)

# Register the lineage.
self._lineage_map.for_mutation(
known_query_lineage.downstream, OrderedSet()
).add(query_fingerprint)

def add_known_lineage_mapping(
self,
upstream_urn: UrnStr,
downstream_urn: UrnStr,
lineage_type: str = models.DatasetLineageTypeClass.COPY,
) -> None:
"""Add a known lineage mapping to the aggregator.

By mapping, we mean that the downstream is effectively a copy or
alias of the upstream. This is useful for things like external tables
(e.g. Redshift Spectrum, Redshift UNLOADs, Snowflake external tables).

Because this method takes in urns, it does not require that the urns
are part of the platform that the aggregator is configured for.

TODO: In the future, this method will also generate CLL if we have
schemas for either the upstream or downstream.

The known lineage mapping does not contribute to usage statistics or operations.

Args:
upstream_urn: The upstream dataset URN.
downstream_urn: The downstream dataset URN.
"""

# We generate a fake "query" object to hold the lineage.
query_id = self._known_lineage_query_id()

# Register the query.
self._add_to_query_map(
QueryMetadata(
query_id=query_id,
formatted_query_string="-skip-",
session_id=_MISSING_SESSION_ID,
query_type=QueryType.UNKNOWN,
lineage_type=lineage_type,
latest_timestamp=None,
actor=None,
upstreams=[upstream_urn],
column_lineage=[],
confidence_score=1.0,
)
)

# Register the lineage.
self._lineage_map.for_mutation(downstream_urn, OrderedSet()).add(query_id)

def add_view_definition(
self,
view_urn: DatasetUrn,
Expand Down Expand Up @@ -449,6 +566,10 @@ def _make_schema_resolver_for_session(
def _process_view_definition(
self, view_urn: UrnStr, view_definition: ViewDefinition
) -> None:
# Note that in some cases, the view definition will be a SELECT statement
# instead of a CREATE VIEW ... AS SELECT statement. In those cases, we can't
# trust the parsed query type or downstream urn.

# Run the SQL parser.
parsed = self._run_sql_parser(
view_definition.view_definition,
Expand All @@ -464,10 +585,6 @@ def _process_view_definition(
elif parsed.debug_info.error:
self.report.num_views_column_failed += 1

# Note that in some cases, the view definition will be a SELECT statement
# instead of a CREATE VIEW ... AS SELECT statement. In those cases, we can't
# trust the parsed query type or downstream urn.

query_fingerprint = self._view_query_id(view_urn)

# Register the query.
Expand Down Expand Up @@ -540,15 +657,6 @@ def _add_to_query_map(self, new: QueryMetadata) -> None:
else:
self._query_map[query_fingerprint] = new

"""
def add_lineage(self) -> None:
# A secondary mechanism for adding non-SQL-based lineage
# e.g. redshift external tables might use this when pointing at s3

# TODO Add this once we have a use case for it
pass
"""

def gen_metadata(self) -> Iterable[MetadataChangeProposalWrapper]:
# diff from v1 - we generate operations here, and it also
# generates MCPWs instead of workunits
Expand All @@ -569,7 +677,7 @@ def _gen_lineage_mcps(self) -> Iterable[MetadataChangeProposalWrapper]:

# Generate lineage and queries.
queries_generated: Set[QueryId] = set()
for downstream_urn in self._lineage_map:
for downstream_urn in sorted(self._lineage_map):
yield from self._gen_lineage_for_downstream(
downstream_urn, queries_generated=queries_generated
)
Expand Down Expand Up @@ -640,7 +748,9 @@ def _gen_lineage_for_downstream(
dataset=upstream_urn,
type=queries_map[query_id].lineage_type,
query=(
self._query_urn(query_id) if self.generate_queries else None
self._query_urn(query_id)
if self.can_generate_query(query_id)
else None
),
created=query.make_created_audit_stamp(),
auditStamp=models.AuditStampClass(
Expand Down Expand Up @@ -671,7 +781,9 @@ def _gen_lineage_for_downstream(
SchemaFieldUrn(downstream_urn, downstream_column).urn()
],
query=(
self._query_urn(query_id) if self.generate_queries else None
self._query_urn(query_id)
if self.can_generate_query(query_id)
else None
),
confidenceScore=queries_map[query_id].confidence_score,
)
Expand All @@ -682,9 +794,10 @@ def _gen_lineage_for_downstream(
aspect=upstream_aspect,
)

if not self.generate_queries:
return
for query_id in required_queries:
if not self.can_generate_query(query_id):
continue

# Avoid generating the same query twice.
if query_id in queries_generated:
continue
Expand All @@ -696,6 +809,7 @@ def _gen_lineage_for_downstream(
entityUrn=self._query_urn(query_id),
aspects=[
models.QueryPropertiesClass(
dataPlatform=self.platform.urn(),
statement=models.QueryStatementClass(
value=query.formatted_query_string,
language=models.QueryLanguageClass.SQL,
Expand Down Expand Up @@ -729,6 +843,19 @@ def _composite_query_id(cls, composed_of_queries: Iterable[QueryId]) -> str:
def _view_query_id(cls, view_urn: UrnStr) -> str:
return f"view_{DatasetUrn.url_encode(view_urn)}"

@classmethod
def _known_lineage_query_id(cls) -> str:
return f"known_{uuid.uuid4()}"

@classmethod
def _is_known_lineage_query_id(cls, query_id: QueryId) -> bool:
# Our query fingerprints are hex and won't have underscores, so this will
# never conflict with a real query fingerprint.
return query_id.startswith("known_")

def can_generate_query(self, query_id: QueryId) -> bool:
return self.generate_queries and not self._is_known_lineage_query_id(query_id)

def _resolve_query_with_temp_tables(
self,
base_query: QueryMetadata,
Expand Down Expand Up @@ -895,8 +1022,10 @@ def _gen_operation_for_downstream(
operationType=operation_type,
lastUpdatedTimestamp=make_ts_millis(query.latest_timestamp),
actor=query.actor.urn() if query.actor else None,
customProperties={
"query_urn": self._query_urn(query_id),
},
customProperties=(
{"query_urn": self._query_urn(query_id)}
if self.can_generate_query(query_id)
else None
),
)
yield MetadataChangeProposalWrapper(entityUrn=downstream_urn, aspect=aspect)
16 changes: 13 additions & 3 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import hashlib
import logging
from typing import Dict, Iterable, Optional, Union

import sqlglot
import sqlglot.errors

logger = logging.getLogger(__name__)
DialectOrStr = Union[sqlglot.Dialect, str]


Expand Down Expand Up @@ -139,10 +142,17 @@ def get_query_fingerprint(
The fingerprint for the SQL query.
"""

dialect = get_dialect(dialect)
expression_sql = generalize_query(expression, dialect=dialect)
fingerprint = generate_hash(expression_sql)
try:
dialect = get_dialect(dialect)
expression_sql = generalize_query(expression, dialect=dialect)
except (ValueError, sqlglot.errors.SqlglotError) as e:
if not isinstance(expression, str):
raise

logger.debug("Failed to generalize query for fingerprinting: %s", e)
expression_sql = expression

fingerprint = generate_hash(expression_sql)
return fingerprint


Expand Down
Loading
Loading