47
47
ContainerClass ,
48
48
DataPlatformInstanceClass ,
49
49
DataProcessInstanceInputClass ,
50
+ DataProcessInstanceOutputClass ,
50
51
DataProcessInstancePropertiesClass ,
51
52
DataProcessInstanceRunEventClass ,
52
53
DataProcessInstanceRunResultClass ,
53
54
DataProcessRunStatusClass ,
54
55
DatasetPropertiesClass ,
56
+ EdgeClass ,
55
57
MetadataAttributionClass ,
56
58
MLHyperParamClass ,
57
59
MLMetricClass ,
@@ -306,6 +308,12 @@ def _gen_experiment_run_mcps(
306
308
if isinstance (run_result_type , RunResultTypeClass )
307
309
and created_time is not None
308
310
else None ,
311
+ DataProcessInstanceOutputClass (
312
+ outputs = [],
313
+ outputEdges = [
314
+ EdgeClass (destinationUrn = experiment_key .as_urn ()),
315
+ ],
316
+ ),
309
317
],
310
318
)
311
319
@@ -434,6 +442,17 @@ def _gen_training_job_mcps(
434
442
if job_meta .input_dataset
435
443
else None
436
444
)
445
+ # If Training Job has Output Model
446
+ model_urn = (
447
+ self ._make_ml_model_urn (
448
+ model_version = job_meta .output_model_version ,
449
+ model_name = self ._make_vertexai_model_name (
450
+ entity_id = job_meta .output_model .name
451
+ ),
452
+ )
453
+ if job_meta .output_model and job_meta .output_model_version
454
+ else None
455
+ )
437
456
438
457
yield from MetadataChangeProposalWrapper .construct_many (
439
458
job_urn ,
@@ -455,9 +474,22 @@ def _gen_training_job_mcps(
455
474
SubTypesClass (typeNames = [MLAssetSubTypes .VERTEX_TRAINING_JOB ]),
456
475
ContainerClass (container = self ._get_project_container ().as_urn ()),
457
476
DataPlatformInstanceClass (platform = str (DataPlatformUrn (self .platform ))),
458
- DataProcessInstanceInputClass (inputs = [dataset_urn ])
477
+ DataProcessInstanceInputClass (
478
+ inputs = [],
479
+ inputEdges = [
480
+ EdgeClass (destinationUrn = dataset_urn ),
481
+ ],
482
+ )
459
483
if dataset_urn
460
484
else None ,
485
+ DataProcessInstanceOutputClass (
486
+ outputs = [],
487
+ outputEdges = [
488
+ EdgeClass (destinationUrn = model_urn ),
489
+ ],
490
+ )
491
+ if model_urn
492
+ else None ,
461
493
],
462
494
)
463
495
@@ -593,7 +625,7 @@ def _get_input_dataset_mcps(
593
625
ContainerClass (container = self ._get_project_container ().as_urn ()),
594
626
DataPlatformInstanceClass (
595
627
platform = str (DataPlatformUrn (self .platform ))
596
- ),
628
+ )
597
629
],
598
630
)
599
631
@@ -770,7 +802,9 @@ def _gen_ml_model_mcps(
770
802
versionTag = str (model_version .version_id ),
771
803
metadataAttribution = (
772
804
MetadataAttributionClass (
773
- time = int (model_version .version_create_time .timestamp () * 1000 ),
805
+ time = int (
806
+ model_version .version_create_time .timestamp () * 1000
807
+ ),
774
808
actor = "urn:li:corpuser:datahub" ,
775
809
)
776
810
if model_version .version_create_time
0 commit comments