Skip to content

Commit 20cec21

Browse files
committed
Removing temp credential file creation
1 parent e52fc78 commit 20cec21

File tree

4 files changed

+48
-36
lines changed

4 files changed

+48
-36
lines changed

metadata-ingestion/src/datahub/ingestion/source/common/gcp_credentials_config.py

+6
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,9 @@ def create_credential_temp_file(self, project_id: Optional[str] = None) -> str:
5151
cred_json = json.dumps(configs, indent=4, separators=(",", ": "))
5252
fp.write(cred_json.encode())
5353
return fp.name
54+
55+
def to_dict(self, project_id: Optional[str] = None) -> Dict[str, str]:
56+
configs = self.dict()
57+
if project_id:
58+
configs["project_id"] = project_id
59+
return configs

metadata-ingestion/src/datahub/ingestion/source/vertexai/config.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import Any, Optional
1+
from typing import Any, Dict, Optional
22

3-
from pydantic import Field, PrivateAttr
3+
from pydantic import Field
44

55
from datahub.configuration.source_common import EnvConfigMixin
66
from datahub.ingestion.source.common.gcp_credentials_config import GCPCredential
@@ -23,11 +23,10 @@ class VertexAIConfig(EnvConfigMixin):
2323
description=("VertexUI URI"),
2424
)
2525

26-
_credentials_path: Optional[str] = PrivateAttr(None)
27-
2826
def __init__(self, **data: Any):
2927
super().__init__(**data)
28+
29+
def get_credentials(self) -> Optional[Dict[str, str]]:
3030
if self.credential:
31-
self._credentials_path = self.credential.create_credential_temp_file(
32-
project_id=self.project_id
33-
)
31+
return self.credential.to_dict(self.project_id)
32+
return None

metadata-ingestion/src/datahub/ingestion/source/vertexai/vertexai.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,10 @@ def __init__(self, ctx: PipelineContext, config: VertexAIConfig):
113113
self.config = config
114114
self.report = SourceReport()
115115

116+
creds = self.config.get_credentials()
116117
credentials = (
117-
service_account.Credentials.from_service_account_file(
118-
self.config._credentials_path
119-
)
120-
if self.config.credential
118+
service_account.Credentials.from_service_account_info(**creds)
119+
if creds
121120
else None
122121
)
123122

@@ -294,6 +293,16 @@ def _gen_run_execution(
294293
upstreamInstances=[self._make_experiment_run_urn(exp, run)],
295294
parentInstance=self._make_experiment_run_urn(exp, run),
296295
),
296+
(
297+
DataProcessInstanceInputClass(
298+
inputs=[],
299+
inputEdges=[
300+
EdgeClass(
301+
destinationUrn=self._make_experiment_run_urn(exp, run)
302+
),
303+
],
304+
)
305+
),
297306
],
298307
)
299308

@@ -342,7 +351,9 @@ def _gen_experiment_run_mcps(
342351
(
343352
DataProcessInstanceRunEventClass(
344353
status=DataProcessRunStatusClass.COMPLETE,
345-
timestampMillis=created_time if created_time else 0, # None is not allowed, 0 as default value
354+
timestampMillis=created_time
355+
if created_time
356+
else 0, # None is not allowed, 0 as default value
346357
result=DataProcessInstanceRunResultClass(
347358
type=run_result_type,
348359
nativeResultType=self.platform,

metadata-ingestion/tests/unit/test_vertexai_source.py

+20-24
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
import json
32
from typing import List
43
from unittest.mock import patch
54

@@ -110,7 +109,6 @@ def test_get_ml_model_properties_mcps(source: VertexAISource) -> None:
110109

111110
# Run _gen_ml_model_mcps
112111
actual_mcps = list(source._gen_ml_model_mcps(model_meta))
113-
114112
actual_urns = [mcp.entityUrn for mcp in actual_mcps]
115113
expected_urns = [
116114
source._make_ml_model_urn(
@@ -362,31 +360,29 @@ def test_vertexai_config_init():
362360
assert config.credential.client_id == "test-client-id"
363361
assert config.credential.auth_uri == "https://accounts.google.com/o/oauth2/auth"
364362
assert config.credential.token_uri == "https://oauth2.googleapis.com/token"
363+
assert config.credential.auth_provider_x509_cert_url == "service_account"
364+
365+
parsed_conf = config.get_credentials()
366+
assert parsed_conf is not None
367+
assert parsed_conf.get("project_id") == config_data["project_id"]
368+
assert "credential" in config_data
369+
assert parsed_conf.get("private_key_id", "") == "test-key-id"
370+
assert (
371+
parsed_conf.get("private_key", "")
372+
== "-----BEGIN PRIVATE KEY-----\ntest-private-key\n-----END PRIVATE KEY-----\n"
373+
)
374+
assert (
375+
parsed_conf.get("client_email")
376+
377+
)
378+
assert parsed_conf.get("client_id") == "test-client-id"
379+
assert parsed_conf.get("auth_uri") == "https://accounts.google.com/o/oauth2/auth"
380+
assert parsed_conf.get("token_uri") == "https://oauth2.googleapis.com/token"
365381
assert (
366-
config.credential.auth_provider_x509_cert_url
382+
parsed_conf.get("auth_provider_x509_cert_url")
367383
== "https://www.googleapis.com/oauth2/v1/certs"
368384
)
369-
370-
assert config._credentials_path is not None
371-
with open(config._credentials_path, "r") as file:
372-
content = json.loads(file.read())
373-
assert content["project_id"] == "test-project"
374-
assert content["private_key_id"] == "test-key-id"
375-
assert content["private_key_id"] == "test-key-id"
376-
assert (
377-
content["private_key"]
378-
== "-----BEGIN PRIVATE KEY-----\ntest-private-key\n-----END PRIVATE KEY-----\n"
379-
)
380-
assert (
381-
content["client_email"] == "[email protected]"
382-
)
383-
assert content["client_id"] == "test-client-id"
384-
assert content["auth_uri"] == "https://accounts.google.com/o/oauth2/auth"
385-
assert content["token_uri"] == "https://oauth2.googleapis.com/token"
386-
assert (
387-
content["auth_provider_x509_cert_url"]
388-
== "https://www.googleapis.com/oauth2/v1/certs"
389-
)
385+
assert parsed_conf.get("type") == "service_account"
390386

391387

392388
def test_get_input_dataset_mcps(source: VertexAISource) -> None:

0 commit comments

Comments
 (0)