Skip to content

Commit 95b9d1b

Browse files
authored
feat(ingest/aws-common): improved instance profile support (#12139)
for ec2, ecs, eks, lambda, beanstalk, app runner and cft roles
1 parent 0b4d96e commit 95b9d1b

File tree

2 files changed

+559
-27
lines changed

2 files changed

+559
-27
lines changed

metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py

+231-27
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import logging
2+
import os
13
from datetime import datetime, timedelta, timezone
2-
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
4+
from enum import Enum
5+
from http import HTTPStatus
6+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
37

48
import boto3
9+
import requests
510
from boto3.session import Session
611
from botocore.config import DEFAULT_TIMEOUT, Config
712
from botocore.utils import fix_s3_host
@@ -14,6 +19,8 @@
1419
)
1520
from datahub.configuration.source_common import EnvConfigMixin
1621

22+
logger = logging.getLogger(__name__)
23+
1724
if TYPE_CHECKING:
1825
from mypy_boto3_dynamodb import DynamoDBClient
1926
from mypy_boto3_glue import GlueClient
@@ -22,6 +29,26 @@
2229
from mypy_boto3_sts import STSClient
2330

2431

32+
class AwsEnvironment(Enum):
33+
EC2 = "EC2"
34+
ECS = "ECS"
35+
EKS = "EKS"
36+
LAMBDA = "LAMBDA"
37+
APP_RUNNER = "APP_RUNNER"
38+
BEANSTALK = "ELASTIC_BEANSTALK"
39+
CLOUD_FORMATION = "CLOUD_FORMATION"
40+
UNKNOWN = "UNKNOWN"
41+
42+
43+
class AwsServicePrincipal(Enum):
44+
LAMBDA = "lambda.amazonaws.com"
45+
EKS = "eks.amazonaws.com"
46+
APP_RUNNER = "apprunner.amazonaws.com"
47+
ECS = "ecs.amazonaws.com"
48+
ELASTIC_BEANSTALK = "elasticbeanstalk.amazonaws.com"
49+
EC2 = "ec2.amazonaws.com"
50+
51+
2552
class AwsAssumeRoleConfig(PermissiveConfigModel):
2653
# Using the PermissiveConfigModel to allow the user to pass additional arguments.
2754

@@ -34,6 +61,163 @@ class AwsAssumeRoleConfig(PermissiveConfigModel):
3461
)
3562

3663

