Skip to content

Commit a3c3b25

Browse files
authored
Merge branch 'main' into add-diffusers-utils
2 parents 0760d75 + 44e3dec commit a3c3b25

File tree

5 files changed

+221
-1
lines changed

5 files changed

+221
-1
lines changed

README.md

+15
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,21 @@ The `HF_TRUST_REMOTE_CODE` environment variable defines wether or not to allow f
117117
HF_TRUST_REMOTE_CODE="True"
118118
```
119119

120+
#### `HF_OPTIMUM_BATCH_SIZE`
121+
122+
The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted.
123+
124+
```bash
125+
HF_OPTIMUM_BATCH_SIZE="1"
126+
```
127+
128+
#### `HF_OPTIMUM_SEQUENCE_LENGTH`
129+
130+
The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted.
131+
132+
```bash
133+
HF_OPTIMUM_SEQUENCE_LENGTH="128"
134+
```
120135

121136
---
122137

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# We don't declare our dependency on transformers here because we build with
3131
# different packages for different variants
3232

33-
VERSION = "2.1.2"
33+
VERSION = "2.3.0.dev0"
3434

3535

3636
# Ubuntu packages
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import importlib.util
16+
import logging
17+
import os
18+
19+
20+
_optimum_neuron = False
21+
if importlib.util.find_spec("optimum") is not None:
22+
if importlib.util.find_spec("optimum.neuron") is not None:
23+
_optimum_neuron = True
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
def is_optimum_neuron_available():
29+
return _optimum_neuron
30+
31+
32+
def get_input_shapes(model_dir):
33+
"""Method to get input shapes from model config file. If config file is not present, default values are returned."""
34+
from transformers import AutoConfig
35+
36+
input_shapes = {}
37+
input_shapes_available = False
38+
# try to get input shapes from config file
39+
try:
40+
config = AutoConfig.from_pretrained(model_dir)
41+
if hasattr(config, "neuron_batch_size") and hasattr(config, "neuron_sequence_length"):
42+
input_shapes["batch_size"] = config.neuron_batch_size
43+
input_shapes["sequence_length"] = config.neuron_sequence_length
44+
input_shapes_available = True
45+
logger.info(
46+
f"Input shapes found in config file. Using input shapes from config with batch size {input_shapes['batch_size']} and sequence length {input_shapes['sequence_length']}"
47+
)
48+
if os.environ.get("HF_OPTIMUM_BATCH_SIZE", None) is not None:
49+
logger.warning(
50+
"HF_OPTIMUM_BATCH_SIZE environment variable is set. Environment variable will be ignored and input shapes from config file will be used."
51+
)
52+
if os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) is not None:
53+
logger.warning(
54+
"HF_OPTIMUM_SEQUENCE_LENGTH environment variable is set. Environment variable will be ignored and input shapes from config file will be used."
55+
)
56+
except Exception:
57+
input_shapes_available = False
58+
59+
# return input shapes if available
60+
if input_shapes_available:
61+
return input_shapes
62+
63+
# extract input shapes from environment variables
64+
sequence_length = os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None)
65+
if not int(sequence_length) > 0:
66+
raise ValueError(
67+
f"HF_OPTIMUM_SEQUENCE_LENGTH must be set to a positive integer. Current value is {sequence_length}"
68+
)
69+
batch_size = os.environ.get("HF_OPTIMUM_BATCH_SIZE", 1)
70+
logger.info(
71+
f"Using input shapes from environment variables with batch size {batch_size} and sequence length {sequence_length}"
72+
)
73+
return {"batch_size": int(batch_size), "sequence_length": int(sequence_length)}
74+
75+
76+
def get_optimum_neuron_pipeline(task, model_dir):
77+
"""Method to get optimum neuron pipeline for a given task. Method checks if task is supported by optimum neuron and if required environment variables are set, in case model is not converted. If all checks pass, optimum neuron pipeline is returned. If checks fail, an error is raised."""
78+
from optimum.neuron.pipelines import NEURONX_SUPPORTED_TASKS, pipeline
79+
from optimum.neuron.utils import NEURON_FILE_NAME
80+
81+
# check task support
82+
if task not in NEURONX_SUPPORTED_TASKS:
83+
raise ValueError(
84+
f"Task {task} is not supported by optimum neuron and inf2. Supported tasks are: {list(NEURONX_SUPPORTED_TASKS.keys())}"
85+
)
86+
87+
# check if model is already converted and has input shapes available
88+
export = True
89+
if NEURON_FILE_NAME in os.listdir(model_dir):
90+
export = False
91+
if export:
92+
logger.info("Model is not converted. Checking if required environment variables are set and converting model.")
93+
94+
# get static input shapes to run inference
95+
input_shapes = get_input_shapes(model_dir)
96+
# get optimum neuron pipeline
97+
neuron_pipe = pipeline(task, model=model_dir, export=export, input_shapes=input_shapes)
98+
99+
return neuron_pipe

src/sagemaker_huggingface_inference_toolkit/transformers_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from transformers.pipelines import Conversation, Pipeline
2525

2626
from sagemaker_huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, is_diffusers_available
27+
from sagemaker_huggingface_inference_toolkit.optimum_utils import is_optimum_neuron_available
2728

2829

2930
if is_tf_available():
@@ -73,6 +74,9 @@ def strtobool(val):
7374
"ckpt": "*ckpt",
7475
}
7576

77+
if is_optimum_neuron_available():
78+
FILE_LIST_NAMES.append("model.neuron")
79+
7680
REPO_ID_SEPARATOR = "__"
7781

7882
ARCHITECTURES_2_TASK = {

tests/unit/test_optimum_utils.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import tempfile
16+
17+
import pytest
18+
from transformers.testing_utils import require_torch
19+
20+
from sagemaker_huggingface_inference_toolkit.optimum_utils import (
21+
get_input_shapes,
22+
get_optimum_neuron_pipeline,
23+
is_optimum_neuron_available,
24+
)
25+
from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub
26+
27+
28+
require_inferentia = pytest.mark.skipif(
29+
not is_optimum_neuron_available(),
30+
reason="Skipping tests, since optimum neuron is not available or not running on inf2 instances.",
31+
)
32+
33+
34+
REMOTE_NOT_CONVERTED_MODEL = "hf-internal-testing/tiny-random-BertModel"
35+
REMOTE_CONVERTED_MODEL = "optimum/tiny_random_bert_neuron"
36+
TASK = "text-classification"
37+
38+
39+
@require_torch
40+
@require_inferentia
41+
def test_not_supported_task():
42+
os.environ["HF_TASK"] = "not-supported-task"
43+
with pytest.raises(Exception):
44+
get_optimum_neuron_pipeline(task=TASK, model_dir=os.getcwd())
45+
46+
47+
@require_torch
48+
@require_inferentia
49+
def test_get_input_shapes_from_file():
50+
with tempfile.TemporaryDirectory() as tmpdirname:
51+
storage_folder = _load_model_from_hub(
52+
model_id=REMOTE_CONVERTED_MODEL,
53+
model_dir=tmpdirname,
54+
)
55+
input_shapes = get_input_shapes(model_dir=storage_folder)
56+
assert input_shapes["batch_size"] == 1
57+
assert input_shapes["sequence_length"] == 16
58+
59+
60+
@require_torch
61+
@require_inferentia
62+
def test_get_input_shapes_from_env():
63+
os.environ["HF_OPTIMUM_BATCH_SIZE"] = "4"
64+
os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = "32"
65+
with tempfile.TemporaryDirectory() as tmpdirname:
66+
storage_folder = _load_model_from_hub(
67+
model_id=REMOTE_NOT_CONVERTED_MODEL,
68+
model_dir=tmpdirname,
69+
)
70+
input_shapes = get_input_shapes(model_dir=storage_folder)
71+
assert input_shapes["batch_size"] == 4
72+
assert input_shapes["sequence_length"] == 32
73+
74+
75+
@require_torch
76+
@require_inferentia
77+
def test_get_optimum_neuron_pipeline_from_converted_model():
78+
with tempfile.TemporaryDirectory() as tmpdirname:
79+
os.system(
80+
f"optimum-cli export neuron --model philschmid/tiny-distilbert-classification --sequence_length 32 --batch_size 1 {tmpdirname}"
81+
)
82+
pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=tmpdirname)
83+
r = pipe("This is a test")
84+
85+
assert r[0]["score"] > 0.0
86+
assert isinstance(r[0]["label"], str)
87+
88+
89+
@require_torch
90+
@require_inferentia
91+
def test_get_optimum_neuron_pipeline_from_non_converted_model():
92+
os.environ["OPTIMUM_NEURON_SEQUENCE_LENGTH"] = "32"
93+
with tempfile.TemporaryDirectory() as tmpdirname:
94+
storage_folder = _load_model_from_hub(
95+
model_id=REMOTE_NOT_CONVERTED_MODEL,
96+
model_dir=tmpdirname,
97+
)
98+
pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=storage_folder)
99+
r = pipe("This is a test")
100+
101+
assert r[0]["score"] > 0.0
102+
assert isinstance(r[0]["label"], str)

0 commit comments

Comments
 (0)