Skip to content

Commit 46215ae

Browse files
nmadanNamrata MadanAo GuoRohan GujarathiZhankuil
authored
Feature: SageMaker Remote Function (#3797)
Co-authored-by: Namrata Madan <[email protected]> Co-authored-by: Ao Guo <[email protected]> Co-authored-by: Rohan Gujarathi <[email protected]> Co-authored-by: Zhankui Lu <[email protected]> Co-authored-by: Dipankar Patro <[email protected]> Co-authored-by: Mourya Baddam <[email protected]>
1 parent ebd48c9 commit 46215ae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+7832
-236
lines changed

.flake8

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ application_import_names = sagemaker, tests
33
import-order-style = google
44
per-file-ignores =
55
tests/unit/test_tuner.py: F405
6+
src/sagemaker/config/config_schema.py: E501

README.rst

+2
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ To run the integration tests, the following prerequisites must be met
133133
1. AWS account credentials are available in the environment for the boto3 client to use.
134134
2. The AWS account has an IAM role named :code:`SageMakerRole`.
135135
It should have the AmazonSageMakerFullAccess policy attached as well as a policy with `the necessary permissions to use Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei-setup.html>`__.
136+
3. To run remote_function tests, dummy ecr repo should be created. It can be created by running -
137+
:code:`aws ecr create-repository --repository-name remote-function-dummy-container`
136138

137139
We recommend selectively running just those integration tests you'd like to run. You can filter by individual test function names with:
138140

requirements/extras/test_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ sagemaker-experiments==0.1.35
2121
Jinja2==3.0.3
2222
pandas>=1.3.5,<1.5
2323
scikit-learn==1.0.2
24+
cloudpickle==2.2.1

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def read_requirements(filename):
4949
required_packages = [
5050
"attrs>=20.3.0,<23",
5151
"boto3>=1.26.28,<2.0",
52+
"cloudpickle==2.2.1",
5253
"google-pasta",
5354
"numpy>=1.9.0,<2.0",
5455
"protobuf>=3.1,<4.0",
@@ -62,6 +63,7 @@ def read_requirements(filename):
6263
"PyYAML==5.4.1",
6364
"jsonschema",
6465
"platformdirs",
66+
"tblib==1.7.0",
6567
]
6668

6769
# Specific use case dependencies

src/sagemaker/config/config_schema.py

+101-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@
4444
SAGEMAKER = "SageMaker"
4545
PYTHON_SDK = "PythonSDK"
4646
MODULES = "Modules"
47+
REMOTE_FUNCTION = "RemoteFunction"
48+
DEPENDENCIES = "Dependencies"
49+
PRE_EXECUTION_SCRIPT = "PreExecutionScript"
50+
PRE_EXECUTION_COMMANDS = "PreExecutionCommands"
51+
ENVIRONMENT_VARIABLES = "EnvironmentVariables"
52+
IMAGE_URI = "ImageUri"
53+
INCLUDE_LOCAL_WORKDIR = "IncludeLocalWorkDir"
54+
INSTANCE_TYPE = "InstanceType"
55+
S3_KMS_KEY_ID = "S3KmsKeyId"
56+
S3_ROOT_URI = "S3RootUri"
57+
JOB_CONDA_ENV = "JobCondaEnvironment"
4758
OFFLINE_STORE_CONFIG = "OfflineStoreConfig"
4859
ONLINE_STORE_CONFIG = "OnlineStoreConfig"
4960
S3_STORAGE_CONFIG = "S3StorageConfig"
@@ -221,6 +232,49 @@ def _simple_path(*args: str):
221232
SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES
222233
)
223234

235+
REMOTE_FUNCTION_DEPENDENCIES = _simple_path(
236+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, DEPENDENCIES
237+
)
238+
REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS = _simple_path(
239+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_COMMANDS
240+
)
241+
REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT = _simple_path(
242+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_SCRIPT
243+
)
244+
REMOTE_FUNCTION_ENVIRONMENT_VARIABLES = _simple_path(
245+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENVIRONMENT_VARIABLES
246+
)
247+
REMOTE_FUNCTION_IMAGE_URI = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, IMAGE_URI)
248+
REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR = _simple_path(
249+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, INCLUDE_LOCAL_WORKDIR
250+
)
251+
REMOTE_FUNCTION_INSTANCE_TYPE = _simple_path(
252+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, INSTANCE_TYPE
253+
)
254+
REMOTE_FUNCTION_JOB_CONDA_ENV = _simple_path(
255+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, JOB_CONDA_ENV
256+
)
257+
REMOTE_FUNCTION_ROLE_ARN = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ROLE_ARN)
258+
REMOTE_FUNCTION_S3_KMS_KEY_ID = _simple_path(
259+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, S3_KMS_KEY_ID
260+
)
261+
REMOTE_FUNCTION_S3_ROOT_URI = _simple_path(
262+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, S3_ROOT_URI
263+
)
264+
REMOTE_FUNCTION_TAGS = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, TAGS)
265+
REMOTE_FUNCTION_VOLUME_KMS_KEY_ID = _simple_path(
266+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VOLUME_KMS_KEY_ID
267+
)
268+
REMOTE_FUNCTION_VPC_CONFIG_SUBNETS = _simple_path(
269+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SUBNETS
270+
)
271+
REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS = _simple_path(
272+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SECURITY_GROUP_IDS
273+
)
274+
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path(
275+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
276+
)
277+
224278
# Paths for reference elsewhere in the SDK.
225279
# Names include the schema version since the paths could change with other schema versions
226280
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
@@ -245,7 +299,6 @@ def _simple_path(*args: str):
245299
SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
246300
)
247301

