Skip to content

Commit e88c46b

Browse files
committed
feat(openapi-ingestion): implement openapi ingestion
* not enabled by default
1 parent 48b6581 commit e88c46b

File tree

5 files changed

+400
-15
lines changed

5 files changed

+400
-15
lines changed

metadata-ingestion/src/datahub/emitter/rest_emitter.py

+151-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import json
55
import logging
66
import os
7+
from collections import defaultdict
8+
from dataclasses import dataclass
79
from json.decoder import JSONDecodeError
810
from typing import (
911
TYPE_CHECKING,
@@ -31,6 +33,7 @@
3133
ConfigurationError,
3234
OperationalError,
3335
)
36+
from datahub.emitter.aspect import JSON_CONTENT_TYPE
3437
from datahub.emitter.generic_emitter import Emitter
3538
from datahub.emitter.mcp import MetadataChangeProposalWrapper
3639
from datahub.emitter.request_helper import make_curl_command
@@ -143,10 +146,31 @@ def build_session(self) -> requests.Session:
143146
return session
144147

145148

149+
@dataclass
150+
class _Chunk:
151+
items: List[str]
152+
total_bytes: int = 0
153+
154+
def add_item(self, item: str) -> bool:
155+
item_bytes = len(item.encode())
156+
if not self.items: # Always add at least one item even if over byte limit
157+
self.items.append(item)
158+
self.total_bytes += item_bytes
159+
return True
160+
self.items.append(item)
161+
self.total_bytes += item_bytes
162+
return True
163+
164+
@staticmethod
165+
def join(chunk: "_Chunk") -> str:
166+
return "[" + ",".join(chunk.items) + "]"
167+
168+
146169
class DataHubRestEmitter(Closeable, Emitter):
147170
_gms_server: str
148171
_token: Optional[str]
149172
_session: requests.Session
173+
_openapi_ingestion: bool
150174

