Skip to content

Commit eb7c0cf

Browse files
committed
Adding option to control retry for any aws source
Adding logs to make it transparent what is going on in SageMaker
1 parent eef2077 commit eb7c0cf

File tree

5 files changed

+49
-5
lines changed

5 files changed

+49
-5
lines changed

metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py

+12
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ class AwsConnectionConfig(ConfigModel):
107107
default=None,
108108
description="A set of proxy configs to use with AWS. See the [botocore.config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html) docs for details.",
109109
)
110+
aws_retry_num: int = Field(
111+
default=5,
112+
description="Number of times to retry failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
113+
)
114+
aws_retry_mode: Literal["legacy", "standard", "adaptive"] = Field(
115+
default="standard",
116+
description="Retry mode to use for failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
117+
)
110118

111119
read_timeout: float = Field(
112120
default=DEFAULT_TIMEOUT,
@@ -199,6 +207,10 @@ def _aws_config(self) -> Config:
199207
return Config(
200208
proxies=self.aws_proxy,
201209
read_timeout=self.read_timeout,
210+
retries={
211+
"max_attempts": self.aws_retry_num,
212+
"mode": self.aws_retry_mode,
213+
},
202214
**self.aws_advanced_config,
203215
)
204216

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections import defaultdict
23
from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional
34

@@ -36,6 +37,8 @@
3637
if TYPE_CHECKING:
3738
from mypy_boto3_sagemaker import SageMakerClient
3839

40+
logger = logging.getLogger(__name__)
41+
3942

4043
@platform_name("SageMaker")
4144
@config_class(SagemakerSourceConfig)
@@ -75,6 +78,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
7578
]
7679