248-
249302
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
250303
"$schema": "https://json-schema.org/draft/2020-12/schema",
251304
TYPE: OBJECT,
@@ -377,6 +430,23 @@ def _simple_path(*args: str):
377430
"minItems": 0,
378431
"maxItems": 50,
379432
},
433+
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment
434+
"environmentVariables": {
435+
TYPE: OBJECT,
436+
ADDITIONAL_PROPERTIES: False,
437+
PATTERN_PROPERTIES: {
438+
r"([a-zA-Z_][a-zA-Z0-9_]*){1,512}": {
439+
TYPE: "string",
440+
"pattern": r"[\S\s]*",
441+
"maxLength": 512,
442+
}
443+
},
444+
"maxProperties": 48,
445+
},
446+
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri
447+
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
448+
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
449+
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
380450
},
381451
PROPERTIES: {
382452
SCHEMA_VERSION: {
@@ -406,6 +476,36 @@ def _simple_path(*args: str):
406476
# Any SageMaker Python SDK specific configuration will be added here.
407477
TYPE: OBJECT,
408478
ADDITIONAL_PROPERTIES: False,
479+
PROPERTIES: {
480+
REMOTE_FUNCTION: {
481+
TYPE: OBJECT,
482+
ADDITIONAL_PROPERTIES: False,
483+
PROPERTIES: {
484+
DEPENDENCIES: {TYPE: "string"},
485+
PRE_EXECUTION_COMMANDS: {
486+
TYPE: "array",
487+
"items": {"$ref": "#/definitions/preExecutionCommand"},
488+
},
489+
PRE_EXECUTION_SCRIPT: {TYPE: "string"},
490+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {
491+
TYPE: "boolean"
492+
},
493+
ENVIRONMENT_VARIABLES: {
494+
"$ref": "#/definitions/environmentVariables"
495+
},
496+
IMAGE_URI: {TYPE: "string"},
497+
INCLUDE_LOCAL_WORKDIR: {TYPE: "boolean"},
498+
INSTANCE_TYPE: {TYPE: "string"},
499+
JOB_CONDA_ENV: {TYPE: "string"},
500+
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
501+
S3_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
502+
S3_ROOT_URI: {"$ref": "#/definitions/s3Uri"},
503+
TAGS: {"$ref": "#/definitions/tags"},
504+
VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
505+
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
506+
},
507+
}
508+
},
409509
}
410510
},
411511
},

src/sagemaker/experiments/run.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,14 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
715715

716716
self.close()
717717

718+
def __getstate__(self):
719+
"""Overriding this method to prevent instance of Run from being pickled.
720+
721+
Raise:
722+
NotImplementedError: If attempting to pickle this instance.
723+
"""
724+
raise NotImplementedError("Instance of Run type is not allowed to be pickled.")
725+
718726

