Skip to content

Commit 4c44886

Browse files
committed
adding tests
1 parent 5b63991 commit 4c44886

File tree

3 files changed

+287
-3
lines changed

3 files changed

+287
-3
lines changed

metadata-ingestion/setup.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@
165165
# Deal with a version incompatibility between botocore (used by boto3) and urllib3.
166166
# See https://github.com/boto/botocore/pull/2563.
167167
"botocore!=1.23.0",
168-
"requests",
169168
}
170169

171170
path_spec_common = {
@@ -274,7 +273,7 @@
274273
"ujson>=5.2.0",
275274
"smart-open[s3]>=5.2.1",
276275
# moto 5.0.0 drops support for Python 3.7
277-
"moto[s3]<5.0.0",
276+
"moto[s3,sts,lambda,iam]<5.0.0",
278277
*path_spec_common,
279278
}
280279

metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import os
33
from datetime import datetime, timedelta, timezone
44
from enum import Enum
5-
import requests
65
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
76

87
import boto3
8+
import requests
9+
910
from boto3.session import Session
1011
from botocore.config import DEFAULT_TIMEOUT, Config
1112
from botocore.utils import fix_s3_host
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
import json
2+
import os
3+
import pytest
4+
from unittest.mock import patch, MagicMock
5+
import boto3
6+
from moto import mock_sts, mock_lambda, mock_iam
7+
8+
from datahub.ingestion.source.aws.aws_common import (
9+
AwsEnvironment,
10+
detect_aws_environment,
11+
get_current_identity,
12+
get_instance_metadata_token,
13+
get_instance_role_arn,
14+
AwsConnectionConfig,
15+
get_lambda_role_arn,
16+
is_running_on_ec2,
17+
)
18+
19+
20+
@pytest.fixture
21+
def mock_aws_config():
22+
return AwsConnectionConfig(
23+
aws_access_key_id="test-key",
24+
aws_secret_access_key="test-secret",
25+
aws_region="us-east-1",
26+
)
27+
28+
29+
class TestAwsCommon:
30+
def test_environment_detection_no_environment(self):
31+
"""Test environment detection when no AWS environment is present"""
32+
with patch.dict(os.environ, {}, clear=True):
33+
assert detect_aws_environment() == AwsEnvironment.UNKNOWN
34+
35+
def test_environment_detection_lambda(self):
36+
"""Test Lambda environment detection"""
37+
with patch.dict(os.environ, {"AWS_LAMBDA_FUNCTION_NAME": "test-function"}):
38+
assert detect_aws_environment() == AwsEnvironment.LAMBDA
39+
40+
def test_environment_detection_lambda_cloudformation(self):
41+
"""Test CloudFormation Lambda environment detection"""
42+
with patch.dict(os.environ, {
43+
"AWS_LAMBDA_FUNCTION_NAME": "test-function",
44+
"AWS_EXECUTION_ENV": "CloudFormation.xxx"
45+
}):
46+
assert detect_aws_environment() == AwsEnvironment.CLOUD_FORMATION
47+
48+
def test_environment_detection_eks(self):
49+
"""Test EKS environment detection"""
50+
with patch.dict(os.environ, {
51+
"AWS_WEB_IDENTITY_TOKEN_FILE": "/var/run/secrets/token",
52+
"AWS_ROLE_ARN": "arn:aws:iam::123456789012:role/test-role"
53+
}):
54+
assert detect_aws_environment() == AwsEnvironment.EKS
55+
56+
def test_environment_detection_app_runner(self):
57+
"""Test App Runner environment detection"""
58+
with patch.dict(os.environ, {"AWS_APP_RUNNER_SERVICE_ID": "service-id"}):
59+
assert detect_aws_environment() == AwsEnvironment.APP_RUNNER
60+
61+
def test_environment_detection_ecs(self):
62+
"""Test ECS environment detection"""
63+
with patch.dict(os.environ, {"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2/v4"}):
64+
assert detect_aws_environment() == AwsEnvironment.ECS
65+
66+
def test_environment_detection_beanstalk(self):
67+
"""Test Elastic Beanstalk environment detection"""
68+
with patch.dict(os.environ, {"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "my-env"}):
69+
assert detect_aws_environment() == AwsEnvironment.BEANSTALK
70+
71+
@patch("requests.put")
72+
def test_ec2_metadata_token(self, mock_put):
73+
"""Test EC2 metadata token retrieval"""
74+
mock_put.return_value.status_code = 200
75+
mock_put.return_value.text = "token123"
76+
77+
token = get_instance_metadata_token()
78+
assert token == "token123"
79+
80+
mock_put.assert_called_once_with(
81+
"http://169.254.169.254/latest/api/token",
82+
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
83+
timeout=1
84+
)
85+
86+
@patch("requests.put")
87+
def test_ec2_metadata_token_failure(self, mock_put):
88+
"""Test EC2 metadata token failure case"""
89+
mock_put.return_value.status_code = 404
90+
91+
token = get_instance_metadata_token()
92+
assert token is None
93+
94+
@patch("requests.get")
95+
@patch("requests.put")
96+
def test_is_running_on_ec2(self, mock_put, mock_get):
97+
"""Test EC2 instance detection with IMDSv2"""
98+
mock_put.return_value.status_code = 200
99+
mock_put.return_value.text = "token123"
100+
mock_get.return_value.status_code = 200
101+
102+
assert is_running_on_ec2() is True
103+
104+
mock_put.assert_called_once_with(
105+
"http://169.254.169.254/latest/api/token",
106+
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
107+
timeout=1
108+
)
109+
mock_get.assert_called_once_with(
110+
"http://169.254.169.254/latest/meta-data/instance-id",
111+
headers={"X-aws-ec2-metadata-token": "token123"},
112+
timeout=1
113+
)
114+
115+
@patch("requests.get")
116+
@patch("requests.put")
117+
def test_is_running_on_ec2_failure(self, mock_put, mock_get):
118+
"""Test EC2 instance detection failure"""
119+
mock_put.return_value.status_code = 404
120+
assert is_running_on_ec2() is False
121+
122+
mock_put.return_value.status_code = 200
123+
mock_put.return_value.text = "token123"
124+
mock_get.return_value.status_code = 404
125+
assert is_running_on_ec2() is False
126+
127+
@mock_sts
128+
@mock_lambda
129+
@mock_iam
130+
def test_get_current_identity_lambda(self):
131+
"""Test getting identity in Lambda environment"""
132+
with patch.dict(os.environ, {
133+
"AWS_LAMBDA_FUNCTION_NAME": "test-function",
134+
"AWS_DEFAULT_REGION": "us-east-1"
135+
}):
136+
# Create IAM role first with proper trust policy
137+
iam_client = boto3.client("iam", region_name="us-east-1")
138+
trust_policy = {
139+
"Version": "2012-10-17",
140+
"Statement": [{
141+
"Effect": "Allow",
142+
"Principal": {
143+
"Service": "lambda.amazonaws.com"
144+
},
145+
"Action": "sts:AssumeRole"
146+
}]
147+
}
148+
iam_client.create_role(
149+
RoleName="test-role",
150+
AssumeRolePolicyDocument=json.dumps(trust_policy)
151+
)
152+
153+
lambda_client = boto3.client("lambda", region_name="us-east-1")
154+
lambda_client.create_function(
155+
FunctionName="test-function",
156+
Runtime="python3.8",
157+
Role="arn:aws:iam::123456789012:role/test-role",
158+
Handler="index.handler",
159+
Code={"ZipFile": b"def handler(event, context): pass"}
160+
)
161+
162+
role_arn, source = get_current_identity()
163+
assert source == "lambda.amazonaws.com"
164+
assert role_arn == "arn:aws:iam::123456789012:role/test-role"
165+
166+
@patch("requests.get")
167+
@patch("requests.put")
168+
@mock_sts
169+
def test_get_instance_role_arn_success(self, mock_put, mock_get):
170+
"""Test getting EC2 instance role ARN"""
171+
mock_put.return_value.status_code = 200
172+
mock_put.return_value.text = "token123"
173+
mock_get.return_value.status_code = 200
174+
mock_get.return_value.text = "test-role"
175+
176+
with patch("boto3.client") as mock_boto:
177+
mock_sts = MagicMock()
178+
mock_sts.get_caller_identity.return_value = {
179+
"Arn": "arn:aws:sts::123456789012:assumed-role/test-role/instance"
180+
}
181+
mock_boto.return_value = mock_sts
182+
183+
role_arn = get_instance_role_arn()
184+
assert role_arn == "arn:aws:sts::123456789012:assumed-role/test-role/instance"
185+
186+
@mock_sts
187+
def test_aws_connection_config_basic(self, mock_aws_config):
188+
"""Test basic AWS connection configuration"""
189+
session = mock_aws_config.get_session()
190+
creds = session.get_credentials()
191+
assert creds.access_key == "test-key"
192+
assert creds.secret_key == "test-secret"
193+
194+
@mock_sts
195+
def test_aws_connection_config_with_session_token(self):
196+
"""Test AWS connection with session token"""
197+
config = AwsConnectionConfig(
198+
aws_access_key_id="test-key",
199+
aws_secret_access_key="test-secret",
200+
aws_session_token="test-token",
201+
aws_region="us-east-1"
202+
)
203+
204+
session = config.get_session()
205+
creds = session.get_credentials()
206+
assert creds.token == "test-token"
207+
208+
@mock_sts
209+
def test_aws_connection_config_role_assumption(self):
210+
"""Test AWS connection with role assumption"""
211+
config = AwsConnectionConfig(
212+
aws_access_key_id="test-key",
213+
aws_secret_access_key="test-secret",
214+
aws_region="us-east-1",
215+
aws_role="arn:aws:iam::123456789012:role/test-role"
216+
)
217+
218+
with patch("datahub.ingestion.source.aws.aws_common.get_current_identity") as mock_identity:
219+
mock_identity.return_value = (None, None)
220+
session = config.get_session()
221+
creds = session.get_credentials()
222+
assert creds is not None
223+
224+
@mock_sts
225+
def test_aws_connection_config_skip_role_assumption(self):
226+
"""Test AWS connection skipping role assumption when already in role"""
227+
config = AwsConnectionConfig(
228+
aws_region="us-east-1",
229+
aws_role="arn:aws:iam::123456789012:role/current-role"
230+
)
231+
232+
with patch("datahub.ingestion.source.aws.aws_common.get_current_identity") as mock_identity:
233+
mock_identity.return_value = ("arn:aws:iam::123456789012:role/current-role", "ec2.amazonaws.com")
234+
session = config.get_session()
235+
assert session is not None
236+
237+
@mock_sts
238+
def test_aws_connection_config_multiple_roles(self):
239+
"""Test AWS connection with multiple role assumption"""
240+
config = AwsConnectionConfig(
241+
aws_access_key_id="test-key",
242+
aws_secret_access_key="test-secret",
243+
aws_region="us-east-1",
244+
aws_role=[
245+
"arn:aws:iam::123456789012:role/role1",
246+
"arn:aws:iam::123456789012:role/role2"
247+
]
248+
)
249+
250+
with patch("datahub.ingestion.source.aws.aws_common.get_current_identity") as mock_identity:
251+
mock_identity.return_value = (None, None)
252+
session = config.get_session()
253+
assert session is not None
254+
255+
def test_aws_connection_config_validation_error(self):
256+
"""Test AWS connection validation"""
257+
with patch.dict("os.environ", {
258+
"AWS_ACCESS_KEY_ID": "test-key",
259+
# Deliberately missing AWS_SECRET_ACCESS_KEY
260+
"AWS_DEFAULT_REGION": "us-east-1"
261+
}, clear=True):
262+
config = AwsConnectionConfig() # Let it pick up from environment
263+
session = config.get_session()
264+
with pytest.raises(Exception, match="Partial credentials found in env, missing: AWS_SECRET_ACCESS_KEY"):
265+
session.get_credentials()
266+
267+
@pytest.mark.parametrize(
268+
"env_vars,expected_environment",
269+
[
270+
({}, AwsEnvironment.UNKNOWN),
271+
({"AWS_LAMBDA_FUNCTION_NAME": "test"}, AwsEnvironment.LAMBDA),
272+
({"AWS_LAMBDA_FUNCTION_NAME": "test", "AWS_EXECUTION_ENV": "CloudFormation"},
273+
AwsEnvironment.CLOUD_FORMATION),
274+
({"AWS_WEB_IDENTITY_TOKEN_FILE": "/token", "AWS_ROLE_ARN": "arn:aws:iam::123:role/test"},
275+
AwsEnvironment.EKS),
276+
({"AWS_APP_RUNNER_SERVICE_ID": "service-123"}, AwsEnvironment.APP_RUNNER),
277+
({"ECS_CONTAINER_METADATA_URI_V4": "http://169.254.170.2"}, AwsEnvironment.ECS),
278+
({"ELASTIC_BEANSTALK_ENVIRONMENT_NAME": "my-env"}, AwsEnvironment.BEANSTALK),
279+
]
280+
)
281+
def test_environment_detection_parametrized(self, env_vars, expected_environment):
282+
"""Parametrized test for environment detection with different configurations"""
283+
with patch.dict(os.environ, env_vars, clear=True):
284+
assert detect_aws_environment() == expected_environment

0 commit comments

Comments
 (0)