Skip to content

Commit fa4ff7b

Browse files
feat(openapi-ingestion): implement openapi ingestion (#12757)
1 parent 31df9c4 commit fa4ff7b

File tree

6 files changed

+409
-16
lines changed

6 files changed

+409
-16
lines changed

metadata-ingestion/src/datahub/emitter/rest_emitter.py

+165-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import json
55
import logging
66
import os
7+
from collections import defaultdict
8+
from dataclasses import dataclass
9+
from enum import auto
710
from json.decoder import JSONDecodeError
811
from typing import (
912
TYPE_CHECKING,
@@ -17,6 +20,7 @@
1720
Union,
1821
)
1922

23+
import pydantic
2024
import requests
2125
from deprecated import deprecated
2226
from requests.adapters import HTTPAdapter, Retry
@@ -27,10 +31,12 @@
2731
from datahub.cli.cli_utils import ensure_has_system_metadata, fixup_gms_url, get_or_else
2832
from datahub.cli.env_utils import get_boolean_env_variable
2933
from datahub.configuration.common import (
34+
ConfigEnum,
3035
ConfigModel,
3136
ConfigurationError,
3237
OperationalError,
3338
)
39+
from datahub.emitter.aspect import JSON_CONTENT_TYPE
3440
from datahub.emitter.generic_emitter import Emitter
3541
from datahub.emitter.mcp import MetadataChangeProposalWrapper
3642
from datahub.emitter.request_helper import make_curl_command
@@ -77,6 +83,17 @@
7783
)
7884

7985

86+
class RestSinkEndpoint(ConfigEnum):
87+
RESTLI = auto()
88+
OPENAPI = auto()
89+
90+
91+
DEFAULT_REST_SINK_ENDPOINT = pydantic.parse_obj_as(
92+
RestSinkEndpoint,
93+
os.getenv("DATAHUB_REST_SINK_DEFAULT_ENDPOINT", RestSinkEndpoint.RESTLI),
94+
)
95+
96+
8097
class RequestsSessionConfig(ConfigModel):
8198
timeout: Union[float, Tuple[float, float], None] = _DEFAULT_TIMEOUT_SEC
8299

@@ -143,10 +160,31 @@ def build_session(self) -> requests.Session:
143160
return session
144161

145162

163+
@dataclass
164+
class _Chunk:
165+
items: List[str]
166+
total_bytes: int = 0
167+
168+
def add_item(self, item: str) -> bool:
169+
item_bytes = len(item.encode())
170+
if not self.items: # Always add at least one item even if over byte limit
171+
self.items.append(item)
172+
self.total_bytes += item_bytes
173+
return True
174+
self.items.append(item)
175+
self.total_bytes += item_bytes
176+
return True
177+
178+
@staticmethod
179+
def join(chunk: "_Chunk") -> str:
180+
return "[" + ",".join(chunk.items) + "]"
181+
182+
146183
class DataHubRestEmitter(Closeable, Emitter):
147184
_gms_server: str
148185
_token: Optional[str]
149186
_session: requests.Session
187+
_openapi_ingestion: bool
150188