719727
def load_run(
720728
run_name: Optional[str] = None,
@@ -787,36 +795,38 @@ def load_run(
787795
Returns:
788796
Run: The loaded Run object.
789797
"""
790-
sagemaker_session = sagemaker_session or _utils.default_session()
791798
environment = _RunEnvironment.load()
792799

793800
verify_load_input_names(run_name=run_name, experiment_name=experiment_name)
794801

795-
if run_name or environment:
796-
if run_name:
797-
logger.warning(
798-
"run_name is explicitly supplied in load_run, "
799-
"which will be prioritized to load the Run object. "
800-
"In other words, the run name in the experiment config, fetched from the "
801-
"job environment or the current run context, will be ignored."
802-
)
803-
else:
804-
exp_config = get_tc_and_exp_config_from_job_env(
805-
environment=environment, sagemaker_session=sagemaker_session
806-
)
807-
run_name = Run._extract_run_name_from_tc_name(
808-
trial_component_name=exp_config[RUN_NAME],
809-
experiment_name=exp_config[EXPERIMENT_NAME],
810-
)
811-
experiment_name = exp_config[EXPERIMENT_NAME]
812-
802+
if run_name:
803+
logger.warning(
804+
"run_name is explicitly supplied in load_run, "
805+
"which will be prioritized to load the Run object. "
806+
"In other words, the run name in the experiment config, fetched from the "
807+
"job environment or the current run context, will be ignored."
808+
)
813809
run_instance = Run(
814810
experiment_name=experiment_name,
815811
run_name=run_name,
816-
sagemaker_session=sagemaker_session,
812+
sagemaker_session=sagemaker_session or _utils.default_session(),
817813
)
818814
elif _RunContext.get_current_run():
819815
run_instance = _RunContext.get_current_run()
816+
elif environment:
817+
exp_config = get_tc_and_exp_config_from_job_env(
818+
environment=environment, sagemaker_session=sagemaker_session or _utils.default_session()
819+
)
820+
run_name = Run._extract_run_name_from_tc_name(
821+
trial_component_name=exp_config[RUN_NAME],
822+
experiment_name=exp_config[EXPERIMENT_NAME],
823+
)
824+
experiment_name = exp_config[EXPERIMENT_NAME]
825+
run_instance = Run(
826+
experiment_name=experiment_name,
827+
run_name=run_name,
828+
sagemaker_session=sagemaker_session or _utils.default_session(),
829+
)
820830
else:
821831
raise RuntimeError(
822832
"Failed to load a Run object. "
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"versions": {
3+
"1.0": {
4+
"registries": {
5+
"us-east-2": "429704687514",
6+
"me-south-1": "117516905037",
7+
"us-west-2": "236514542706",
8+
"ca-central-1": "310906938811",
9+
"ap-east-1": "493642496378",
10+
"us-east-1": "081325390199",
11+
"ap-northeast-2": "806072073708",
12+
"eu-west-2": "712779665605",
13+
"ap-southeast-2": "52832661640",
14+
"cn-northwest-1": "390780980154",
15+
"eu-north-1": "243637512696",
16+
"cn-north-1": "390048526115",
17+
"ap-south-1": "394103062818",
18+
"eu-west-3": "615547856133",
19+
"ap-southeast-3": "276181064229",
20+
"af-south-1": "559312083959",
21+
"eu-west-1": "470317259841",
22+
"eu-central-1": "936697816551",
23+
"sa-east-1": "782484402741",
24+
"ap-northeast-3": "792733760839",
25+
"eu-south-1": "592751261982",
26+
"ap-northeast-1": "102112518831",
27+
"us-west-1": "742091327244",
28+
"ap-southeast-1": "492261229750",
29+
"me-central-1": "103105715889",
30+
"us-gov-east-1": "107072934176",
31+
"us-gov-west-1": "107173498710"
32+
},
33+
"repository": "sagemaker-base-python"
34+
}
35+
}
36+
}

src/sagemaker/image_uris.py

+26
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,29 @@ def get_training_image_uri(
663663
container_version=container_version,
664664
training_compiler_config=compiler_config,
665665
)
666+
667+
668+
def get_base_python_image_uri(region, py_version="310") -> str:
669+
"""Retrieves the image URI for base python image.
670+
671+
Args:
672+
region (str): The AWS region to use for image URI.
673+
py_version (str): The python version to use for the image. Can be 310 or 38
674+
Default to 310
675+
676+
Returns:
677+
str: The image URI string.
678+
"""
679+
680+
framework = "sagemaker-base-python"
681+
version = "1.0"
682+
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
683+
config = config_for_framework(framework)
684+
version_config = config["versions"][_version_for_config(version, config)]
685+
686+
registry = _registry_from_region(region, version_config["registries"])
687+
688+
repo = version_config["repository"] + "-" + py_version
689+
repo_and_tag = repo + ":" + version
690+
691+
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag)

src/sagemaker/local/local_session.py

+12
Original file line numberDiff line numberDiff line change
@@ -674,10 +674,22 @@ def _initialize(
674674
self.sagemaker_client = LocalSagemakerClient(self)
675675
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
676676
self.local_mode = True
677+
sagemaker_config = kwargs.get("sagemaker_config", None)
678+
if sagemaker_config:
679+
validate_sagemaker_config(sagemaker_config)
677680

678681
if self.s3_endpoint_url is not None:
679682
self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
680683
self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)
684+
self.sagemaker_config = (
685+
sagemaker_config
686+
if sagemaker_config
687+
else load_sagemaker_config(s3_resource=self.s3_resource)
688+
)
689+
else:
690+
self.sagemaker_config = (
691+
sagemaker_config if sagemaker_config else load_sagemaker_config()
692+
)
681693

682694
sagemaker_config = kwargs.get("sagemaker_config", None)
683695
if sagemaker_config:
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Defines classes and helper methods used in remote function executions."""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.remote_function.client import remote, RemoteExecutor # noqa: F401

0 commit comments

Comments
 (0)