1
+ import logging
2
+ import os
1
3
from datetime import datetime , timedelta , timezone
2
- from typing import TYPE_CHECKING , Any , Dict , List , Literal , Optional , Union
4
+ from enum import Enum
5
+ from http import HTTPStatus
6
+ from typing import TYPE_CHECKING , Any , Dict , List , Literal , Optional , Tuple , Union
3
7
4
8
import boto3
9
+ import requests
5
10
from boto3 .session import Session
6
11
from botocore .config import DEFAULT_TIMEOUT , Config
7
12
from botocore .utils import fix_s3_host
14
19
)
15
20
from datahub .configuration .source_common import EnvConfigMixin
16
21
22
+ logger = logging .getLogger (__name__ )
23
+
17
24
if TYPE_CHECKING :
18
25
from mypy_boto3_dynamodb import DynamoDBClient
19
26
from mypy_boto3_glue import GlueClient
22
29
from mypy_boto3_sts import STSClient
23
30
24
31
32
+ class AwsEnvironment (Enum ):
33
+ EC2 = "EC2"
34
+ ECS = "ECS"
35
+ EKS = "EKS"
36
+ LAMBDA = "LAMBDA"
37
+ APP_RUNNER = "APP_RUNNER"
38
+ BEANSTALK = "ELASTIC_BEANSTALK"
39
+ CLOUD_FORMATION = "CLOUD_FORMATION"
40
+ UNKNOWN = "UNKNOWN"
41
+
42
+
43
+ class AwsServicePrincipal (Enum ):
44
+ LAMBDA = "lambda.amazonaws.com"
45
+ EKS = "eks.amazonaws.com"
46
+ APP_RUNNER = "apprunner.amazonaws.com"
47
+ ECS = "ecs.amazonaws.com"
48
+ ELASTIC_BEANSTALK = "elasticbeanstalk.amazonaws.com"
49
+ EC2 = "ec2.amazonaws.com"
50
+
51
+
25
52
class AwsAssumeRoleConfig (PermissiveConfigModel ):
26
53
# Using the PermissiveConfigModel to allow the user to pass additional arguments.
27
54
@@ -34,6 +61,163 @@ class AwsAssumeRoleConfig(PermissiveConfigModel):
34
61
)
35
62
36
63
64
+ def get_instance_metadata_token () -> Optional [str ]:
65
+ """Get IMDSv2 token"""
66
+ try :
67
+ response = requests .put (
68
+ "http://169.254.169.254/latest/api/token" ,
69
+ headers = {"X-aws-ec2-metadata-token-ttl-seconds" : "21600" },
70
+ timeout = 1 ,
71
+ )
72
+ if response .status_code == HTTPStatus .OK :
73
+ return response .text
74
+ except requests .exceptions .RequestException :
75
+ logger .debug ("Failed to get IMDSv2 token" )
76
+ return None
77
+
78
+
79
+ def is_running_on_ec2 () -> bool :
80
+ """Check if code is running on EC2 using IMDSv2"""
81
+ token = get_instance_metadata_token ()
82
+ if not token :
83
+ return False
84
+
85
+ try :
86
+ response = requests .get (
87
+ "http://169.254.169.254/latest/meta-data/instance-id" ,
88
+ headers = {"X-aws-ec2-metadata-token" : token },
89
+ timeout = 1 ,
90
+ )
91
+ return response .status_code == HTTPStatus .OK
92
+ except requests .exceptions .RequestException :
93
+ return False
94
+
95
+
96
+ def detect_aws_environment () -> AwsEnvironment :
97
+ """
98
+ Detect the AWS environment we're running in.
99
+ Order matters as some environments may have multiple indicators.
100
+ """
101
+ # Check Lambda first as it's most specific
102
+ if os .getenv ("AWS_LAMBDA_FUNCTION_NAME" ):
103
+ if os .getenv ("AWS_EXECUTION_ENV" , "" ).startswith ("CloudFormation" ):
104
+ return AwsEnvironment .CLOUD_FORMATION
105
+ return AwsEnvironment .LAMBDA
106
+
107
+ # Check EKS (IRSA)
108
+ if os .getenv ("AWS_WEB_IDENTITY_TOKEN_FILE" ) and os .getenv ("AWS_ROLE_ARN" ):
109
+ return AwsEnvironment .EKS
110
+
111
+ # Check App Runner
112
+ if os .getenv ("AWS_APP_RUNNER_SERVICE_ID" ):
113
+ return AwsEnvironment .APP_RUNNER
114
+
115
+ # Check ECS
116
+ if os .getenv ("ECS_CONTAINER_METADATA_URI_V4" ) or os .getenv (
117
+ "ECS_CONTAINER_METADATA_URI"
118
+ ):
119
+ return AwsEnvironment .ECS
120
+
121
+ # Check Elastic Beanstalk
122
+ if os .getenv ("ELASTIC_BEANSTALK_ENVIRONMENT_NAME" ):
123
+ return AwsEnvironment .BEANSTALK
124
+
125
+ if is_running_on_ec2 ():
126
+ return AwsEnvironment .EC2
127
+
128
+ return AwsEnvironment .UNKNOWN
129
+
130
+
131
+ def get_instance_role_arn () -> Optional [str ]:
132
+ """Get role ARN from EC2 instance metadata using IMDSv2"""
133
+ token = get_instance_metadata_token ()
134
+ if not token :
135
+ return None
136
+
137
+ try :
138
+ response = requests .get (
139
+ "http://169.254.169.254/latest/meta-data/iam/security-credentials/" ,
140
+ headers = {"X-aws-ec2-metadata-token" : token },
141
+ timeout = 1 ,
142
+ )
143
+ if response .status_code == 200 :
144
+ role_name = response .text .strip ()
145
+ if role_name :
146
+ sts = boto3 .client ("sts" )
147
+ identity = sts .get_caller_identity ()
148
+ return identity .get ("Arn" )
149
+ except Exception as e :
150
+ logger .debug (f"Failed to get instance role ARN: { e } " )
151
+ return None
152
+
153
+
154
+ def get_lambda_role_arn () -> Optional [str ]:
155
+ """Get the Lambda function's role ARN"""
156
+ try :
157
+ function_name = os .getenv ("AWS_LAMBDA_FUNCTION_NAME" )
158
+ if not function_name :
159
+ return None
160
+
161
+ lambda_client = boto3 .client ("lambda" )
162
+ function_config = lambda_client .get_function_configuration (
163
+ FunctionName = function_name
164
+ )
165
+ return function_config .get ("Role" )
166
+ except Exception as e :
167
+ logger .debug (f"Failed to get Lambda role ARN: { e } " )
168
+ return None
169
+
170
+
171
+ def get_current_identity () -> Tuple [Optional [str ], Optional [str ]]:
172
+ """
173
+ Get the current role ARN and source type based on the runtime environment.
174
+ Returns (role_arn, credential_source)
175
+ """
176
+ env = detect_aws_environment ()
177
+
178
+ if env == AwsEnvironment .LAMBDA :
179
+ role_arn = get_lambda_role_arn ()
180
+ return role_arn , AwsServicePrincipal .LAMBDA .value
181
+
182
+ elif env == AwsEnvironment .EKS :
183
+ role_arn = os .getenv ("AWS_ROLE_ARN" )
184
+ return role_arn , AwsServicePrincipal .EKS .value
185
+
186
+ elif env == AwsEnvironment .APP_RUNNER :
187
+ try :
188
+ sts = boto3 .client ("sts" )
189
+ identity = sts .get_caller_identity ()
190
+ return identity .get ("Arn" ), AwsServicePrincipal .APP_RUNNER .value
191
+ except Exception as e :
192
+ logger .debug (f"Failed to get App Runner role: { e } " )
193
+
194
+ elif env == AwsEnvironment .ECS :
195
+ try :
196
+ metadata_uri = os .getenv ("ECS_CONTAINER_METADATA_URI_V4" ) or os .getenv (
197
+ "ECS_CONTAINER_METADATA_URI"
198
+ )
199
+ if metadata_uri :
200
+ response = requests .get (f"{ metadata_uri } /task" , timeout = 1 )
201
+ if response .status_code == HTTPStatus .OK :
202
+ task_metadata = response .json ()
203
+ if "TaskARN" in task_metadata :
204
+ return (
205
+ task_metadata .get ("TaskARN" ),
206
+ AwsServicePrincipal .ECS .value ,
207
+ )
208
+ except Exception as e :
209
+ logger .debug (f"Failed to get ECS task role: { e } " )
210
+
211
+ elif env == AwsEnvironment .BEANSTALK :
212
+ # Beanstalk uses EC2 instance metadata
213
+ return get_instance_role_arn (), AwsServicePrincipal .ELASTIC_BEANSTALK .value
214
+
215
+ elif env == AwsEnvironment .EC2 :
216
+ return get_instance_role_arn (), AwsServicePrincipal .EC2 .value
217
+
218
+ return None , None
219
+
220
+
37
221
def assume_role (
38
222
role : AwsAssumeRoleConfig ,
39
223
aws_region : Optional [str ],
@@ -95,7 +279,7 @@ class AwsConnectionConfig(ConfigModel):
95
279
)
96
280
aws_profile : Optional [str ] = Field (
97
281
default = None ,
98
- description = "Named AWS profile to use. Only used if access key / secret are unset. If not set the default will be used " ,
282
+ description = "The [named profile](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html) to use from AWS credentials. Falls back to default profile if not specified and no access keys provided. Profiles are configured in ~/.aws/credentials or ~/.aws/config. " ,
99
283
)
100
284
aws_region : Optional [str ] = Field (None , description = "AWS region code." )
101
285
@@ -145,45 +329,65 @@ def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:
145
329
146
330
def get_session (self ) -> Session :
147
331
if self .aws_access_key_id and self .aws_secret_access_key :
332
+ # Explicit credentials take precedence
148
333
session = Session (
149
334
aws_access_key_id = self .aws_access_key_id ,
150
335
aws_secret_access_key = self .aws_secret_access_key ,
151
336
aws_session_token = self .aws_session_token ,
152
337
region_name = self .aws_region ,
153
338
)
154
339
elif self .aws_profile :
340
+ # Named profile is second priority
155
341
session = Session (
156
342
region_name = self .aws_region , profile_name = self .aws_profile
157
343
)
158
344
else :
159
- # Use boto3's credential autodetection.
345
+ # Use boto3's credential autodetection
160
346
session = Session (region_name = self .aws_region )
161
347
162
- if self ._normalized_aws_roles ():
163
- # Use existing session credentials to start the chain of role assumption.
164
- current_credentials = session .get_credentials ()
165
- credentials = {
166
- "AccessKeyId" : current_credentials .access_key ,
167
- "SecretAccessKey" : current_credentials .secret_key ,
168
- "SessionToken" : current_credentials .token ,
169
- }
170
-
171
- for role in self ._normalized_aws_roles ():
172
- if self ._should_refresh_credentials ():
173
- credentials = assume_role (
174
- role ,
175
- self .aws_region ,
176
- credentials = credentials ,
348
+ target_roles = self ._normalized_aws_roles ()
349
+ if target_roles :
350
+ current_role_arn , credential_source = get_current_identity ()
351
+
352
+ # Only assume role if:
353
+ # 1. We're not in a known AWS environment with a role, or
354
+ # 2. We need to assume a different role than our current one
355
+ should_assume_role = current_role_arn is None or any (
356
+ role .RoleArn != current_role_arn for role in target_roles
357
+ )
358
+
359
+ if should_assume_role :
360
+ env = detect_aws_environment ()
361
+ logger .debug (f"Assuming role(s) from { env .value } environment" )
362
+
363
+ current_credentials = session .get_credentials ()
364
+ if current_credentials is None :
365
+ raise ValueError ("No credentials available for role assumption" )
366
+
367
+ credentials = {
368
+ "AccessKeyId" : current_credentials .access_key ,
369
+ "SecretAccessKey" : current_credentials .secret_key ,
370
+ "SessionToken" : current_credentials .token ,
371
+ }
372
+
373
+ for role in target_roles :
374
+ if self ._should_refresh_credentials ():
375
+ credentials = assume_role (
376
+ role = role ,
377
+ aws_region = self .aws_region ,
378
+ credentials = credentials ,
379
+ )
380
+ if isinstance (credentials ["Expiration" ], datetime ):
381
+ self ._credentials_expiration = credentials ["Expiration" ]
382
+
383
+ session = Session (
384
+ aws_access_key_id = credentials ["AccessKeyId" ],
385
+ aws_secret_access_key = credentials ["SecretAccessKey" ],
386
+ aws_session_token = credentials ["SessionToken" ],
387
+ region_name = self .aws_region ,
177
388
)
178
- if isinstance (credentials ["Expiration" ], datetime ):
179
- self ._credentials_expiration = credentials ["Expiration" ]
180
-
181
- session = Session (
182
- aws_access_key_id = credentials ["AccessKeyId" ],
183
- aws_secret_access_key = credentials ["SecretAccessKey" ],
184
- aws_session_token = credentials ["SessionToken" ],
185
- region_name = self .aws_region ,
186
- )
389
+ else :
390
+ logger .debug (f"Using existing role from { credential_source } " )
187
391
188
392
return session
189
393
0 commit comments