Skip to content

Commit 9f15aa8

Browse files
committed
fix test; enhance report
1 parent b7c84b3 commit 9f15aa8

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

metadata-ingestion/src/datahub/ingestion/source/aws/glue.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
platform_name,
5353
support_status,
5454
)
55+
from datahub.ingestion.api.report import EntityFilterReport
5556
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
5657
from datahub.ingestion.api.workunit import MetadataWorkUnit
5758
from datahub.ingestion.source.aws import s3_util
@@ -219,6 +220,7 @@ def platform_validator(cls, v: str) -> str:
219220
class GlueSourceReport(StaleEntityRemovalSourceReport):
220221
tables_scanned = 0
221222
filtered: List[str] = dataclass_field(default_factory=list)
223+
databases = EntityFilterReport.field(type="database")
222224

223225
num_job_script_location_missing: int = 0
224226
num_job_script_location_invalid: int = 0
@@ -684,15 +686,15 @@ def get_all_databases(self) -> Iterable[Mapping[str, Any]]:
684686
pattern += "[?!TargetDatabase]"
685687

686688
for database in paginator_response.search(pattern):
687-
if not self.source_config.database_pattern.allowed(database["Name"]):
688-
continue
689-
if (
689+
if (not self.source_config.database_pattern.allowed(database["Name"])) or (
690690
self.source_config.catalog_id
691691
and database.get("CatalogId")
692692
and database.get("CatalogId") != self.source_config.catalog_id
693693
):
694-
continue
695-
yield database
694+
self.report.databases.dropped(database["Name"])
695+
else:
696+
self.report.databases.processed(database["Name"])
697+
yield database
696698

697699
def get_tables_from_database(self, database: Mapping[str, Any]) -> Iterable[Dict]:
698700
logger.debug(f"Getting tables from database {database['Name']}")

metadata-ingestion/tests/unit/glue/test_glue_source.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -316,26 +316,29 @@ def format_databases(databases):
316316
return set(d["Name"] for d in databases)
317317

318318
all_catalogs_source: GlueSource = GlueSource(
319-
config=GlueSourceConfig(), ctx=PipelineContext(run_id="glue-source-test")
319+
config=GlueSourceConfig(aws_region="us-west-2"),
320+
ctx=PipelineContext(run_id="glue-source-test"),
320321
)
321322
with Stubber(all_catalogs_source.glue_client) as glue_stubber:
322323
glue_stubber.add_response("get_databases", get_databases_response, {})
323324

324-
expected = format_databases([flights_database, test_database, empty_database])
325-
assert format_databases(all_catalogs_source.get_all_databases()) == expected
325+
expected = [flights_database, test_database, empty_database]
326+
actual = all_catalogs_source.get_all_databases()
327+
assert format_databases(actual) == format_databases(expected)
326328

327329
catalog_id = "123412341234"
328330
single_catalog_source = GlueSource(
329-
config=GlueSourceConfig(catalog_id=catalog_id),
331+
config=GlueSourceConfig(catalog_id=catalog_id, aws_region="us-west-2"),
330332
ctx=PipelineContext(run_id="glue-source-test"),
331333
)
332334
with Stubber(single_catalog_source.glue_client) as glue_stubber:
333335
glue_stubber.add_response(
334336
"get_databases", get_databases_response, {"CatalogId": catalog_id}
335337
)
336338

337-
expected = format_databases([flights_database, test_database])
338-
assert format_databases(single_catalog_source.get_all_databases()) == expected
339+
expected = [flights_database, test_database]
340+
actual = single_catalog_source.get_all_databases()
341+
assert format_databases(actual) == format_databases(expected)
339342

340343

341344
@freeze_time(FROZEN_TIME)

0 commit comments

Comments
 (0)