Skip to content

Commit a646185

Browse files
authored
feat(ingest): improve extract-sql-agg-log command (#12803)
1 parent 41b0629 commit a646185

File tree

3 files changed

+98
-36
lines changed

3 files changed

+98
-36
lines changed

metadata-ingestion/src/datahub/cli/check_cli.py

+72-19
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import pprint
66
import shutil
77
import tempfile
8-
from typing import Dict, List, Optional, Union
8+
from datetime import datetime
9+
from typing import Any, Dict, List, Optional, Union
910

1011
import click
1112

@@ -20,7 +21,10 @@
2021
from datahub.ingestion.source.source_registry import source_registry
2122
from datahub.ingestion.transformer.transform_registry import transform_registry
2223
from datahub.telemetry import telemetry
23-
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList
24+
from datahub.utilities.file_backed_collections import (
25+
ConnectionWrapper,
26+
FileBackedDict,
27+
)
2428

2529
logger = logging.getLogger(__name__)
2630

@@ -391,29 +395,78 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None:
391395
raise e
392396

393397

398+
def _jsonify(data: Any) -> Any:
399+
if dataclasses.is_dataclass(data):
400+
# dataclasses.asdict() is recursive. We're doing the recursion
401+
# manually here via _jsonify calls, so we can't use
402+
# dataclasses.asdict() here.
403+
return {
404+
f.name: _jsonify(getattr(data, f.name)) for f in dataclasses.fields(data)
405+
}
406+
elif isinstance(data, list):
407+
return [_jsonify(item) for item in data]
408+
elif isinstance(data, dict):
409+
return {_jsonify(k): _jsonify(v) for k, v in data.items()}
410+
elif isinstance(data, datetime):
411+
return data.isoformat()
412+
else:
413+
return data
414+
415+
394416
@check.command()
395-
@click.argument("query-log-file", type=click.Path(exists=True, dir_okay=False))
396-
@click.option("--output", type=click.Path())
397-
def extract_sql_agg_log(query_log_file: str, output: Optional[str]) -> None:
417+
@click.argument("db-file", type=click.Path(exists=True, dir_okay=False))
418+
def extract_sql_agg_log(db_file: str) -> None:
398419
"""Convert a sqlite db generated by the SqlParsingAggregator into a JSON."""
399420

400-
from datahub.sql_parsing.sql_parsing_aggregator import LoggedQuery
421+
if pathlib.Path(db_file).suffix != ".db":
422+
raise click.UsageError("DB file must be a sqlite db")
423+
424+
output_dir = pathlib.Path(db_file).with_suffix("")
425+
output_dir.mkdir(exist_ok=True)
426+
427+
shared_connection = ConnectionWrapper(pathlib.Path(db_file))
428+
429+
tables: List[str] = [
430+
row[0]
431+
for row in shared_connection.execute(
432+
"""\
433+
SELECT
434+
name
435+
FROM
436+
sqlite_schema
437+
WHERE
438+
type ='table' AND
439+
name NOT LIKE 'sqlite_%';
440+
""",
441+
parameters={},
442+
)
443+
]
444+
logger.info(f"Extracting {len(tables)} tables from {db_file}: {tables}")
445+
446+
for table in tables:
447+
table_output_path = output_dir / f"{table}.json"
448+
if table_output_path.exists():
449+
logger.info(f"Skipping {table_output_path} because it already exists")
450+
continue
401451

402-
assert dataclasses.is_dataclass(LoggedQuery)
452+
# Some of the tables might actually be FileBackedList. Because
453+
# the list is built on top of the FileBackedDict, we don't
454+
# need to distinguish between the two cases.
403455

404-
shared_connection = ConnectionWrapper(pathlib.Path(query_log_file))
405-
query_log = FileBackedList[LoggedQuery](
406-
shared_connection=shared_connection, tablename="stored_queries"
407-
)
408-
logger.info(f"Extracting {len(query_log)} queries from {query_log_file}")
409-
queries = [dataclasses.asdict(query) for query in query_log]
456+
table_data: FileBackedDict[Any] = FileBackedDict(
457+
shared_connection=shared_connection, tablename=table
458+
)
410459

411-
if output:
412-
with open(output, "w") as f:
413-
json.dump(queries, f, indent=2, default=str)
414-
logger.info(f"Extracted {len(queries)} queries to {output}")
415-
else:
416-
click.echo(json.dumps(queries, indent=2))
460+
data = {}
461+
with click.progressbar(
462+
table_data.items(), length=len(table_data), label=f"Extracting {table}"
463+
) as items:
464+
for k, v in items:
465+
data[k] = _jsonify(v)
466+
467+
with open(table_output_path, "w") as f:
468+
json.dump(data, f, indent=2, default=str)
469+
logger.info(f"Extracted {len(data)} entries to {table_output_path}")
417470

418471

419472
@check.command()
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
[
2-
{
1+
{
2+
"0": {
33
"query": "create table foo as select a, b from bar",
44
"session_id": null,
55
"timestamp": null,
66
"user": null,
77
"default_db": "dev",
88
"default_schema": "public"
99
}
10-
]
10+
}

metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,13 @@ def _ts(ts: int) -> datetime:
4141
return datetime.fromtimestamp(ts, tz=timezone.utc)
4242

4343

44-
@freeze_time(FROZEN_TIME)
45-
def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
44+
def make_basic_aggregator(store: bool = False) -> SqlParsingAggregator:
4645
aggregator = SqlParsingAggregator(
4746
platform="redshift",
4847
generate_lineage=True,
4948
generate_usage_statistics=False,
5049
generate_operations=False,
51-
query_log=QueryLogSetting.STORE_ALL,
50+
query_log=QueryLogSetting.STORE_ALL if store else QueryLogSetting.DISABLED,
5251
)
5352

5453
aggregator.add_observed_query(
@@ -59,26 +58,36 @@ def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> N
5958
)
6059
)
6160

61+
return aggregator
62+
63+
64+
@freeze_time(FROZEN_TIME)
65+
def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
66+
aggregator = make_basic_aggregator()
6267
mcps = list(aggregator.gen_metadata())
6368

6469
check_goldens_stream(
6570
outputs=mcps,
6671
golden_path=RESOURCE_DIR / "test_basic_lineage.json",
6772
)
6873

69-
# This test also validates the query log storage functionality.
74+
75+
@freeze_time(FROZEN_TIME)
76+
def test_aggregator_dump(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
77+
# Validates the query log storage + extraction functionality.
78+
aggregator = make_basic_aggregator(store=True)
7079
aggregator.close()
80+
7181
query_log_db = aggregator.report.query_log_path
72-
query_log_json = tmp_path / "query_log.json"
73-
run_datahub_cmd(
74-
[
75-
"check",
76-
"extract-sql-agg-log",
77-
str(query_log_db),
78-
"--output",
79-
str(query_log_json),
80-
]
81-
)
82+
assert query_log_db is not None
83+
84+
run_datahub_cmd(["check", "extract-sql-agg-log", query_log_db])
85+
86+
output_json_dir = pathlib.Path(query_log_db).with_suffix("")
87+
assert (
88+
len(list(output_json_dir.glob("*.json"))) > 5
89+
) # 5 is arbitrary, but should have at least a couple tables
90+
query_log_json = output_json_dir / "stored_queries.json"
8291
mce_helpers.check_golden_file(
8392
pytestconfig, query_log_json, RESOURCE_DIR / "test_basic_lineage_query_log.json"
8493
)

0 commit comments

Comments
 (0)