64+
def get_instance_metadata_token() -> Optional[str]:
65+
"""Get IMDSv2 token"""
66+
try:
67+
response = requests.put(
68+
"http://169.254.169.254/latest/api/token",
69+
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
70+
timeout=1,
71+
)
72+
if response.status_code == HTTPStatus.OK:
73+
return response.text
74+
except requests.exceptions.RequestException:
75+
logger.debug("Failed to get IMDSv2 token")
76+
return None
77+
78+
79+
def is_running_on_ec2() -> bool:
80+
"""Check if code is running on EC2 using IMDSv2"""
81+
token = get_instance_metadata_token()
82+
if not token:
83+
return False
84+
85+
try:
86+
response = requests.get(
87+
"http://169.254.169.254/latest/meta-data/instance-id",
88+
headers={"X-aws-ec2-metadata-token": token},
89+
timeout=1,
90+
)
91+
return response.status_code == HTTPStatus.OK
92+
except requests.exceptions.RequestException:
93+
return False
94+
95+
96+
def detect_aws_environment() -> AwsEnvironment:
97+
"""
98+
Detect the AWS environment we're running in.
99+
Order matters as some environments may have multiple indicators.
100+
"""
101+
# Check Lambda first as it's most specific
102+
if os.getenv("AWS_LAMBDA_FUNCTION_NAME"):
103+
if os.getenv("AWS_EXECUTION_ENV", "").startswith("CloudFormation"):
104+
return AwsEnvironment.CLOUD_FORMATION
105+
return AwsEnvironment.LAMBDA
106+
107+
# Check EKS (IRSA)
108+
if os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE") and os.getenv("AWS_ROLE_ARN"):
109+
return AwsEnvironment.EKS
110+
111+
# Check App Runner
112+
if os.getenv("AWS_APP_RUNNER_SERVICE_ID"):
113+
return AwsEnvironment.APP_RUNNER
114+
115+
# Check ECS
116+
if os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv(
117+
"ECS_CONTAINER_METADATA_URI"
118+
):
119+
return AwsEnvironment.ECS
120+
121+
# Check Elastic Beanstalk
122+
if os.getenv("ELASTIC_BEANSTALK_ENVIRONMENT_NAME"):
123+
return AwsEnvironment.BEANSTALK
124+
125+
if is_running_on_ec2():
126+
return AwsEnvironment.EC2
127+
128+
return AwsEnvironment.UNKNOWN
129+
130+
131+
def get_instance_role_arn() -> Optional[str]:
132+
"""Get role ARN from EC2 instance metadata using IMDSv2"""
133+
token = get_instance_metadata_token()
134+
if not token:
135+
return None
136+
137+
try:
138+
response = requests.get(
139+
"http://169.254.169.254/latest/meta-data/iam/security-credentials/",
140+
headers={"X-aws-ec2-metadata-token": token},
141+
timeout=1,
142+
)
143+
if response.status_code == 200:
144+
role_name = response.text.strip()
145+
if role_name:
146+
sts = boto3.client("sts")
147+
identity = sts.get_caller_identity()
148+
return identity.get("Arn")
149+
except Exception as e:
150+
logger.debug(f"Failed to get instance role ARN: {e}")
151+
return None
152+
153+
154+
def get_lambda_role_arn() -> Optional[str]:
155+
"""Get the Lambda function's role ARN"""
156+
try:
157+
function_name = os.getenv("AWS_LAMBDA_FUNCTION_NAME")
158+
if not function_name:
159+
return None
160+
161+
lambda_client = boto3.client("lambda")
162+
function_config = lambda_client.get_function_configuration(
163+
FunctionName=function_name
164+
)
165+
return function_config.get("Role")
166+
except Exception as e:
167+
logger.debug(f"Failed to get Lambda role ARN: {e}")
168+
return None
169+
170+
171+
def get_current_identity() -> Tuple[Optional[str], Optional[str]]:
172+
"""
173+
Get the current role ARN and source type based on the runtime environment.
174+
Returns (role_arn, credential_source)
175+
"""
176+
env = detect_aws_environment()
177+
178+
if env == AwsEnvironment.LAMBDA:
179+
role_arn = get_lambda_role_arn()
180+
return role_arn, AwsServicePrincipal.LAMBDA.value
181+
182+
elif env == AwsEnvironment.EKS:
183+
role_arn = os.getenv("AWS_ROLE_ARN")
184+
return role_arn, AwsServicePrincipal.EKS.value
185+
186+
elif env == AwsEnvironment.APP_RUNNER:
187+
try:
188+
sts = boto3.client("sts")
189+
identity = sts.get_caller_identity()
190+
return identity.get("Arn"), AwsServicePrincipal.APP_RUNNER.value
191+
except Exception as e:
192+
logger.debug(f"Failed to get App Runner role: {e}")
193+
194+
elif env == AwsEnvironment.ECS:
195+
try:
196+
metadata_uri = os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv(
197+
"ECS_CONTAINER_METADATA_URI"
198+
)
199+
if metadata_uri:
200+
response = requests.get(f"{metadata_uri}/task", timeout=1)
201+
if response.status_code == HTTPStatus.OK:
202+
task_metadata = response.json()
203+
if "TaskARN" in task_metadata:
204+
return (
205+
task_metadata.get("TaskARN"),
206+
AwsServicePrincipal.ECS.value,
207+
)
208+
except Exception as e:
209+
logger.debug(f"Failed to get ECS task role: {e}")
210+
211+
elif env == AwsEnvironment.BEANSTALK:
212+
# Beanstalk uses EC2 instance metadata
213+
return get_instance_role_arn(), AwsServicePrincipal.ELASTIC_BEANSTALK.value
214+
215+
elif env == AwsEnvironment.EC2:
216+
return get_instance_role_arn(), AwsServicePrincipal.EC2.value
217+
218+
return None, None
219+
220+
37221
def assume_role(
38222
role: AwsAssumeRoleConfig,
39223
aws_region: Optional[str],
@@ -95,7 +279,7 @@ class AwsConnectionConfig(ConfigModel):
95279
)
96280
aws_profile: Optional[str] = Field(
97281
default=None,
98-
description="Named AWS profile to use. Only used if access key / secret are unset. If not set the default will be used",
282+
description="The [named profile](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html) to use from AWS credentials. Falls back to default profile if not specified and no access keys provided. Profiles are configured in ~/.aws/credentials or ~/.aws/config.",
99283
)
100284
aws_region: Optional[str] = Field(None, description="AWS region code.")
101285

