Skip to content

Commit 53a205c

Browse files
authored
Merge branch 'master' into cus3379-tableau-ingestion-node-limit-exceeded
2 parents bb01bc3 + 4811de1 commit 53a205c

36 files changed

+116085
-292
lines changed

metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime, timedelta, timezone
2-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
33

44
import boto3
55
from boto3.session import Session
@@ -107,6 +107,14 @@ class AwsConnectionConfig(ConfigModel):
107107
default=None,
108108
description="A set of proxy configs to use with AWS. See the [botocore.config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html) docs for details.",
109109
)
110+
aws_retry_num: int = Field(
111+
default=5,
112+
description="Number of times to retry failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
113+
)
114+
aws_retry_mode: Literal["legacy", "standard", "adaptive"] = Field(
115+
default="standard",
116+
description="Retry mode to use for failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
117+
)
110118

111119
read_timeout: float = Field(
112120
default=DEFAULT_TIMEOUT,
@@ -199,6 +207,10 @@ def _aws_config(self) -> Config:
199207
return Config(
200208
proxies=self.aws_proxy,
201209
read_timeout=self.read_timeout,
210+
retries={
211+
"max_attempts": self.aws_retry_num,
212+
"mode": self.aws_retry_mode,
213+
},
202214
**self.aws_advanced_config,
203215
)
204216

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections import defaultdict
23
from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional
34

@@ -36,6 +37,8 @@
3637
if TYPE_CHECKING:
3738
from mypy_boto3_sagemaker import SageMakerClient
3839

40+
logger = logging.getLogger(__name__)
41+
3942

4043
@platform_name("SageMaker")
4144
@config_class(SagemakerSourceConfig)
@@ -75,6 +78,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
7578
]
7679