7780
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
81+
logger.info("Starting SageMaker ingestion...")
7882
# get common lineage graph
7983
lineage_processor = LineageProcessor(
8084
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
@@ -83,6 +87,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
8387

8488
# extract feature groups if specified
8589
if self.source_config.extract_feature_groups:
90+
logger.info("Extracting feature groups...")
8691
feature_group_processor = FeatureGroupProcessor(
8792
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
8893
)
@@ -95,6 +100,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
95100

96101
# extract jobs if specified
97102
if self.source_config.extract_jobs is not False:
103+
logger.info("Extracting jobs...")
98104
job_processor = JobProcessor(
99105
sagemaker_client=self.client_factory.get_client,
100106
env=self.env,
@@ -109,6 +115,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
109115

110116
# extract models if specified
111117
if self.source_config.extract_models:
118+
logger.info("Extracting models...")
119+
112120
model_processor = ModelProcessor(
113121
sagemaker_client=self.sagemaker_client,
114122
env=self.env,

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py

+6
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@ class SagemakerSourceReport(StaleEntityRemovalSourceReport):
4040
groups_scanned = 0
4141
models_scanned = 0
4242
jobs_scanned = 0
43+
jobs_processed = 0
4344
datasets_scanned = 0
4445
filtered: List[str] = field(default_factory=list)
46+
model_endpoint_lineage = 0
47+
model_group_lineage = 0
4548

4649
def report_feature_group_scanned(self) -> None:
4750
self.feature_groups_scanned += 1
@@ -58,6 +61,9 @@ def report_group_scanned(self) -> None:
5861
def report_model_scanned(self) -> None:
5962
self.models_scanned += 1
6063

64+
def report_job_processed(self) -> None:
65+
self.jobs_processed += 1
66+
6167
def report_job_scanned(self) -> None:
6268
self.jobs_scanned += 1
6369

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections import defaultdict
23
from dataclasses import dataclass, field
34
from enum import Enum
@@ -49,6 +50,8 @@
4950
if TYPE_CHECKING:
5051
from mypy_boto3_sagemaker import SageMakerClient
5152

53+
logger = logging.getLogger(__name__)
54+
5255
JobInfo = TypeVar(
5356
"JobInfo",
5457
AutoMlJobInfo,
@@ -274,15 +277,18 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
274277
)
275278

276279
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
280+
logger.info("Getting all SageMaker jobs")
277281
jobs = self.get_all_jobs()
278282

279283
processed_jobs: Dict[str, SageMakerJob] = {}
280284

285+
logger.info("Processing SageMaker jobs")
281286
# first pass: process jobs and collect datasets used
287+
logger.info("first pass: process jobs and collect datasets used")
282288
for job in jobs:
283289
job_type = job_type_to_info[job["type"]]
284290
job_name = job[job_type.list_name_key]
285-
291+
logger.debug(f"Processing job {job_name} with type {job_type}")
286292
job_details = self.get_job_details(job_name, job["type"])
287293

288294
processed_job = getattr(self, job_type.processor)(job_details)
@@ -293,6 +299,9 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
293299
# second pass:
294300
# - move output jobs to inputs
295301
# - aggregate i/o datasets
302+
logger.info(
303+
"second pass: move output jobs to inputs and aggregate i/o datasets"
304+
)
296305
for job_urn in sorted(processed_jobs):
297306
processed_job = processed_jobs[job_urn]
298307

@@ -301,6 +310,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
301310

302311
all_datasets.update(processed_job.input_datasets)
303312
all_datasets.update(processed_job.output_datasets)
313+
self.report.report_job_processed()
304314

305315
# yield datasets
306316
for dataset_urn, dataset in all_datasets.items():
@@ -322,6 +332,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
322332
self.report.report_dataset_scanned()
323333

324334
# third pass: construct and yield MCEs
335+
logger.info("third pass: construct and yield MCEs")
325336
for job_urn in sorted(processed_jobs):
326337
processed_job = processed_jobs[job_urn]
327338
job_snapshot = processed_job.job_snapshot

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections import defaultdict
23
from dataclasses import dataclass, field
34
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
@@ -6,6 +7,8 @@
67
SagemakerSourceReport,
78
)
89

10+
logger = logging.getLogger(__name__)
11+
912
if TYPE_CHECKING:
1013
from mypy_boto3_sagemaker import SageMakerClient
1114
from mypy_boto3_sagemaker.type_defs import (
@@ -88,7 +91,6 @@ def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
8891
paginator = self.sagemaker_client.get_paginator("list_contexts")
8992
for page in paginator.paginate():
9093
contexts += page["ContextSummaries"]
91-
9294
return contexts
9395

9496
def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
@@ -225,27 +227,32 @@ def get_lineage(self) -> LineageInfo:
225227
"""
226228
Get the lineage of all artifacts in SageMaker.
227229
"""
228-
230+
logger.info("Getting lineage for SageMaker artifacts...")
231+
logger.info("Getting all actions")
229232
for action in self.get_all_actions():
230233
self.nodes[action["ActionArn"]] = {**action, "node_type": "action"}
234+
logger.info("Getting all artifacts")
231235
for artifact in self.get_all_artifacts():
232236
self.nodes[artifact["ArtifactArn"]] = {**artifact, "node_type": "artifact"}
237+
logger.info("Getting all contexts")
233238
for context in self.get_all_contexts():
234239
self.nodes[context["ContextArn"]] = {**context, "node_type": "context"}
235240

241+
logger.info("Getting lineage for model deployments and model groups")
236242
for node_arn, node in self.nodes.items():
243+
logger.debug(f"Getting lineage for node {node_arn}")
237244
# get model-endpoint lineage
238245
if (
239246
node["node_type"] == "action"
240247
and node.get("ActionType") == "ModelDeployment"
241248
):
242249
self.get_model_deployment_lineage(node_arn)
243-
250+
self.report.model_endpoint_lineage += 1
244251
# get model-group lineage
245252
if (
246253
node["node_type"] == "context"
247254
and node.get("ContextType") == "ModelGroup"
248255
):
249256
self.get_model_group_lineage(node_arn, node)
250-
257+
self.report.model_group_lineage += 1
251258
return self.lineage_info

0 commit comments

Comments
 (0)