Skip to content

Commit da8f822

Browse files
authored
feat(ingest/mlflow): Support configurable base_external_url (#12167)
1 parent 83904b7 commit da8f822

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

metadata-ingestion/src/datahub/ingestion/source/mlflow.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,30 @@
3838
class MLflowConfig(EnvConfigMixin):
3939
tracking_uri: Optional[str] = Field(
4040
default=None,
41-
description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)",
41+
description=(
42+
"Tracking server URI. If not set, an MLflow default tracking_uri is used"
43+
" (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)"
44+
),
4245
)
4346
registry_uri: Optional[str] = Field(
4447
default=None,
45-
description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)",
48+
description=(
49+
"Registry server URI. If not set, an MLflow default registry_uri is used"
50+
" (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)"
51+
),
4652
)
4753
model_name_separator: str = Field(
4854
default="_",
4955
description="A string which separates model name from its version (e.g. model_1 or model-1)",
5056
)
57+
base_external_url: Optional[str] = Field(
58+
default=None,
59+
description=(
60+
"Base URL to use when constructing external URLs to MLflow."
61+
" If not set, tracking_uri is used if it's an HTTP URL."
62+
" If neither is set, external URLs are not generated."
63+
),
64+
)
5165

5266

5367
@dataclass
@@ -279,12 +293,23 @@ def _make_ml_model_urn(self, model_version: ModelVersion) -> str:
279293
)
280294
return urn
281295

282-
def _make_external_url(self, model_version: ModelVersion) -> Union[None, str]:
296+
def _get_base_external_url_from_tracking_uri(self) -> Optional[str]:
297+
if isinstance(
298+
self.client.tracking_uri, str
299+
) and self.client.tracking_uri.startswith("http"):
300+
return self.client.tracking_uri
301+
else:
302+
return None
303+
304+
def _make_external_url(self, model_version: ModelVersion) -> Optional[str]:
283305
"""
284306
Generate URL for a Model Version to MLflow UI.
285307
"""
286-
base_uri = self.client.tracking_uri
287-
if base_uri.startswith("http"):
308+
base_uri = (
309+
self.config.base_external_url
310+
or self._get_base_external_url_from_tracking_uri()
311+
)
312+
if base_uri:
288313
return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}"
289314
else:
290315
return None

metadata-ingestion/tests/unit/test_mlflow_source.py

+13
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,16 @@ def test_make_external_link_remote(source, model_version):
136136
url = source._make_external_url(model_version)
137137

138138
assert url == expected_url
139+
140+
141+
def test_make_external_link_remote_via_config(source, model_version):
142+
custom_base_url = "https://custom-server.org"
143+
source.config.base_external_url = custom_base_url
144+
source.client = MlflowClient(
145+
tracking_uri="https://dummy-mlflow-tracking-server.org"
146+
)
147+
expected_url = f"{custom_base_url}/#/models/{model_version.name}/versions/{model_version.version}"
148+
149+
url = source._make_external_url(model_version)
150+
151+
assert url == expected_url

0 commit comments

Comments
 (0)