4
4
import json
5
5
import logging
6
6
import os
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass
9
+ from enum import auto
7
10
from json .decoder import JSONDecodeError
8
11
from typing import (
9
12
TYPE_CHECKING ,
17
20
Union ,
18
21
)
19
22
23
+ import pydantic
20
24
import requests
21
25
from deprecated import deprecated
22
26
from requests .adapters import HTTPAdapter , Retry
27
31
from datahub .cli .cli_utils import ensure_has_system_metadata , fixup_gms_url , get_or_else
28
32
from datahub .cli .env_utils import get_boolean_env_variable
29
33
from datahub .configuration .common import (
34
+ ConfigEnum ,
30
35
ConfigModel ,
31
36
ConfigurationError ,
32
37
OperationalError ,
33
38
)
39
+ from datahub .emitter .aspect import JSON_CONTENT_TYPE
34
40
from datahub .emitter .generic_emitter import Emitter
35
41
from datahub .emitter .mcp import MetadataChangeProposalWrapper
36
42
from datahub .emitter .request_helper import make_curl_command
77
83
)
78
84
79
85
86
+ class RestSinkEndpoint (ConfigEnum ):
87
+ RESTLI = auto ()
88
+ OPENAPI = auto ()
89
+
90
+
91
+ DEFAULT_REST_SINK_ENDPOINT = pydantic .parse_obj_as (
92
+ RestSinkEndpoint ,
93
+ os .getenv ("DATAHUB_REST_SINK_DEFAULT_ENDPOINT" , RestSinkEndpoint .RESTLI ),
94
+ )
95
+
96
+
80
97
class RequestsSessionConfig (ConfigModel ):
81
98
timeout : Union [float , Tuple [float , float ], None ] = _DEFAULT_TIMEOUT_SEC
82
99
@@ -143,10 +160,31 @@ def build_session(self) -> requests.Session:
143
160
return session
144
161
145
162
163
+ @dataclass
164
+ class _Chunk :
165
+ items : List [str ]
166
+ total_bytes : int = 0
167
+
168
+ def add_item (self , item : str ) -> bool :
169
+ item_bytes = len (item .encode ())
170
+ if not self .items : # Always add at least one item even if over byte limit
171
+ self .items .append (item )
172
+ self .total_bytes += item_bytes
173
+ return True
174
+ self .items .append (item )
175
+ self .total_bytes += item_bytes
176
+ return True
177
+
178
+ @staticmethod
179
+ def join (chunk : "_Chunk" ) -> str :
180
+ return "[" + "," .join (chunk .items ) + "]"
181
+
182
+
146
183
class DataHubRestEmitter (Closeable , Emitter ):
147
184
_gms_server : str
148
185
_token : Optional [str ]
149
186
_session : requests .Session
187
+ _openapi_ingestion : bool
150
188
151
189
def __init__ (
152
190
self ,
@@ -162,6 +200,7 @@ def __init__(
162
200
ca_certificate_path : Optional [str ] = None ,
163
201
client_certificate_path : Optional [str ] = None ,
164
202
disable_ssl_verification : bool = False ,
203
+ openapi_ingestion : bool = False ,
165
204
):
166
205
if not gms_server :
167
206
raise ConfigurationError ("gms server is required" )
@@ -174,9 +213,13 @@ def __init__(
174
213
self ._gms_server = fixup_gms_url (gms_server )
175
214
self ._token = token
176
215
self .server_config : Dict [str , Any ] = {}
177
-
216
+ self . _openapi_ingestion = openapi_ingestion
178
217
self ._session = requests .Session ()
179
218
219
+ logger .debug (
220
+ f"Using { 'OpenAPI' if self ._openapi_ingestion else 'Restli' } for ingestion."
221
+ )
222
+
180
223
headers = {
181
224
"X-RestLi-Protocol-Version" : "2.0.0" ,
182
225
"X-DataHub-Py-Cli-Version" : nice_version_name (),
@@ -264,6 +307,43 @@ def to_graph(self) -> "DataHubGraph":
264
307
265
308
return DataHubGraph .from_emitter (self )
266
309
310
+ def _to_openapi_request (
311
+ self ,
312
+ mcp : Union [MetadataChangeProposal , MetadataChangeProposalWrapper ],
313
+ async_flag : Optional [bool ] = None ,
314
+ async_default : bool = False ,
315
+ ) -> Optional [Tuple [str , List [Dict [str , Any ]]]]:
316
+ if mcp .aspect and mcp .aspectName :
317
+ resolved_async_flag = (
318
+ async_flag if async_flag is not None else async_default
319
+ )
320
+ url = f"{ self ._gms_server } /openapi/v3/entity/{ mcp .entityType } ?async={ 'true' if resolved_async_flag else 'false' } "
321
+
322
+ if isinstance (mcp , MetadataChangeProposalWrapper ):
323
+ aspect_value = pre_json_transform (
324
+ mcp .to_obj (simplified_structure = True )
325
+ )["aspect" ]["json" ]
326
+ else :
327
+ obj = mcp .aspect .to_obj ()
328
+ if obj .get ("value" ) and obj .get ("contentType" ) == JSON_CONTENT_TYPE :
329
+ obj = json .loads (obj ["value" ])
330
+ aspect_value = pre_json_transform (obj )
331
+ return (
332
+ url ,
333
+ [
334
+ {
335
+ "urn" : mcp .entityUrn ,
336
+ mcp .aspectName : {
337
+ "value" : aspect_value ,
338
+ "systemMetadata" : mcp .systemMetadata .to_obj ()
339
+ if mcp .systemMetadata
340
+ else None ,
341
+ },
342
+ }
343
+ ],
344
+ )
345
+ return None
346
+
267
347
def emit (
268
348
self ,
269
349
item : Union [
@@ -317,18 +397,24 @@ def emit_mcp(
317
397
mcp : Union [MetadataChangeProposal , MetadataChangeProposalWrapper ],
318
398
async_flag : Optional [bool ] = None ,
319
399
) -> None :
320
- url = f"{ self ._gms_server } /aspects?action=ingestProposal"
321
400
ensure_has_system_metadata (mcp )
322
401
323
- mcp_obj = pre_json_transform (mcp .to_obj ())
324
- payload_dict = {"proposal" : mcp_obj }
402
+ if self ._openapi_ingestion :
403
+ request = self ._to_openapi_request (mcp , async_flag , async_default = False )
404
+ if request :
405
+ self ._emit_generic (request [0 ], payload = request [1 ])
406
+ else :
407
+ url = f"{ self ._gms_server } /aspects?action=ingestProposal"
325
408
326
- if async_flag is not None :
327
- payload_dict [ "async" ] = "true" if async_flag else "false"
409
+ mcp_obj = pre_json_transform ( mcp . to_obj ())
410
+ payload_dict = { "proposal" : mcp_obj }
328
411
329
- payload = json .dumps (payload_dict )
412
+ if async_flag is not None :
413
+ payload_dict ["async" ] = "true" if async_flag else "false"
330
414
331
- self ._emit_generic (url , payload )
415
+ payload = json .dumps (payload_dict )
416
+
417
+ self ._emit_generic (url , payload )
332
418
333
419
def emit_mcps (
334
420
self ,
@@ -337,10 +423,75 @@ def emit_mcps(
337
423
) -> int :
338
424
if _DATAHUB_EMITTER_TRACE :
339
425
logger .debug (f"Attempting to emit MCP batch of size { len (mcps )} " )
340
- url = f" { self . _gms_server } /aspects?action=ingestProposalBatch"
426
+
341
427
for mcp in mcps :
342
428
ensure_has_system_metadata (mcp )
343
429
430
+ if self ._openapi_ingestion :
431
+ return self ._emit_openapi_mcps (mcps , async_flag )
432
+ else :
433
+ return self ._emit_restli_mcps (mcps , async_flag )
434
+
435
+ def _emit_openapi_mcps (
436
+ self ,
437
+ mcps : Sequence [Union [MetadataChangeProposal , MetadataChangeProposalWrapper ]],
438
+ async_flag : Optional [bool ] = None ,
439
+ ) -> int :
440
+ """
441
+ 1. Grouping MCPs by their entity URL
442
+ 2. Breaking down large batches into smaller chunks based on both:
443
+ * Total byte size (INGEST_MAX_PAYLOAD_BYTES)
444
+ * Maximum number of items (BATCH_INGEST_MAX_PAYLOAD_LENGTH)
445
+
446
+ The Chunk class encapsulates both the items and their byte size tracking
447
+ Serializing the items only once with json.dumps(request[1]) and reusing that
448
+ The chunking logic handles edge cases (always accepting at least one item per chunk)
449
+ The joining logic is efficient with a simple string concatenation
450
+
451
+ :param mcps: metadata change proposals to transmit
452
+ :param async_flag: the mode
453
+ :return: number of requests
454
+ """
455
+ # group by entity url
456
+ batches : Dict [str , List [_Chunk ]] = defaultdict (
457
+ lambda : [_Chunk (items = [])]
458
+ ) # Initialize with one empty Chunk
459
+
460
+ for mcp in mcps :
461
+ request = self ._to_openapi_request (mcp , async_flag , async_default = True )
462
+ if request :
463
+ current_chunk = batches [request [0 ]][- 1 ] # Get the last chunk
464
+ # Only serialize once
465
+ serialized_item = json .dumps (request [1 ][0 ])
466
+ item_bytes = len (serialized_item .encode ())
467
+
468
+ # If adding this item would exceed max_bytes, create a new chunk
469
+ # Unless the chunk is empty (always add at least one item)
470
+ if current_chunk .items and (
471
+ current_chunk .total_bytes + item_bytes > INGEST_MAX_PAYLOAD_BYTES
472
+ or len (current_chunk .items ) >= BATCH_INGEST_MAX_PAYLOAD_LENGTH
473
+ ):
474
+ new_chunk = _Chunk (items = [])
475
+ batches [request [0 ]].append (new_chunk )
476
+ current_chunk = new_chunk
477
+
478
+ current_chunk .add_item (serialized_item )
479
+
480
+ responses = []
481
+ for url , chunks in batches .items ():
482
+ for chunk in chunks :
483
+ response = self ._emit_generic (url , payload = _Chunk .join (chunk ))
484
+ responses .append (response )
485
+
486
+ return len (responses )
487
+
488
+ def _emit_restli_mcps (
489
+ self ,
490
+ mcps : Sequence [Union [MetadataChangeProposal , MetadataChangeProposalWrapper ]],
491
+ async_flag : Optional [bool ] = None ,
492
+ ) -> int :
493
+ url = f"{ self ._gms_server } /aspects?action=ingestProposalBatch"
494
+
344
495
mcp_objs = [pre_json_transform (mcp .to_obj ()) for mcp in mcps ]
345
496
346
497
# As a safety mechanism, we need to make sure we don't exceed the max payload size for GMS.
@@ -392,7 +543,10 @@ def emit_usage(self, usageStats: UsageAggregation) -> None:
392
543
payload = json .dumps (snapshot )
393
544
self ._emit_generic (url , payload )
394
545
395
- def _emit_generic (self , url : str , payload : str ) -> None :
546
+ def _emit_generic (self , url : str , payload : Union [str , Any ]) -> requests .Response :
547
+ if not isinstance (payload , str ):
548
+ payload = json .dumps (payload )
549
+
396
550
curl_command = make_curl_command (self ._session , "POST" , url , payload )
397
551
payload_size = len (payload )
398
552
if payload_size > INGEST_MAX_PAYLOAD_BYTES :
@@ -408,6 +562,7 @@ def _emit_generic(self, url: str, payload: str) -> None:
408
562
try :
409
563
response = self ._session .post (url , data = payload )
410
564
response .raise_for_status ()
565
+ return response
411
566
except HTTPError as e :
412
567
try :
413
568
info : Dict = response .json ()
0 commit comments