|
1 | 1 | import argparse
|
2 |
| - |
3 | 2 | from mlflow_dh_client import MLflowDatahubClient
|
4 |
| - |
5 | 3 | import datahub.metadata.schema_classes as models
|
6 | 4 | from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType
|
7 | 5 |
|
|
14 | 12 | client = MLflowDatahubClient(token=args.token)
|
15 | 13 |
|
16 | 14 | # Create model group
|
17 |
| - # Using property classes directly |
18 | 15 | model_group_urn = client.create_model_group(
|
19 |
| - group_id="airline_forecast_models_group_4", |
| 16 | + group_id="airline_forecast_models_group", |
20 | 17 | properties=models.MLModelGroupPropertiesClass(
|
21 |
| - name="Airline Forecast Models Group 4", |
| 18 | + name="Airline Forecast Models Group", |
22 | 19 | description="Group of models for airline passenger forecasting",
|
23 | 20 | created=models.TimeStampClass(
|
24 | 21 | time=1628580000000, actor="urn:li:corpuser:datahub"
|
|
28 | 25 |
|
29 | 26 | # Creating a model with property classes
|
30 | 27 | model_urn = client.create_model(
|
31 |
| - model_id="arima_model_5", |
| 28 | + model_id="arima_model", |
32 | 29 | properties=models.MLModelPropertiesClass(
|
33 |
| - name="ARIMA Model 6", |
| 30 | + name="ARIMA Model", |
34 | 31 | description="ARIMA model for airline passenger forecasting",
|
35 | 32 | customProperties={"team": "forecasting"},
|
| 33 | + trainingMetrics=[ |
| 34 | + models.MLMetricClass(name="accuracy", value="0.9"), |
| 35 | + models.MLMetricClass(name="precision", value="0.8"), |
| 36 | + ], |
| 37 | + hyperParams=[ |
| 38 | + models.MLHyperParamClass(name="learning_rate", value="0.01"), |
| 39 | + models.MLHyperParamClass(name="batch_size", value="32"), |
| 40 | + ], |
| 41 | + externalUrl="https:localhost:5000", |
| 42 | + created=models.TimeStampClass( |
| 43 | + time=1628580000000, actor="urn:li:corpuser:datahub" |
| 44 | + ), |
| 45 | + lastModified=models.TimeStampClass( |
| 46 | + time=1628580000000, actor="urn:li:corpuser:datahub" |
| 47 | + ), |
| 48 | + tags=["forecasting", "arima"], |
36 | 49 | ),
|
37 |
| - version="6.0", |
38 |
| - alias="arima_model_6_alias", |
| 50 | + version="1.0", |
| 51 | + alias="champion", |
39 | 52 | )
|
40 | 53 |
|
41 | 54 | # Creating an experiment with property class
|
|
45 | 58 | name="Airline Forecast Experiment",
|
46 | 59 | description="Experiment to forecast airline passenger numbers",
|
47 | 60 | customProperties={"team": "forecasting"},
|
| 61 | + created=models.TimeStampClass( |
| 62 | + time=1628580000000, actor="urn:li:corpuser:datahub" |
| 63 | + ), |
| 64 | + lastModified=models.TimeStampClass( |
| 65 | + time=1628580000000, actor="urn:li:corpuser:datahub" |
| 66 | + ), |
48 | 67 | ),
|
49 | 68 | )
|
50 | 69 |
|
|
55 | 74 | created=models.AuditStampClass(
|
56 | 75 | time=1628580000000, actor="urn:li:corpuser:datahub"
|
57 | 76 | ),
|
| 77 | + customProperties={"team": "forecasting"}, |
58 | 78 | ),
|
59 | 79 | training_run_properties=models.MLTrainingRunPropertiesClass(
|
60 | 80 | id="simple_training_run_4",
|
61 | 81 | outputUrls=["s3://my-bucket/output"],
|
62 | 82 | trainingMetrics=[models.MLMetricClass(name="accuracy", value="0.9")],
|
| 83 | + hyperParams=[models.MLHyperParamClass(name="learning_rate", value="0.01")], |
| 84 | + externalUrl="https:localhost:5000", |
63 | 85 | ),
|
64 | 86 | run_result=RunResultType.FAILURE,
|
65 | 87 | start_timestamp=1628580000000,
|
|
0 commit comments