|
38 | 38 | class MLflowConfig(EnvConfigMixin):
|
39 | 39 | tracking_uri: Optional[str] = Field(
|
40 | 40 | default=None,
|
41 |
| - description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)", |
| 41 | + description=( |
| 42 | + "Tracking server URI. If not set, an MLflow default tracking_uri is used" |
| 43 | + " (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)" |
| 44 | + ), |
42 | 45 | )
|
43 | 46 | registry_uri: Optional[str] = Field(
|
44 | 47 | default=None,
|
45 |
| - description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)", |
| 48 | + description=( |
| 49 | + "Registry server URI. If not set, an MLflow default registry_uri is used" |
| 50 | + " (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)" |
| 51 | + ), |
46 | 52 | )
|
47 | 53 | model_name_separator: str = Field(
|
48 | 54 | default="_",
|
49 | 55 | description="A string which separates model name from its version (e.g. model_1 or model-1)",
|
50 | 56 | )
|
| 57 | + base_external_url: Optional[str] = Field( |
| 58 | + default=None, |
| 59 | + description=( |
| 60 | + "Base URL to use when constructing external URLs to MLflow." |
| 61 | + " If not set, tracking_uri is used if it's an HTTP URL." |
| 62 | + " If neither is set, external URLs are not generated." |
| 63 | + ), |
| 64 | + ) |
51 | 65 |
|
52 | 66 |
|
53 | 67 | @dataclass
|
@@ -279,12 +293,23 @@ def _make_ml_model_urn(self, model_version: ModelVersion) -> str:
|
279 | 293 | )
|
280 | 294 | return urn
|
281 | 295 |
|
282 |
| - def _make_external_url(self, model_version: ModelVersion) -> Union[None, str]: |
| 296 | + def _get_base_external_url_from_tracking_uri(self) -> Optional[str]: |
| 297 | + if isinstance( |
| 298 | + self.client.tracking_uri, str |
| 299 | + ) and self.client.tracking_uri.startswith("http"): |
| 300 | + return self.client.tracking_uri |
| 301 | + else: |
| 302 | + return None |
| 303 | + |
| 304 | + def _make_external_url(self, model_version: ModelVersion) -> Optional[str]: |
283 | 305 | """
|
284 | 306 | Generate URL for a Model Version to MLflow UI.
|
285 | 307 | """
|
286 |
| - base_uri = self.client.tracking_uri |
287 |
| - if base_uri.startswith("http"): |
| 308 | + base_uri = ( |
| 309 | + self.config.base_external_url |
| 310 | + or self._get_base_external_url_from_tracking_uri() |
| 311 | + ) |
| 312 | + if base_uri: |
288 | 313 | return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}"
|
289 | 314 | else:
|
290 | 315 | return None
|
|
0 commit comments