Skip to content

Commit 13c577b

Browse files
authored
chore: separate plugins into their own files (#33)
- plugins.py was split up as follows: - Plugin/PluginFactory moved to plugin.py - each concrete plugin and plugin factory has its own file - PluginService/PluginManager/PluginServiceManagerContainer moved to plugin_service.py - splitting the plugins into their own files was initially causing circular import errors. This PR adds the flake8-type-checking dependency to help avoid/resolve circular imports due to imports that are only used for type checking. Fixed the errors that this dependency pointed out - corrected some dialect logic that did not match the JDBC wrapper logic
1 parent af8f843 commit 13c577b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1444
-1201
lines changed

.flake8

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[flake8]
22
max-line-length = 150
3+
extend-select = TC, TC1
+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
if TYPE_CHECKING:
20+
from boto3 import Session
21+
from aws_wrapper.hostinfo import HostInfo
22+
from aws_wrapper.pep249 import Connection
23+
from aws_wrapper.plugin_service import PluginService
24+
25+
from json import loads
26+
from logging import getLogger
27+
from re import search
28+
from types import SimpleNamespace
29+
from typing import Callable, Dict, Optional, Set, Tuple
30+
31+
import boto3
32+
from botocore.exceptions import ClientError
33+
34+
from aws_wrapper.errors import AwsWrapperError
35+
from aws_wrapper.plugin import Plugin, PluginFactory
36+
from aws_wrapper.utils.messages import Messages
37+
from aws_wrapper.utils.properties import Properties, WrapperProperties
38+
39+
logger = getLogger(__name__)
40+
41+
42+
class AwsSecretsManagerPlugin(Plugin):
43+
_SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}
44+
45+
_SECRETS_ARN_PATTERN = r"^arn:aws:secretsmanager:(?P<region>[^:\n]*):[^:\n]*:([^:/\n]*[:/])?(.*)$"
46+
47+
_secrets_cache: Dict[Tuple, SimpleNamespace] = {}
48+
_secret_key: Tuple = ()
49+
50+
@property
51+
def subscribed_methods(self) -> Set[str]:
52+
return self._SUBSCRIBED_METHODS
53+
54+
def __init__(self, plugin_service: PluginService, props: Properties, session: Optional[Session] = None):
55+
self._plugin_service = plugin_service
56+
self._session = session
57+
58+
secret_id = WrapperProperties.SECRETS_MANAGER_SECRET_ID.get(props)
59+
if not secret_id:
60+
raise AwsWrapperError(
61+
Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter",
62+
WrapperProperties.SECRETS_MANAGER_SECRET_ID.name))
63+
64+
region: str = self._get_rds_region(secret_id, props)
65+
66+
self._secret_key: Tuple = (secret_id, region)
67+
68+
def connect(self, host_info: HostInfo, props: Properties, initial: bool, connect_func: Callable) -> Connection:
69+
return self._connect(props, connect_func)
70+
71+
def force_connect(self, host_info: HostInfo, props: Properties, initial: bool,
72+
force_connect_func: Callable) -> Connection:
73+
return self._connect(props, force_connect_func)
74+
75+
def _connect(self, props: Properties, connect_func: Callable) -> Connection:
76+
secret_fetched: bool = self._update_secret()
77+
78+
try:
79+
self._apply_secret_to_properties(props)
80+
return connect_func()
81+
82+
except Exception as e:
83+
if not self._plugin_service.is_login_exception(error=e) or secret_fetched:
84+
raise AwsWrapperError(
85+
Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e)) from e
86+
87+
secret_fetched = self._update_secret(True)
88+
89+
if secret_fetched:
90+
try:
91+
self._apply_secret_to_properties(props)
92+
return connect_func()
93+
except Exception as unhandled_error:
94+
raise AwsWrapperError(
95+
Messages.get_formatted("AwsSecretsManagerPlugin.UnhandledException",
96+
unhandled_error)) from unhandled_error
97+
raise AwsWrapperError(Messages.get_formatted("AwsSecretsManagerPlugin.FailedLogin", e)) from e
98+
99+
def _update_secret(self, force_refetch: bool = False) -> bool:
100+
fetched: bool = False
101+
102+
self._secret: Optional[SimpleNamespace] = AwsSecretsManagerPlugin._secrets_cache.get(self._secret_key)
103+
104+
if not self._secret or force_refetch:
105+
try:
106+
self._secret = self._fetch_latest_credentials()
107+
if self._secret:
108+
AwsSecretsManagerPlugin._secrets_cache[self._secret_key] = self._secret
109+
fetched = True
110+
except (ClientError, AttributeError) as e:
111+
logger.debug(Messages.get_formatted("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e))
112+
raise AwsWrapperError(
113+
Messages.get_formatted("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e)) from e
114+
115+
return fetched
116+
117+
def _fetch_latest_credentials(self):
118+
session = self._session if self._session else boto3.Session()
119+
client = session.client(
120+
'secretsmanager',
121+
region_name=self._secret_key[1],
122+
)
123+
124+
secret = client.get_secret_value(
125+
SecretId=self._secret_key[0],
126+
)
127+
128+
client.close()
129+
130+
return loads(secret.get("SecretString"), object_hook=lambda d: SimpleNamespace(**d))
131+
132+
def _apply_secret_to_properties(self, properties: Properties):
133+
if self._secret:
134+
WrapperProperties.USER.set(properties, self._secret.username)
135+
WrapperProperties.PASSWORD.set(properties, self._secret.password)
136+
137+
def _get_rds_region(self, secret_id: str, props: Properties) -> str:
138+
region: Optional[str] = props.get(WrapperProperties.SECRETS_MANAGER_REGION.name)
139+
if not region:
140+
match = search(self._SECRETS_ARN_PATTERN, secret_id)
141+
if match:
142+
region = match.group("region")
143+
else:
144+
raise AwsWrapperError(
145+
Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter",
146+
WrapperProperties.SECRETS_MANAGER_REGION.name))
147+
148+
session = self._session if self._session else boto3.Session()
149+
if region not in session.get_available_regions("rds"):
150+
exception_message = Messages.get_formatted("AwsSdk.UnsupportedRegion", region)
151+
logger.debug(exception_message)
152+
raise AwsWrapperError(exception_message)
153+
154+
return region
155+
156+
157+
class AwsSecretsManagerPluginFactory(PluginFactory):
158+
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
159+
return AwsSecretsManagerPlugin(plugin_service, props)

aws_wrapper/connection_provider.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
if TYPE_CHECKING:
20+
from .hostinfo import HostInfo, HostRole
21+
from .pep249 import Connection
22+
from .utils.properties import Properties
1523

1624
import threading
1725
from typing import Callable, Dict, List, Optional, Protocol
1826

1927
from aws_wrapper.errors import AwsWrapperError
20-
from .hostinfo import HostInfo, HostRole
2128
from .hostselector import HostSelector, RandomHostSelector
22-
from .pep249 import Connection
2329
from .utils.messages import Messages
24-
from .utils.properties import Properties
2530

2631

2732
class ConnectionProvider(Protocol):

aws_wrapper/default_plugin.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
if TYPE_CHECKING:
20+
from aws_wrapper.host_list_provider import HostListProviderService
21+
from aws_wrapper.plugin_service import PluginService
22+
from aws_wrapper.pep249 import Connection
23+
24+
import copy
25+
from typing import Any, Callable, Set
26+
27+
from aws_wrapper.connection_provider import (ConnectionProvider,
28+
ConnectionProviderManager)
29+
from aws_wrapper.errors import AwsWrapperError
30+
from aws_wrapper.hostinfo import HostInfo, HostRole
31+
from aws_wrapper.plugin import Plugin
32+
from aws_wrapper.utils.messages import Messages
33+
from aws_wrapper.utils.properties import Properties, PropertiesUtils
34+
35+
36+
class DefaultPlugin(Plugin):
37+
_SUBSCRIBED_METHODS: Set[str] = {"*"}
38+
39+
def __init__(self, plugin_service: PluginService, default_conn_provider: ConnectionProvider):
40+
self._plugin_service: PluginService = plugin_service
41+
self._connection_provider_manager = ConnectionProviderManager(default_conn_provider)
42+
43+
def connect(self, host_info: HostInfo, props: Properties,
44+
initial: bool, connect_func: Callable) -> Any:
45+
target_driver_props = copy.copy(props)
46+
PropertiesUtils.remove_wrapper_props(target_driver_props)
47+
connection_provider: ConnectionProvider = \
48+
self._connection_provider_manager.get_connection_provider(host_info, target_driver_props)
49+
# logger.debug("Default plugin: connect before")
50+
result = self._connect(host_info, target_driver_props, connection_provider)
51+
# logger.debug("Default plugin: connect after")
52+
return result
53+
54+
def _connect(self, host_info: HostInfo, props: Properties, conn_provider: ConnectionProvider):
55+
result = conn_provider.connect(host_info, props)
56+
return result
57+
58+
def force_connect(self, host_info: HostInfo, props: Properties,
59+
initial: bool, force_connect_func: Callable) -> Connection:
60+
target_driver_props = copy.copy(props)
61+
PropertiesUtils.remove_wrapper_props(target_driver_props)
62+
return self._connect(host_info, target_driver_props, self._connection_provider_manager.default_provider)
63+
64+
def execute(self, target: object, method_name: str, execute_func: Callable, *args: tuple) -> Any:
65+
# logger.debug("Default plugin: execute before")
66+
result = execute_func()
67+
# logger.debug("Default plugin: execute after")
68+
return result
69+
70+
def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
71+
if HostRole.UNKNOWN == role:
72+
return False
73+
return self._connection_provider_manager.accepts_strategy(role, strategy)
74+
75+
def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo:
76+
if HostRole.UNKNOWN == role:
77+
raise AwsWrapperError(Messages.get("Plugins.UnknownHosts"))
78+
79+
hosts = self._plugin_service.hosts
80+
81+
if len(hosts) < 1:
82+
raise AwsWrapperError(Messages.get("Plugins.EmptyHosts"))
83+
84+
return self._connection_provider_manager.get_host_info_by_strategy(hosts, role, strategy)
85+
86+
@property
87+
def subscribed_methods(self) -> Set[str]:
88+
return DefaultPlugin._SUBSCRIBED_METHODS
89+
90+
def init_host_provider(
91+
self,
92+
props: Properties,
93+
host_list_provider_service: HostListProviderService,
94+
init_host_provider_func: Callable):
95+
# Do nothing
96+
# This is the last plugin in the plugin chain.
97+
# So init_host_provider_func will be a no-op and does not need to be called.
98+
pass

0 commit comments

Comments
 (0)