1
+ from __future__ import annotations
2
+
1
3
import functools
2
4
import json
3
5
import logging
4
6
import os
5
7
from json .decoder import JSONDecodeError
6
- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Sequence , Union
8
+ from typing import (
9
+ TYPE_CHECKING ,
10
+ Any ,
11
+ Callable ,
12
+ Dict ,
13
+ List ,
14
+ Optional ,
15
+ Sequence ,
16
+ Tuple ,
17
+ Union ,
18
+ )
7
19
8
20
import requests
9
21
from deprecated import deprecated
12
24
13
25
from datahub import nice_version_name
14
26
from datahub .cli import config_utils
15
- from datahub .cli .cli_utils import ensure_has_system_metadata , fixup_gms_url
27
+ from datahub .cli .cli_utils import ensure_has_system_metadata , fixup_gms_url , get_or_else
16
28
from datahub .cli .env_utils import get_boolean_env_variable
17
- from datahub .configuration .common import ConfigurationError , OperationalError
29
+ from datahub .configuration .common import (
30
+ ConfigModel ,
31
+ ConfigurationError ,
32
+ OperationalError ,
33
+ )
18
34
from datahub .emitter .generic_emitter import Emitter
19
35
from datahub .emitter .mcp import MetadataChangeProposalWrapper
20
36
from datahub .emitter .request_helper import make_curl_command
31
47
32
48
logger = logging .getLogger (__name__ )
33
49
34
- _DEFAULT_CONNECT_TIMEOUT_SEC = 30 # 30 seconds should be plenty to connect
35
- _DEFAULT_READ_TIMEOUT_SEC = (
36
- 30 # Any ingest call taking longer than 30 seconds should be abandoned
37
- )
50
+ _DEFAULT_TIMEOUT_SEC = 30 # 30 seconds should be plenty to connect
51
+ _TIMEOUT_LOWER_BOUND_SEC = 1 # if below this, we log a warning
38
52
_DEFAULT_RETRY_STATUS_CODES = [ # Additional status codes to retry on
39
53
429 ,
40
54
500 ,
63
77
)
64
78
65
79
80
+ class RequestsSessionConfig (ConfigModel ):
81
+ timeout : Union [float , Tuple [float , float ], None ] = _DEFAULT_TIMEOUT_SEC
82
+
83
+ retry_status_codes : List [int ] = _DEFAULT_RETRY_STATUS_CODES
84
+ retry_methods : List [str ] = _DEFAULT_RETRY_METHODS
85
+ retry_max_times : int = _DEFAULT_RETRY_MAX_TIMES
86
+
87
+ extra_headers : Dict [str , str ] = {}
88
+
89
+ ca_certificate_path : Optional [str ] = None
90
+ client_certificate_path : Optional [str ] = None
91
+ disable_ssl_verification : bool = False
92
+
93
+ def build_session (self ) -> requests .Session :
94
+ session = requests .Session ()
95
+
96
+ if self .extra_headers :
97
+ session .headers .update (self .extra_headers )
98
+
99
+ if self .client_certificate_path :
100
+ session .cert = self .client_certificate_path
101
+
102
+ if self .ca_certificate_path :
103
+ session .verify = self .ca_certificate_path
104
+
105
+ if self .disable_ssl_verification :
106
+ session .verify = False
107
+
108
+ try :
109
+ # Set raise_on_status to False to propagate errors:
110
+ # https://stackoverflow.com/questions/70189330/determine-status-code-from-python-retry-exception
111
+ # Must call `raise_for_status` after making a request, which we do
112
+ retry_strategy = Retry (
113
+ total = self .retry_max_times ,
114
+ status_forcelist = self .retry_status_codes ,
115
+ backoff_factor = 2 ,
116
+ allowed_methods = self .retry_methods ,
117
+ raise_on_status = False ,
118
+ )
119
+ except TypeError :
120
+ # Prior to urllib3 1.26, the Retry class used `method_whitelist` instead of `allowed_methods`.
121
+ retry_strategy = Retry (
122
+ total = self .retry_max_times ,
123
+ status_forcelist = self .retry_status_codes ,
124
+ backoff_factor = 2 ,
125
+ method_whitelist = self .retry_methods ,
126
+ raise_on_status = False ,
127
+ )
128
+
129
+ adapter = HTTPAdapter (
130
+ pool_connections = 100 , pool_maxsize = 100 , max_retries = retry_strategy
131
+ )
132
+ session .mount ("http://" , adapter )
133
+ session .mount ("https://" , adapter )
134
+
135
+ if self .timeout is not None :
136
+ # Shim session.request to apply default timeout values.
137
+ # Via https://stackoverflow.com/a/59317604.
138
+ session .request = functools .partial ( # type: ignore
139
+ session .request ,
140
+ timeout = self .timeout ,
141
+ )
142
+
143
+ return session
144
+
145
+
66
146
class DataHubRestEmitter (Closeable , Emitter ):
67
147
_gms_server : str
68
148
_token : Optional [str ]
69
149
_session : requests .Session
70
- _connect_timeout_sec : float = _DEFAULT_CONNECT_TIMEOUT_SEC
71
- _read_timeout_sec : float = _DEFAULT_READ_TIMEOUT_SEC
72
- _retry_status_codes : List [int ] = _DEFAULT_RETRY_STATUS_CODES
73
- _retry_methods : List [str ] = _DEFAULT_RETRY_METHODS
74
- _retry_max_times : int = _DEFAULT_RETRY_MAX_TIMES
75
150
76
151
def __init__ (
77
152
self ,
@@ -102,15 +177,13 @@ def __init__(
102
177
103
178
self ._session = requests .Session ()
104
179
105
- self ._session .headers .update (
106
- {
107
- "X-RestLi-Protocol-Version" : "2.0.0" ,
108
- "X-DataHub-Py-Cli-Version" : nice_version_name (),
109
- "Content-Type" : "application/json" ,
110
- }
111
- )
180
+ headers = {
181
+ "X-RestLi-Protocol-Version" : "2.0.0" ,
182
+ "X-DataHub-Py-Cli-Version" : nice_version_name (),
183
+ "Content-Type" : "application/json" ,
184
+ }
112
185
if token :
113
- self . _session . headers . update ({ "Authorization" : f"Bearer { token } " })
186
+ headers [ "Authorization" ] = f"Bearer { token } "
114
187
else :
115
188
# HACK: When no token is provided but system auth env variables are set, we use them.
116
189
# Ideally this should simply get passed in as config, instead of being sneakily injected
@@ -119,75 +192,43 @@ def __init__(
119
192
# rest emitter, and the rest sink uses the rest emitter under the hood.
120
193
system_auth = config_utils .get_system_auth ()
121
194
if system_auth is not None :
122
- self ._session .headers .update ({"Authorization" : system_auth })
123
-
124
- if extra_headers :
125
- self ._session .headers .update (extra_headers )
126
-
127
- if client_certificate_path :
128
- self ._session .cert = client_certificate_path
129
-
130
- if ca_certificate_path :
131
- self ._session .verify = ca_certificate_path
132
-
133
- if disable_ssl_verification :
134
- self ._session .verify = False
135
-
136
- self ._connect_timeout_sec = (
137
- connect_timeout_sec or timeout_sec or _DEFAULT_CONNECT_TIMEOUT_SEC
138
- )
139
- self ._read_timeout_sec = (
140
- read_timeout_sec or timeout_sec or _DEFAULT_READ_TIMEOUT_SEC
141
- )
142
-
143
- if self ._connect_timeout_sec < 1 or self ._read_timeout_sec < 1 :
144
- logger .warning (
145
- f"Setting timeout values lower than 1 second is not recommended. Your configuration is connect_timeout:{ self ._connect_timeout_sec } s, read_timeout:{ self ._read_timeout_sec } s"
146
- )
147
-
148
- if retry_status_codes is not None : # Only if missing. Empty list is allowed
149
- self ._retry_status_codes = retry_status_codes
150
-
151
- if retry_methods is not None :
152
- self ._retry_methods = retry_methods
153
-
154
- if retry_max_times :
155
- self ._retry_max_times = retry_max_times
195
+ headers ["Authorization" ] = system_auth
156
196
157
- try :
158
- # Set raise_on_status to False to propagate errors:
159
- # https://stackoverflow.com/questions/70189330/determine-status-code-from-python-retry-exception
160
- # Must call `raise_for_status` after making a request, which we do
161
- retry_strategy = Retry (
162
- total = self ._retry_max_times ,
163
- status_forcelist = self ._retry_status_codes ,
164
- backoff_factor = 2 ,
165
- allowed_methods = self ._retry_methods ,
166
- raise_on_status = False ,
167
- )
168
- except TypeError :
169
- # Prior to urllib3 1.26, the Retry class used `method_whitelist` instead of `allowed_methods`.
170
- retry_strategy = Retry (
171
- total = self ._retry_max_times ,
172
- status_forcelist = self ._retry_status_codes ,
173
- backoff_factor = 2 ,
174
- method_whitelist = self ._retry_methods ,
175
- raise_on_status = False ,
197
+ timeout : float | tuple [float , float ]
198
+ if connect_timeout_sec is not None or read_timeout_sec is not None :
199
+ timeout = (
200
+ connect_timeout_sec or timeout_sec or _DEFAULT_TIMEOUT_SEC ,
201
+ read_timeout_sec or timeout_sec or _DEFAULT_TIMEOUT_SEC ,
176
202
)
203
+ if (
204
+ timeout [0 ] < _TIMEOUT_LOWER_BOUND_SEC
205
+ or timeout [1 ] < _TIMEOUT_LOWER_BOUND_SEC
206
+ ):
207
+ logger .warning (
208
+ f"Setting timeout values lower than { _TIMEOUT_LOWER_BOUND_SEC } second is not recommended. Your configuration is (connect_timeout, read_timeout) = { timeout } seconds"
209
+ )
210
+ else :
211
+ timeout = get_or_else (timeout_sec , _DEFAULT_TIMEOUT_SEC )
212
+ if timeout < _TIMEOUT_LOWER_BOUND_SEC :
213
+ logger .warning (
214
+ f"Setting timeout values lower than { _TIMEOUT_LOWER_BOUND_SEC } second is not recommended. Your configuration is timeout = { timeout } seconds"
215
+ )
177
216
178
- adapter = HTTPAdapter (
179
- pool_connections = 100 , pool_maxsize = 100 , max_retries = retry_strategy
180
- )
181
- self . _session . mount ( "http://" , adapter )
182
- self . _session . mount ( "https://" , adapter )
183
-
184
- # Shim session.request to apply default timeout values.
185
- # Via https://stackoverflow.com/a/59317604.
186
- self . _session . request = functools . partial ( # type: ignore
187
- self . _session . request ,
188
- timeout = ( self . _connect_timeout_sec , self . _read_timeout_sec ) ,
217
+ self . _session_config = RequestsSessionConfig (
218
+ timeout = timeout ,
219
+ retry_status_codes = get_or_else (
220
+ retry_status_codes , _DEFAULT_RETRY_STATUS_CODES
221
+ ),
222
+ retry_methods = get_or_else ( retry_methods , _DEFAULT_RETRY_METHODS ),
223
+ retry_max_times = get_or_else ( retry_max_times , _DEFAULT_RETRY_MAX_TIMES ),
224
+ extra_headers = { ** headers , ** ( extra_headers or {})},
225
+ ca_certificate_path = ca_certificate_path ,
226
+ client_certificate_path = client_certificate_path ,
227
+ disable_ssl_verification = disable_ssl_verification ,
189
228
)
190
229
230
+ self ._session = self ._session_config .build_session ()
231
+
191
232
def test_connection (self ) -> None :
192
233
url = f"{ self ._gms_server } /config"
193
234
response = self ._session .get (url )
0 commit comments