@@ -145,45 +329,65 @@ def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:
145329

146330
def get_session(self) -> Session:
147331
if self.aws_access_key_id and self.aws_secret_access_key:
332+
# Explicit credentials take precedence
148333
session = Session(
149334
aws_access_key_id=self.aws_access_key_id,
150335
aws_secret_access_key=self.aws_secret_access_key,
151336
aws_session_token=self.aws_session_token,
152337
region_name=self.aws_region,
153338
)
154339
elif self.aws_profile:
340+
# Named profile is second priority
155341
session = Session(
156342
region_name=self.aws_region, profile_name=self.aws_profile
157343
)
158344
else:
159-
# Use boto3's credential autodetection.
345+
# Use boto3's credential autodetection
160346
session = Session(region_name=self.aws_region)
161347

162-
if self._normalized_aws_roles():
163-
# Use existing session credentials to start the chain of role assumption.
164-
current_credentials = session.get_credentials()
165-
credentials = {
166-
"AccessKeyId": current_credentials.access_key,
167-
"SecretAccessKey": current_credentials.secret_key,
168-
"SessionToken": current_credentials.token,
169-
}
170-
171-
for role in self._normalized_aws_roles():
172-
if self._should_refresh_credentials():
173-
credentials = assume_role(
174-
role,
175-
self.aws_region,
176-
credentials=credentials,
348+
target_roles = self._normalized_aws_roles()
349+
if target_roles:
350+
current_role_arn, credential_source = get_current_identity()
351+
352+
# Only assume role if:
353+
# 1. We're not in a known AWS environment with a role, or
354+
# 2. We need to assume a different role than our current one
355+
should_assume_role = current_role_arn is None or any(
356+
role.RoleArn != current_role_arn for role in target_roles
357+
)
358+
359+
if should_assume_role:
360+
env = detect_aws_environment()
361+
logger.debug(f"Assuming role(s) from {env.value} environment")
362+
363+
current_credentials = session.get_credentials()
364+
if current_credentials is None:
365+
raise ValueError("No credentials available for role assumption")
366+
367+
credentials = {
368+
"AccessKeyId": current_credentials.access_key,
369+
"SecretAccessKey": current_credentials.secret_key,
370+
"SessionToken": current_credentials.token,
371+
}
372+
373+
for role in target_roles:
374+
if self._should_refresh_credentials():
375+
credentials = assume_role(
376+
role=role,
377+
aws_region=self.aws_region,
378+
credentials=credentials,
379+
)
380+
if isinstance(credentials["Expiration"], datetime):
381+
self._credentials_expiration = credentials["Expiration"]
382+
383+
session = Session(
384+
aws_access_key_id=credentials["AccessKeyId"],
385+
aws_secret_access_key=credentials["SecretAccessKey"],
386+
aws_session_token=credentials["SessionToken"],
387+
region_name=self.aws_region,
177388
)
178-
if isinstance(credentials["Expiration"], datetime):
179-
self._credentials_expiration = credentials["Expiration"]
180-
181-
session = Session(
182-
aws_access_key_id=credentials["AccessKeyId"],
183-
aws_secret_access_key=credentials["SecretAccessKey"],
184-
aws_session_token=credentials["SessionToken"],
185-
region_name=self.aws_region,
186-
)
389+
else:
390+
logger.debug(f"Using existing role from {credential_source}")
187391

188392
return session
189393

0 commit comments

Comments
 (0)