@@ -279,10 +279,11 @@ def _get_dataset_schema(self, schema: str) -> Optional[List[Tuple[str, str]]]:
279
279
]
280
280
except (KeyError , TypeError ):
281
281
return None
282
-
282
+ # If the schema is not formatted, return None
283
283
return None
284
284
285
285
def _get_dataset_platform_from_source_type (self , source_type ):
286
+ # manually map mlflow platform to datahub platform
286
287
if source_type == "gs" :
287
288
return "gcs"
288
289
return source_type
@@ -301,8 +302,10 @@ def _get_dataset_input_workunits(self, run: Run) -> Iterable[MetadataWorkUnit]:
301
302
platform = self ._get_dataset_platform_from_source_type (source_type )
302
303
custom_properties = dataset_tags
303
304
formatted_schema = self ._get_dataset_schema (dataset .schema )
305
+ # If the schema is not formatted, pass the schema as a custom property
304
306
if formatted_schema is None :
305
307
custom_properties ["schema" ] = dataset .schema
308
+ # If the dataset is local or code, we create a local dataset reference
306
309
if source_type in ("local" , "code" ):
307
310
local_dataset_reference = Dataset (
308
311
platform = platform ,
@@ -312,7 +315,7 @@ def _get_dataset_input_workunits(self, run: Run) -> Iterable[MetadataWorkUnit]:
312
315
)
313
316
yield from local_dataset_reference .as_workunits ()
314
317
dataset_reference_urns .append (str (local_dataset_reference .urn ))
315
-
318
+ # Otherwise, we create a hosted dataset reference and a hosted dataset
316
319
else :
317
320
hosted_dataset = Dataset (
318
321
platform = self ._get_dataset_platform_from_source_type (source_type ),
@@ -336,6 +339,7 @@ def _get_dataset_input_workunits(self, run: Run) -> Iterable[MetadataWorkUnit]:
336
339
yield from hosted_dataset .as_workunits ()
337
340
yield from hosted_dataset_reference .as_workunits ()
338
341
342
+ # add the dataset reference as upstream for the run
339
343
if dataset_reference_urns :
340
344
input_edges = [
341
345
EdgeClass (destinationUrn = dataset_referece_urn )
0 commit comments