|
| 1 | +import logging |
1 | 2 | from collections import defaultdict
|
2 | 3 | from dataclasses import dataclass, field
|
3 | 4 | from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
|
|
6 | 7 | SagemakerSourceReport,
|
7 | 8 | )
|
8 | 9 |
|
| 10 | +logger = logging.getLogger(__name__) |
| 11 | + |
9 | 12 | if TYPE_CHECKING:
|
10 | 13 | from mypy_boto3_sagemaker import SageMakerClient
|
11 | 14 | from mypy_boto3_sagemaker.type_defs import (
|
@@ -88,7 +91,6 @@ def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
|
88 | 91 | paginator = self.sagemaker_client.get_paginator("list_contexts")
|
89 | 92 | for page in paginator.paginate():
|
90 | 93 | contexts += page["ContextSummaries"]
|
91 |
| - |
92 | 94 | return contexts
|
93 | 95 |
|
94 | 96 | def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
|
@@ -225,27 +227,32 @@ def get_lineage(self) -> LineageInfo:
|
225 | 227 | """
|
226 | 228 | Get the lineage of all artifacts in SageMaker.
|
227 | 229 | """
|
228 |
| - |
| 230 | + logger.info("Getting lineage for SageMaker artifacts...") |
| 231 | + logger.info("Getting all actions") |
229 | 232 | for action in self.get_all_actions():
|
230 | 233 | self.nodes[action["ActionArn"]] = {**action, "node_type": "action"}
|
| 234 | + logger.info("Getting all artifacts") |
231 | 235 | for artifact in self.get_all_artifacts():
|
232 | 236 | self.nodes[artifact["ArtifactArn"]] = {**artifact, "node_type": "artifact"}
|
| 237 | + logger.info("Getting all contexts") |
233 | 238 | for context in self.get_all_contexts():
|
234 | 239 | self.nodes[context["ContextArn"]] = {**context, "node_type": "context"}
|
235 | 240 |
|
| 241 | + logger.info("Getting lineage for model deployments and model groups") |
236 | 242 | for node_arn, node in self.nodes.items():
|
| 243 | + logger.debug(f"Getting lineage for node {node_arn}") |
237 | 244 | # get model-endpoint lineage
|
238 | 245 | if (
|
239 | 246 | node["node_type"] == "action"
|
240 | 247 | and node.get("ActionType") == "ModelDeployment"
|
241 | 248 | ):
|
242 | 249 | self.get_model_deployment_lineage(node_arn)
|
243 |
| - |
| 250 | + self.report.model_endpoint_lineage += 1 |
244 | 251 | # get model-group lineage
|
245 | 252 | if (
|
246 | 253 | node["node_type"] == "context"
|
247 | 254 | and node.get("ContextType") == "ModelGroup"
|
248 | 255 | ):
|
249 | 256 | self.get_model_group_lineage(node_arn, node)
|
250 |
| - |
| 257 | + self.report.model_group_lineage += 1 |
251 | 258 | return self.lineage_info
|
0 commit comments