151189
def __init__(
152190
self,
@@ -162,6 +200,7 @@ def __init__(
162200
ca_certificate_path: Optional[str] = None,
163201
client_certificate_path: Optional[str] = None,
164202
disable_ssl_verification: bool = False,
203+
openapi_ingestion: bool = False,
165204
):
166205
if not gms_server:
167206
raise ConfigurationError("gms server is required")
@@ -174,9 +213,13 @@ def __init__(
174213
self._gms_server = fixup_gms_url(gms_server)
175214
self._token = token
176215
self.server_config: Dict[str, Any] = {}
177-
216+
self._openapi_ingestion = openapi_ingestion
178217
self._session = requests.Session()
179218

219+
logger.debug(
220+
f"Using {'OpenAPI' if self._openapi_ingestion else 'Restli'} for ingestion."
221+
)
222+
180223
headers = {
181224
"X-RestLi-Protocol-Version": "2.0.0",
182225
"X-DataHub-Py-Cli-Version": nice_version_name(),
@@ -264,6 +307,43 @@ def to_graph(self) -> "DataHubGraph":
264307

265308
return DataHubGraph.from_emitter(self)
266309

310+
def _to_openapi_request(
311+
self,
312+
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
313+
async_flag: Optional[bool] = None,
314+
async_default: bool = False,
315+
) -> Optional[Tuple[str, List[Dict[str, Any]]]]:
316+
if mcp.aspect and mcp.aspectName:
317+
resolved_async_flag = (
318+
async_flag if async_flag is not None else async_default
319+
)
320+
url = f"{self._gms_server}/openapi/v3/entity/{mcp.entityType}?async={'true' if resolved_async_flag else 'false'}"
321+
322+
if isinstance(mcp, MetadataChangeProposalWrapper):
323+
aspect_value = pre_json_transform(
324+
mcp.to_obj(simplified_structure=True)
325+
)["aspect"]["json"]
326+
else:
327+
obj = mcp.aspect.to_obj()
328+
if obj.get("value") and obj.get("contentType") == JSON_CONTENT_TYPE:
329+
obj = json.loads(obj["value"])
330+
aspect_value = pre_json_transform(obj)
331+
return (
332+
url,
333+
[
334+
{
335+
"urn": mcp.entityUrn,
336+
mcp.aspectName: {
337+
"value": aspect_value,
338+
"systemMetadata": mcp.systemMetadata.to_obj()
339+
if mcp.systemMetadata
340+
else None,
341+
},
342+
}
343+
],
344+
)
345+
return None
346+
267347
def emit(
268348
self,
269349
item: Union[
@@ -317,18 +397,24 @@ def emit_mcp(
317397
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
318398
async_flag: Optional[bool] = None,
319399
) -> None:
320-
url = f"{self._gms_server}/aspects?action=ingestProposal"
321400
ensure_has_system_metadata(mcp)
322401

323-
mcp_obj = pre_json_transform(mcp.to_obj())
324-
payload_dict = {"proposal": mcp_obj}
402+
if self._openapi_ingestion:
403+
request = self._to_openapi_request(mcp, async_flag, async_default=False)
404+
if request:
405+
self._emit_generic(request[0], payload=request[1])
406+
else:
407+
url = f"{self._gms_server}/aspects?action=ingestProposal"
325408

326-
if async_flag is not None:
327-
payload_dict["async"] = "true" if async_flag else "false"
409+
mcp_obj = pre_json_transform(mcp.to_obj())
410+
payload_dict = {"proposal": mcp_obj}
328411

329-
payload = json.dumps(payload_dict)
412+
if async_flag is not None:
413+
payload_dict["async"] = "true" if async_flag else "false"
330414

331-
self._emit_generic(url, payload)
415+
payload = json.dumps(payload_dict)
416+
417+
self._emit_generic(url, payload)
332418

333419
def emit_mcps(
334420
self,
@@ -337,10 +423,75 @@ def emit_mcps(
337423
) -> int:
338424
if _DATAHUB_EMITTER_TRACE:
339425
logger.debug(f"Attempting to emit MCP batch of size {len(mcps)}")
340-
url = f"{self._gms_server}/aspects?action=ingestProposalBatch"
426+
341427
for mcp in mcps:
342428
ensure_has_system_metadata(mcp)
343429

430+
if self._openapi_ingestion:
431+
return self._emit_openapi_mcps(mcps, async_flag)
432+
else:
433+
return self._emit_restli_mcps(mcps, async_flag)
434+
435+
def _emit_openapi_mcps(
436+
self,
437+
mcps: Sequence[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]],
438+
async_flag: Optional[bool] = None,
439+
) -> int:
440+
"""
441+
1. Grouping MCPs by their entity URL
442+
2. Breaking down large batches into smaller chunks based on both:
443+
* Total byte size (INGEST_MAX_PAYLOAD_BYTES)
444+
* Maximum number of items (BATCH_INGEST_MAX_PAYLOAD_LENGTH)
445+
446+
The Chunk class encapsulates both the items and their byte size tracking
447+
Serializing the items only once with json.dumps(request[1]) and reusing that
448+
The chunking logic handles edge cases (always accepting at least one item per chunk)
449+
The joining logic is efficient with a simple string concatenation
450+
451+
:param mcps: metadata change proposals to transmit
452+
:param async_flag: the mode
453+
:return: number of requests
454+
"""
455+
# group by entity url
456+
batches: Dict[str, List[_Chunk]] = defaultdict(
457+
lambda: [_Chunk(items=[])]
458+
) # Initialize with one empty Chunk
459+
460+
for mcp in mcps:
461+
request = self._to_openapi_request(mcp, async_flag, async_default=True)
462+
if request:
463+
current_chunk = batches[request[0]][-1] # Get the last chunk
464+
# Only serialize once
465+
serialized_item = json.dumps(request[1][0])
466+
item_bytes = len(serialized_item.encode())
467+
468+
# If adding this item would exceed max_bytes, create a new chunk
469+
# Unless the chunk is empty (always add at least one item)
470+
if current_chunk.items and (
471+
current_chunk.total_bytes + item_bytes > INGEST_MAX_PAYLOAD_BYTES
472+
or len(current_chunk.items) >= BATCH_INGEST_MAX_PAYLOAD_LENGTH
473+
):
474+
new_chunk = _Chunk(items=[])
475+
batches[request[0]].append(new_chunk)
476+
current_chunk = new_chunk
477+
478+
current_chunk.add_item(serialized_item)
479+
480+
responses = []
481+
for url, chunks in batches.items():
482+
for chunk in chunks:
483+
response = self._emit_generic(url, payload=_Chunk.join(chunk))
484+
responses.append(response)
485+
486+
return len(responses)
487+
488+
def _emit_restli_mcps(
489+
self,
490+
mcps: Sequence[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]],
491+
async_flag: Optional[bool] = None,
492+
) -> int:
493+
url = f"{self._gms_server}/aspects?action=ingestProposalBatch"
494+
344495
mcp_objs = [pre_json_transform(mcp.to_obj()) for mcp in mcps]
345496

346497
# As a safety mechanism, we need to make sure we don't exceed the max payload size for GMS.
@@ -392,7 +543,10 @@ def emit_usage(self, usageStats: UsageAggregation) -> None:
392543
payload = json.dumps(snapshot)
393544
self._emit_generic(url, payload)
394545

395-
def _emit_generic(self, url: str, payload: str) -> None:
546+
def _emit_generic(self, url: str, payload: Union[str, Any]) -> requests.Response:
547+
if not isinstance(payload, str):
548+
payload = json.dumps(payload)
549+
396550
curl_command = make_curl_command(self._session, "POST", url, payload)
397551
payload_size = len(payload)
398552
if payload_size > INGEST_MAX_PAYLOAD_BYTES:
@@ -408,6 +562,7 @@ def _emit_generic(self, url: str, payload: str) -> None:
408562
try:
409563
response = self._session.post(url, data=payload)
410564
response.raise_for_status()
565+
return response
411566
except HTTPError as e:
412567
try:
413568
info: Dict = response.json()

metadata-ingestion/src/datahub/ingestion/graph/client.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
from datahub.emitter.aspect import TIMESERIES_ASPECT_MAP
3333
from datahub.emitter.mce_builder import DEFAULT_ENV, Aspect
3434
from datahub.emitter.mcp import MetadataChangeProposalWrapper
35-
from datahub.emitter.rest_emitter import DatahubRestEmitter
35+
from datahub.emitter.rest_emitter import (
36+
DEFAULT_REST_SINK_ENDPOINT,
37+
DatahubRestEmitter,
38+
RestSinkEndpoint,
39+
)
3640
from datahub.emitter.serialization_helper import post_json_transform
3741
from datahub.ingestion.graph.config import (
3842
DatahubClientConfig as DatahubClientConfig,
@@ -141,6 +145,7 @@ def __init__(self, config: DatahubClientConfig) -> None:
141145
ca_certificate_path=self.config.ca_certificate_path,
142146
client_certificate_path=self.config.client_certificate_path,
143147
disable_ssl_verification=self.config.disable_ssl_verification,
148+
openapi_ingestion=DEFAULT_REST_SINK_ENDPOINT == RestSinkEndpoint.OPENAPI,
144149
)
145150

146151
self.server_id = _MISSING_SERVER_ID

metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from datahub.emitter.mcp_builder import mcps_from_mce
2121
from datahub.emitter.rest_emitter import (
2222
BATCH_INGEST_MAX_PAYLOAD_LENGTH,
23+
DEFAULT_REST_SINK_ENDPOINT,
2324
DataHubRestEmitter,
25+
RestSinkEndpoint,
2426
)
2527
from datahub.ingestion.api.common import RecordEnvelope, WorkUnit
2628
from datahub.ingestion.api.sink import (
@@ -66,6 +68,7 @@ class RestSinkMode(ConfigEnum):
6668

6769
class DatahubRestSinkConfig(DatahubClientConfig):
6870
mode: RestSinkMode = _DEFAULT_REST_SINK_MODE
71+
endpoint: RestSinkEndpoint = DEFAULT_REST_SINK_ENDPOINT
6972

7073
# These only apply in async modes.
7174
max_threads: pydantic.PositiveInt = _DEFAULT_REST_SINK_MAX_THREADS
@@ -172,6 +175,7 @@ def _make_emitter(cls, config: DatahubRestSinkConfig) -> DataHubRestEmitter:
172175
ca_certificate_path=config.ca_certificate_path,
173176
client_certificate_path=config.client_certificate_path,
174177
disable_ssl_verification=config.disable_ssl_verification,
178+
openapi_ingestion=config.endpoint == RestSinkEndpoint.OPENAPI,
175179
)
176180

177181
@property

0 commit comments

Comments
 (0)