7780
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
81+
logger.info("Starting SageMaker ingestion...")
7882
# get common lineage graph
7983
lineage_processor = LineageProcessor(
8084
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
@@ -83,6 +87,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
8387

8488
# extract feature groups if specified
8589
if self.source_config.extract_feature_groups:
90+
logger.info("Extracting feature groups...")
8691
feature_group_processor = FeatureGroupProcessor(
8792
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
8893
)
@@ -95,6 +100,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
95100

96101
# extract jobs if specified
97102
if self.source_config.extract_jobs is not False:
103+
logger.info("Extracting jobs...")
98104
job_processor = JobProcessor(
99105
sagemaker_client=self.client_factory.get_client,
100106
env=self.env,
@@ -109,6 +115,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
109115

110116
# extract models if specified
111117
if self.source_config.extract_models:
118+
logger.info("Extracting models...")
119+
112120
model_processor = ModelProcessor(
113121
sagemaker_client=self.sagemaker_client,
114122
env=self.env,

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py

+6
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@ class SagemakerSourceReport(StaleEntityRemovalSourceReport):
4040
groups_scanned = 0
4141
models_scanned = 0
4242
jobs_scanned = 0
43+
jobs_processed = 0
4344
datasets_scanned = 0
4445
filtered: List[str] = field(default_factory=list)
46+
model_endpoint_lineage = 0
47+
model_group_lineage = 0
4548

4649
def report_feature_group_scanned(self) -> None:
4750
self.feature_groups_scanned += 1
@@ -58,6 +61,9 @@ def report_group_scanned(self) -> None:
5861
def report_model_scanned(self) -> None:
5962
self.models_scanned += 1
6063

64+
def report_job_processed(self) -> None:
65+
self.jobs_processed += 1
66+
6167
def report_job_scanned(self) -> None:
6268
self.jobs_scanned += 1
6369

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections import defaultdict
23
from dataclasses import dataclass, field
34
from enum import Enum
@@ -49,6 +50,8 @@
4950
if TYPE_CHECKING:
5051
from mypy_boto3_sagemaker import SageMakerClient
5152

53+
logger = logging.getLogger(__name__)
54+
5255
JobInfo = TypeVar(
5356
"JobInfo",
5457
AutoMlJobInfo,
@@ -274,15 +277,18 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
274277
)
275278

276279
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
280+
logger.info("Getting all SageMaker jobs")
277281
jobs = self.get_all_jobs()
278282

279283
processed_jobs: Dict[str, SageMakerJob] = {}
280284

285+
logger.info("Processing SageMaker jobs")
281286
# first pass: process jobs and collect datasets used
287+
logger.info("first pass: process jobs and collect datasets used")
282288
for job in jobs:
283289
job_type = job_type_to_info[job["type"]]
284290
job_name = job[job_type.list_name_key]
285-
291+
logger.debug(f"Processing job {job_name} with type {job_type}")
286292
job_details = self.get_job_details(job_name, job["type"])
287293

288294
processed_job = getattr(self, job_type.processor)(job_details)
@@ -293,6 +299,9 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
293299
# second pass:
294300
# - move output jobs to inputs
295301
# - aggregate i/o datasets
302+
logger.info(
303+
"second pass: move output jobs to inputs and aggregate i/o datasets"
304+
)
296305
for job_urn in sorted(processed_jobs):
297306
processed_job = processed_jobs[job_urn]
298307

@@ -301,6 +310,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
301310

302311
all_datasets.update(processed_job.input_datasets)
303312
all_datasets.update(processed_job.output_datasets)
313+
self.report.report_job_processed()
304314

305315
# yield datasets
306316
for dataset_urn, dataset in all_datasets.items():
@@ -322,6 +332,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
322332
self.report.report_dataset_scanned()
323333

324334
# third pass: construct and yield MCEs
335+
logger.info("third pass: construct and yield MCEs")
325336
for job_urn in sorted(processed_jobs):
326337
processed_job = processed_jobs[job_urn]
327338
job_snapshot = processed_job.job_snapshot

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections import defaultdict
23
from dataclasses import dataclass, field
34
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
@@ -6,6 +7,8 @@
67
SagemakerSourceReport,
78
)
89

10+
logger = logging.getLogger(__name__)
11+
912
if TYPE_CHECKING:
1013
from mypy_boto3_sagemaker import SageMakerClient
1114
from mypy_boto3_sagemaker.type_defs import (
@@ -88,7 +91,6 @@ def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
8891
paginator = self.sagemaker_client.get_paginator("list_contexts")
8992
for page in paginator.paginate():
9093
contexts += page["ContextSummaries"]
91-
9294
return contexts
9395

9496
def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
@@ -225,27 +227,32 @@ def get_lineage(self) -> LineageInfo:
225227
"""
226228
Get the lineage of all artifacts in SageMaker.
227229
"""
228-
230+
logger.info("Getting lineage for SageMaker artifacts...")
231+
logger.info("Getting all actions")
229232
for action in self.get_all_actions():
230233
self.nodes[action["ActionArn"]] = {**action, "node_type": "action"}
234+
logger.info("Getting all artifacts")
231235
for artifact in self.get_all_artifacts():
232236
self.nodes[artifact["ArtifactArn"]] = {**artifact, "node_type": "artifact"}
237+
logger.info("Getting all contexts")
233238
for context in self.get_all_contexts():
234239
self.nodes[context["ContextArn"]] = {**context, "node_type": "context"}
235240

241+
logger.info("Getting lineage for model deployments and model groups")
236242
for node_arn, node in self.nodes.items():
243+
logger.debug(f"Getting lineage for node {node_arn}")
237244
# get model-endpoint lineage
238245
if (
239246
node["node_type"] == "action"
240247
and node.get("ActionType") == "ModelDeployment"
241248
):
242249
self.get_model_deployment_lineage(node_arn)
243-
250+
self.report.model_endpoint_lineage += 1
244251
# get model-group lineage
245252
if (
246253
node["node_type"] == "context"
247254
and node.get("ContextType") == "ModelGroup"
248255
):
249256
self.get_model_group_lineage(node_arn, node)
250-
257+
self.report.model_group_lineage += 1
251258
return self.lineage_info

metadata-ingestion/src/datahub/ingestion/source/gc/dataprocess_cleanup.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ def fetch_dpis(self, job_urn: str, batch_size: int) -> List[dict]:
207207
assert self.ctx.graph
208208
dpis = []
209209
start = 0
210+
# This graphql endpoint doesn't support scrolling and therefore after 10k DPIs it causes performance issues on ES
211+
# Therefore, we are limiting the max DPIs to 9000
212+
max_item = 9000
210213
while True:
211214
try:
212215
job_query_result = self.ctx.graph.execute_graphql(
@@ -226,10 +229,12 @@ def fetch_dpis(self, job_urn: str, batch_size: int) -> List[dict]:
226229
runs = runs_data.get("runs")
227230
dpis.extend(runs)
228231
start += batch_size
229-
if len(runs) < batch_size:
232+
if len(runs) < batch_size or start >= max_item:
230233
break
231234
except Exception as e:
232-
logger.error(f"Exception while fetching DPIs for job {job_urn}: {e}")
235+
self.report.failure(
236+
f"Exception while fetching DPIs for job {job_urn}:", exc=e
237+
)
233238
break
234239
return dpis
235240

@@ -254,8 +259,9 @@ def keep_last_n_dpi(
254259
deleted_count_last_n += 1
255260
futures[future]["deleted"] = True
256261
except Exception as e:
257-
logger.error(f"Exception while deleting DPI: {e}")
258-
262+
self.report.report_failure(
263+
f"Exception while deleting DPI: {e}", exc=e
264+
)
259265
if deleted_count_last_n % self.config.batch_size == 0:
260266
logger.info(f"Deleted {deleted_count_last_n} DPIs from {job.urn}")
261267
if self.config.delay:
@@ -289,7 +295,7 @@ def delete_dpi_from_datajobs(self, job: DataJobEntity) -> None:
289295
dpis = self.fetch_dpis(job.urn, self.config.batch_size)
290296
dpis.sort(
291297
key=lambda x: x["created"]["time"]
292-
if "created" in x and "time" in x["created"]
298+
if x.get("created") and x["created"].get("time")
293299
else 0,
294300
reverse=True,
295301
)
@@ -325,8 +331,8 @@ def remove_old_dpis(
325331
continue
326332

327333
if (
328-
"created" not in dpi
329-
or "time" not in dpi["created"]
334+
not dpi.get("created")
335+
or not dpi["created"].get("time")
330336
or dpi["created"]["time"] < retention_time * 1000
331337
):
332338
future = executor.submit(
@@ -340,7 +346,7 @@ def remove_old_dpis(
340346
deleted_count_retention += 1
341347
futures[future]["deleted"] = True
342348
except Exception as e:
343-
logger.error(f"Exception while deleting DPI: {e}")
349+
self.report.report_failure(f"Exception while deleting DPI: {e}", exc=e)
344350

345351
if deleted_count_retention % self.config.batch_size == 0:
346352
logger.info(
@@ -351,9 +357,12 @@ def remove_old_dpis(
351357
logger.info(f"Sleeping for {self.config.delay} seconds")
352358
time.sleep(self.config.delay)
353359

354-
logger.info(
355-
f"Deleted {deleted_count_retention} DPIs from {job.urn} due to retention"
356-
)
360+
if deleted_count_retention > 0:
361+
logger.info(
362+
f"Deleted {deleted_count_retention} DPIs from {job.urn} due to retention"
363+
)
364+
else:
365+
logger.debug(f"No DPIs to delete from {job.urn} due to retention")
357366

358367
def get_data_flows(self) -> Iterable[DataFlowEntity]:
359368
assert self.ctx.graph

metadata-ingestion/src/datahub/ingestion/source/powerbi/m_query/data_classes.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
from abc import ABC
32
from dataclasses import dataclass
43
from enum import Enum
54
from typing import Any, Dict, List, Optional
@@ -12,18 +11,8 @@
1211
TRACE_POWERBI_MQUERY_PARSER = os.getenv("DATAHUB_TRACE_POWERBI_MQUERY_PARSER", False)
1312

1413

15-
class AbstractIdentifierAccessor(ABC): # To pass lint
16-
pass
17-
18-
19-
# @dataclass
20-
# class ItemSelector:
21-
# items: Dict[str, Any]
22-
# next: Optional[AbstractIdentifierAccessor]
23-
24-
2514
@dataclass
26-
class IdentifierAccessor(AbstractIdentifierAccessor):
15+
class IdentifierAccessor:
2716
"""
2817
statement
2918
public_order_date = Source{[Schema="public",Item="order_date"]}[Data]
@@ -40,7 +29,7 @@ class IdentifierAccessor(AbstractIdentifierAccessor):
4029

4130
identifier: str
4231
items: Dict[str, Any]
43-
next: Optional[AbstractIdentifierAccessor]
32+
next: Optional["IdentifierAccessor"]
4433

4534

4635
@dataclass

0 commit comments

Comments
 (0)