1
+ import json
2
+ import os
3
+ import pytest
4
+ from unittest .mock import patch , MagicMock
5
+ import boto3
6
+ from moto import mock_sts , mock_lambda , mock_iam
7
+
8
+ from datahub .ingestion .source .aws .aws_common import (
9
+ AwsEnvironment ,
10
+ detect_aws_environment ,
11
+ get_current_identity ,
12
+ get_instance_metadata_token ,
13
+ get_instance_role_arn ,
14
+ AwsConnectionConfig ,
15
+ get_lambda_role_arn ,
16
+ is_running_on_ec2 ,
17
+ )
18
+
19
+
20
+ @pytest .fixture
21
+ def mock_aws_config ():
22
+ return AwsConnectionConfig (
23
+ aws_access_key_id = "test-key" ,
24
+ aws_secret_access_key = "test-secret" ,
25
+ aws_region = "us-east-1" ,
26
+ )
27
+
28
+
29
+ class TestAwsCommon :
30
+ def test_environment_detection_no_environment (self ):
31
+ """Test environment detection when no AWS environment is present"""
32
+ with patch .dict (os .environ , {}, clear = True ):
33
+ assert detect_aws_environment () == AwsEnvironment .UNKNOWN
34
+
35
+ def test_environment_detection_lambda (self ):
36
+ """Test Lambda environment detection"""
37
+ with patch .dict (os .environ , {"AWS_LAMBDA_FUNCTION_NAME" : "test-function" }):
38
+ assert detect_aws_environment () == AwsEnvironment .LAMBDA
39
+
40
+ def test_environment_detection_lambda_cloudformation (self ):
41
+ """Test CloudFormation Lambda environment detection"""
42
+ with patch .dict (os .environ , {
43
+ "AWS_LAMBDA_FUNCTION_NAME" : "test-function" ,
44
+ "AWS_EXECUTION_ENV" : "CloudFormation.xxx"
45
+ }):
46
+ assert detect_aws_environment () == AwsEnvironment .CLOUD_FORMATION
47
+
48
+ def test_environment_detection_eks (self ):
49
+ """Test EKS environment detection"""
50
+ with patch .dict (os .environ , {
51
+ "AWS_WEB_IDENTITY_TOKEN_FILE" : "/var/run/secrets/token" ,
52
+ "AWS_ROLE_ARN" : "arn:aws:iam::123456789012:role/test-role"
53
+ }):
54
+ assert detect_aws_environment () == AwsEnvironment .EKS
55
+
56
+ def test_environment_detection_app_runner (self ):
57
+ """Test App Runner environment detection"""
58
+ with patch .dict (os .environ , {"AWS_APP_RUNNER_SERVICE_ID" : "service-id" }):
59
+ assert detect_aws_environment () == AwsEnvironment .APP_RUNNER
60
+
61
+ def test_environment_detection_ecs (self ):
62
+ """Test ECS environment detection"""
63
+ with patch .dict (os .environ , {"ECS_CONTAINER_METADATA_URI_V4" : "http://169.254.170.2/v4" }):
64
+ assert detect_aws_environment () == AwsEnvironment .ECS
65
+
66
+ def test_environment_detection_beanstalk (self ):
67
+ """Test Elastic Beanstalk environment detection"""
68
+ with patch .dict (os .environ , {"ELASTIC_BEANSTALK_ENVIRONMENT_NAME" : "my-env" }):
69
+ assert detect_aws_environment () == AwsEnvironment .BEANSTALK
70
+
71
+ @patch ("requests.put" )
72
+ def test_ec2_metadata_token (self , mock_put ):
73
+ """Test EC2 metadata token retrieval"""
74
+ mock_put .return_value .status_code = 200
75
+ mock_put .return_value .text = "token123"
76
+
77
+ token = get_instance_metadata_token ()
78
+ assert token == "token123"
79
+
80
+ mock_put .assert_called_once_with (
81
+ "http://169.254.169.254/latest/api/token" ,
82
+ headers = {"X-aws-ec2-metadata-token-ttl-seconds" : "21600" },
83
+ timeout = 1
84
+ )
85
+
86
+ @patch ("requests.put" )
87
+ def test_ec2_metadata_token_failure (self , mock_put ):
88
+ """Test EC2 metadata token failure case"""
89
+ mock_put .return_value .status_code = 404
90
+
91
+ token = get_instance_metadata_token ()
92
+ assert token is None
93
+
94
+ @patch ("requests.get" )
95
+ @patch ("requests.put" )
96
+ def test_is_running_on_ec2 (self , mock_put , mock_get ):
97
+ """Test EC2 instance detection with IMDSv2"""
98
+ mock_put .return_value .status_code = 200
99
+ mock_put .return_value .text = "token123"
100
+ mock_get .return_value .status_code = 200
101
+
102
+ assert is_running_on_ec2 () is True
103
+
104
+ mock_put .assert_called_once_with (
105
+ "http://169.254.169.254/latest/api/token" ,
106
+ headers = {"X-aws-ec2-metadata-token-ttl-seconds" : "21600" },
107
+ timeout = 1
108
+ )
109
+ mock_get .assert_called_once_with (
110
+ "http://169.254.169.254/latest/meta-data/instance-id" ,
111
+ headers = {"X-aws-ec2-metadata-token" : "token123" },
112
+ timeout = 1
113
+ )
114
+
115
+ @patch ("requests.get" )
116
+ @patch ("requests.put" )
117
+ def test_is_running_on_ec2_failure (self , mock_put , mock_get ):
118
+ """Test EC2 instance detection failure"""
119
+ mock_put .return_value .status_code = 404
120
+ assert is_running_on_ec2 () is False
121
+
122
+ mock_put .return_value .status_code = 200
123
+ mock_put .return_value .text = "token123"
124
+ mock_get .return_value .status_code = 404
125
+ assert is_running_on_ec2 () is False
126
+
127
+ @mock_sts
128
+ @mock_lambda
129
+ @mock_iam
130
+ def test_get_current_identity_lambda (self ):
131
+ """Test getting identity in Lambda environment"""
132
+ with patch .dict (os .environ , {
133
+ "AWS_LAMBDA_FUNCTION_NAME" : "test-function" ,
134
+ "AWS_DEFAULT_REGION" : "us-east-1"
135
+ }):
136
+ # Create IAM role first with proper trust policy
137
+ iam_client = boto3 .client ("iam" , region_name = "us-east-1" )
138
+ trust_policy = {
139
+ "Version" : "2012-10-17" ,
140
+ "Statement" : [{
141
+ "Effect" : "Allow" ,
142
+ "Principal" : {
143
+ "Service" : "lambda.amazonaws.com"
144
+ },
145
+ "Action" : "sts:AssumeRole"
146
+ }]
147
+ }
148
+ iam_client .create_role (
149
+ RoleName = "test-role" ,
150
+ AssumeRolePolicyDocument = json .dumps (trust_policy )
151
+ )
152
+
153
+ lambda_client = boto3 .client ("lambda" , region_name = "us-east-1" )
154
+ lambda_client .create_function (
155
+ FunctionName = "test-function" ,
156
+ Runtime = "python3.8" ,
157
+ Role = "arn:aws:iam::123456789012:role/test-role" ,
158
+ Handler = "index.handler" ,
159
+ Code = {"ZipFile" : b"def handler(event, context): pass" }
160
+ )
161
+
162
+ role_arn , source = get_current_identity ()
163
+ assert source == "lambda.amazonaws.com"
164
+ assert role_arn == "arn:aws:iam::123456789012:role/test-role"
165
+
166
+ @patch ("requests.get" )
167
+ @patch ("requests.put" )
168
+ @mock_sts
169
+ def test_get_instance_role_arn_success (self , mock_put , mock_get ):
170
+ """Test getting EC2 instance role ARN"""
171
+ mock_put .return_value .status_code = 200
172
+ mock_put .return_value .text = "token123"
173
+ mock_get .return_value .status_code = 200
174
+ mock_get .return_value .text = "test-role"
175
+
176
+ with patch ("boto3.client" ) as mock_boto :
177
+ mock_sts = MagicMock ()
178
+ mock_sts .get_caller_identity .return_value = {
179
+ "Arn" : "arn:aws:sts::123456789012:assumed-role/test-role/instance"
180
+ }
181
+ mock_boto .return_value = mock_sts
182
+
183
+ role_arn = get_instance_role_arn ()
184
+ assert role_arn == "arn:aws:sts::123456789012:assumed-role/test-role/instance"
185
+
186
+ @mock_sts
187
+ def test_aws_connection_config_basic (self , mock_aws_config ):
188
+ """Test basic AWS connection configuration"""
189
+ session = mock_aws_config .get_session ()
190
+ creds = session .get_credentials ()
191
+ assert creds .access_key == "test-key"
192
+ assert creds .secret_key == "test-secret"
193
+
194
+ @mock_sts
195
+ def test_aws_connection_config_with_session_token (self ):
196
+ """Test AWS connection with session token"""
197
+ config = AwsConnectionConfig (
198
+ aws_access_key_id = "test-key" ,
199
+ aws_secret_access_key = "test-secret" ,
200
+ aws_session_token = "test-token" ,
201
+ aws_region = "us-east-1"
202
+ )
203
+
204
+ session = config .get_session ()
205
+ creds = session .get_credentials ()
206
+ assert creds .token == "test-token"
207
+
208
+ @mock_sts
209
+ def test_aws_connection_config_role_assumption (self ):
210
+ """Test AWS connection with role assumption"""
211
+ config = AwsConnectionConfig (
212
+ aws_access_key_id = "test-key" ,
213
+ aws_secret_access_key = "test-secret" ,
214
+ aws_region = "us-east-1" ,
215
+ aws_role = "arn:aws:iam::123456789012:role/test-role"
216
+ )
217
+
218
+ with patch ("datahub.ingestion.source.aws.aws_common.get_current_identity" ) as mock_identity :
219
+ mock_identity .return_value = (None , None )
220
+ session = config .get_session ()
221
+ creds = session .get_credentials ()
222
+ assert creds is not None
223
+
224
+ @mock_sts
225
+ def test_aws_connection_config_skip_role_assumption (self ):
226
+ """Test AWS connection skipping role assumption when already in role"""
227
+ config = AwsConnectionConfig (
228
+ aws_region = "us-east-1" ,
229
+ aws_role = "arn:aws:iam::123456789012:role/current-role"
230
+ )
231
+
232
+ with patch ("datahub.ingestion.source.aws.aws_common.get_current_identity" ) as mock_identity :
233
+ mock_identity .return_value = ("arn:aws:iam::123456789012:role/current-role" , "ec2.amazonaws.com" )
234
+ session = config .get_session ()
235
+ assert session is not None
236
+
237
+ @mock_sts
238
+ def test_aws_connection_config_multiple_roles (self ):
239
+ """Test AWS connection with multiple role assumption"""
240
+ config = AwsConnectionConfig (
241
+ aws_access_key_id = "test-key" ,
242
+ aws_secret_access_key = "test-secret" ,
243
+ aws_region = "us-east-1" ,
244
+ aws_role = [
245
+ "arn:aws:iam::123456789012:role/role1" ,
246
+ "arn:aws:iam::123456789012:role/role2"
247
+ ]
248
+ )
249
+
250
+ with patch ("datahub.ingestion.source.aws.aws_common.get_current_identity" ) as mock_identity :
251
+ mock_identity .return_value = (None , None )
252
+ session = config .get_session ()
253
+ assert session is not None
254
+
255
+ def test_aws_connection_config_validation_error (self ):
256
+ """Test AWS connection validation"""
257
+ with patch .dict ("os.environ" , {
258
+ "AWS_ACCESS_KEY_ID" : "test-key" ,
259
+ # Deliberately missing AWS_SECRET_ACCESS_KEY
260
+ "AWS_DEFAULT_REGION" : "us-east-1"
261
+ }, clear = True ):
262
+ config = AwsConnectionConfig () # Let it pick up from environment
263
+ session = config .get_session ()
264
+ with pytest .raises (Exception , match = "Partial credentials found in env, missing: AWS_SECRET_ACCESS_KEY" ):
265
+ session .get_credentials ()
266
+
267
+ @pytest .mark .parametrize (
268
+ "env_vars,expected_environment" ,
269
+ [
270
+ ({}, AwsEnvironment .UNKNOWN ),
271
+ ({"AWS_LAMBDA_FUNCTION_NAME" : "test" }, AwsEnvironment .LAMBDA ),
272
+ ({"AWS_LAMBDA_FUNCTION_NAME" : "test" , "AWS_EXECUTION_ENV" : "CloudFormation" },
273
+ AwsEnvironment .CLOUD_FORMATION ),
274
+ ({"AWS_WEB_IDENTITY_TOKEN_FILE" : "/token" , "AWS_ROLE_ARN" : "arn:aws:iam::123:role/test" },
275
+ AwsEnvironment .EKS ),
276
+ ({"AWS_APP_RUNNER_SERVICE_ID" : "service-123" }, AwsEnvironment .APP_RUNNER ),
277
+ ({"ECS_CONTAINER_METADATA_URI_V4" : "http://169.254.170.2" }, AwsEnvironment .ECS ),
278
+ ({"ELASTIC_BEANSTALK_ENVIRONMENT_NAME" : "my-env" }, AwsEnvironment .BEANSTALK ),
279
+ ]
280
+ )
281
+ def test_environment_detection_parametrized (self , env_vars , expected_environment ):
282
+ """Parametrized test for environment detection with different configurations"""
283
+ with patch .dict (os .environ , env_vars , clear = True ):
284
+ assert detect_aws_environment () == expected_environment
0 commit comments