151175
def __init__(
152176
self,
@@ -162,6 +186,7 @@ def __init__(
162186
ca_certificate_path: Optional[str] = None,
163187
client_certificate_path: Optional[str] = None,
164188
disable_ssl_verification: bool = False,
189+
openapi_ingestion: bool = False,
165190
):
166191
if not gms_server:
167192
raise ConfigurationError("gms server is required")
@@ -174,9 +199,13 @@ def __init__(
174199
self._gms_server = fixup_gms_url(gms_server)
175200
self._token = token
176201
self.server_config: Dict[str, Any] = {}
177-
202+
self._openapi_ingestion = openapi_ingestion
178203
self._session = requests.Session()
179204

205+
logger.debug(
206+
f"Using {'OpenAPI' if openapi_ingestion else 'Restli'} for ingestion."
207+
)
208+
180209
headers = {
181210
"X-RestLi-Protocol-Version": "2.0.0",
182211
"X-DataHub-Py-Cli-Version": nice_version_name(),
@@ -264,6 +293,43 @@ def to_graph(self) -> "DataHubGraph":
264293

265294
return DataHubGraph.from_emitter(self)
266295

296+
def _to_openapi_request(
297+
self,
298+
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
299+
async_flag: Optional[bool] = None,
300+
async_default: bool = False,
301+
) -> Optional[Tuple[str, List[Dict[str, Any]]]]:
302+
if mcp.aspect and mcp.aspectName:
303+
resolved_async_flag = (
304+
async_flag if async_flag is not None else async_default
305+
)
306+
url = f"{self._gms_server}/openapi/v3/entity/{mcp.entityType}?async={'true' if resolved_async_flag else 'false'}"
307+
308+
if isinstance(mcp, MetadataChangeProposalWrapper):
309+
aspect_value = pre_json_transform(
310+
mcp.to_obj(simplified_structure=True)
311+
)["aspect"]["json"]
312+
else:
313+
obj = mcp.aspect.to_obj()
314+
if obj.get("value") and obj.get("contentType") == JSON_CONTENT_TYPE:
315+
obj = json.loads(obj["value"])
316+
aspect_value = pre_json_transform(obj)
317+
return (
318+
url,
319+
[
320+
{
321+
"urn": mcp.entityUrn,
322+
mcp.aspectName: {
323+
"value": aspect_value,
324+
"systemMetadata": mcp.systemMetadata.to_obj()
325+
if mcp.systemMetadata
326+
else None,
327+
},
328+
}
329+
],
330+
)
331+
return None
332+
267333
def emit(
268334
self,
269335
item: Union[
@@ -317,18 +383,24 @@ def emit_mcp(
317383
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
318384
async_flag: Optional[bool] = None,
319385
) -> None:
320-
url = f"{self._gms_server}/aspects?action=ingestProposal"
321386
ensure_has_system_metadata(mcp)
322387

323-
mcp_obj = pre_json_transform(mcp.to_obj())
324-
payload_dict = {"proposal": mcp_obj}
388+
if self._openapi_ingestion:
389+
request = self._to_openapi_request(mcp, async_flag, async_default=False)
390+
if request:
391+
self._emit_generic(request[0], payload=request[1])
392+
else:
393+
url = f"{self._gms_server}/aspects?action=ingestProposal"
325394

326-
if async_flag is not None:
327-
payload_dict["async"] = "true" if async_flag else "false"
395+
mcp_obj = pre_json_transform(mcp.to_obj())
396+
payload_dict = {"proposal": mcp_obj}
328397

329-
payload = json.dumps(payload_dict)
398+
if async_flag is not None:
399+
payload_dict["async"] = "true" if async_flag else "false"
330400

331-
self._emit_generic(url, payload)
401+
payload = json.dumps(payload_dict)
402+
403+
self._emit_generic(url, payload)
332404

333405
def emit_mcps(
334406
self,
@@ -337,10 +409,75 @@ def emit_mcps(
337409
) -> int:
338410
if _DATAHUB_EMITTER_TRACE:
339411
logger.debug(f"Attempting to emit MCP batch of size {len(mcps)}")
340-
url = f"{self._gms_server}/aspects?action=ingestProposalBatch"
412+
341413
for mcp in mcps:
342414
ensure_has_system_metadata(mcp)
343415

416+
if self._openapi_ingestion:
417+
return self._emit_openapi_mcps(mcps, async_flag)
418+
else:
419+
return self._emit_restli_mcps(mcps, async_flag)
420+
421+
def _emit_openapi_mcps(
422+
self,
423+
mcps: Sequence[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]],
424+
async_flag: Optional[bool] = None,
425+
) -> int:
426+
"""
427+
1. Grouping MCPs by their entity URL
428+
2. Breaking down large batches into smaller chunks based on both:
429+
* Total byte size (INGEST_MAX_PAYLOAD_BYTES)
430+
* Maximum number of items (BATCH_INGEST_MAX_PAYLOAD_LENGTH)
431+
432+
The Chunk class encapsulates both the items and their byte size tracking
433+
Serializing the items only once with json.dumps(request[1]) and reusing that
434+
The chunking logic handles edge cases (always accepting at least one item per chunk)
435+
The joining logic is efficient with a simple string concatenation
436+
437+
:param mcps: metadata change proposals to transmit
438+
:param async_flag: the mode
439+
:return: number of requests
440+
"""
441+
# group by entity url
442+
batches: Dict[str, List[_Chunk]] = defaultdict(
443+
lambda: [_Chunk(items=[])]
444+
) # Initialize with one empty Chunk
445+
446+
for mcp in mcps:
447+
request = self._to_openapi_request(mcp, async_flag, async_default=True)
448+
if request:
449+
current_chunk = batches[request[0]][-1] # Get the last chunk
450+
# Only serialize once
451+
serialized_item = json.dumps(request[1][0])
452+
item_bytes = len(serialized_item.encode())
453+
454+
# If adding this item would exceed max_bytes, create a new chunk
455+
# Unless the chunk is empty (always add at least one item)
456+
if current_chunk.items and (
457+
current_chunk.total_bytes + item_bytes > INGEST_MAX_PAYLOAD_BYTES
458+
or len(current_chunk.items) >= BATCH_INGEST_MAX_PAYLOAD_LENGTH
459+
):
460+
new_chunk = _Chunk(items=[])
461+
batches[request[0]].append(new_chunk)
462+
current_chunk = new_chunk
463+
464+
current_chunk.add_item(serialized_item)
465+
466+
responses = []
467+
for url, chunks in batches.items():
468+
for chunk in chunks:
469+
response = self._emit_generic(url, payload=_Chunk.join(chunk))
470+
responses.append(response)
471+
472+
return len(responses)
473+
474+
def _emit_restli_mcps(
475+
self,
476+
mcps: Sequence[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]],
477+
async_flag: Optional[bool] = None,
478+
) -> int:
479+
url = f"{self._gms_server}/aspects?action=ingestProposalBatch"
480+
344481
mcp_objs = [pre_json_transform(mcp.to_obj()) for mcp in mcps]
345482

346483
# As a safety mechanism, we need to make sure we don't exceed the max payload size for GMS.
@@ -392,7 +529,10 @@ def emit_usage(self, usageStats: UsageAggregation) -> None:
392529
payload = json.dumps(snapshot)
393530
self._emit_generic(url, payload)
394531

395-
def _emit_generic(self, url: str, payload: str) -> None:
532+
def _emit_generic(self, url: str, payload: Union[str, Any]) -> requests.Response:
533+
if not isinstance(payload, str):
534+
payload = json.dumps(payload)
535+
396536
curl_command = make_curl_command(self._session, "POST", url, payload)
397537
payload_size = len(payload)
398538
if payload_size > INGEST_MAX_PAYLOAD_BYTES:
@@ -408,6 +548,7 @@ def _emit_generic(self, url: str, payload: str) -> None:
408548
try:
409549
response = self._session.post(url, data=payload)
410550
response.raise_for_status()
551+
return response
411552
except HTTPError as e:
412553
try:
413554
info: Dict = response.json()

metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py

+15
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@
4949
)
5050

5151

52+
class RestSinkEndpoint(ConfigEnum):
53+
RESTLI = auto()
54+
OPENAPI = auto()
55+
56+
5257
class RestSinkMode(ConfigEnum):
5358
SYNC = auto()
5459
ASYNC = auto()
@@ -64,8 +69,15 @@ class RestSinkMode(ConfigEnum):
6469
)
6570

6671

72+
_DEFAULT_REST_SINK_ENDPOINT = pydantic.parse_obj_as(
73+
RestSinkEndpoint,
74+
os.getenv("DATAHUB_REST_SINK_DEFAULT_ENDPOINT", RestSinkEndpoint.RESTLI),
75+
)
76+
77+
6778
class DatahubRestSinkConfig(DatahubClientConfig):
6879
mode: RestSinkMode = _DEFAULT_REST_SINK_MODE
80+
endpoint: RestSinkEndpoint = _DEFAULT_REST_SINK_ENDPOINT
6981

7082
# These only apply in async modes.
7183
max_threads: pydantic.PositiveInt = _DEFAULT_REST_SINK_MAX_THREADS
@@ -172,6 +184,9 @@ def _make_emitter(cls, config: DatahubRestSinkConfig) -> DataHubRestEmitter:
172184
ca_certificate_path=config.ca_certificate_path,
173185
client_certificate_path=config.client_certificate_path,
174186
disable_ssl_verification=config.disable_ssl_verification,
187+
openapi_ingestion=True
188+
if config.endpoint == RestSinkEndpoint.OPENAPI
189+
else False,
175190
)
176191

177192
@property

0 commit comments

Comments
 (0)