Skip to content

Commit 5c1682a

Browse files
agunapalUbuntuUbuntu
authored
TorchServe linux-aarch64 experimental support (#3071)
* Changes for building TorchServe on linux aarch64 * Changes for building TorchServe on linux aarch64 * Added an example for linux aarch64 * Doc update for linux aarch64 * Doc update for linux aarch64 * Doc update for linux aarch64 * removed torchtext for aarch64 * lint failure * lint failure * Build conda binaries * Build conda binaries * resolving merge conflicts * resolving merge conflicts * update documentation * review comments * Updated based on review comments --------- Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Ubuntu <[email protected]>
1 parent a69e561 commit 5c1682a

18 files changed

+202
-5
lines changed

binaries/conda/build_packages.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
PACKAGES = ["torchserve", "torch-model-archiver", "torch-workflow-archiver"]
2323

2424
# conda convert supported platforms https://docs.conda.io/projects/conda-build/en/stable/resources/commands/conda-convert.html
25-
PLATFORMS = ["linux-64", "osx-64", "win-64", "osx-arm64"] # Add a new platform here
25+
PLATFORMS = [
26+
"linux-64",
27+
"osx-64",
28+
"win-64",
29+
"osx-arm64",
30+
"linux-aarch64",
31+
] # Add a new platform here
2632

2733
if os.name == "nt":
2834
# Assumes miniconda is installed in windows

docs/linux_aarch64.md

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# TorchServe on linux aarch64 - Experimental
2+
3+
TorchServe has been tested to be working on linux aarch64 for some of the examples.
4+
- Tested this on Amazon Graviton 3 instance(m7g.4x.large)
5+
6+
## Installation
7+
8+
Currently installation from PyPi or installing from source works
9+
10+
```
11+
python ts_scripts/install_dependencies.py
12+
pip install torchserve torch-model-archiver torch-workflow-archiver
13+
```
14+
15+
## Optimizations
16+
17+
You can also enable this optimizations for Graviton 3 to get an improved performance. More details can be found in this [blog](https://pytorch.org/blog/optimized-pytorch-w-graviton/)
18+
```
19+
export DNNL_DEFAULT_FPMATH_MODE=BF16
20+
export LRU_CACHE_CAPACITY=1024
21+
```
22+
23+
## Example
24+
25+
This [example](https://github.com/pytorch/serve/tree/master/examples/text_to_speech_synthesizer/SpeechT5) on Text to Speech synthesis was verified to be working on Graviton 3
26+
27+
## To Dos
28+
- CI
29+
- Regression tests
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Text to Speech synthesis with SpeechT5
2+
3+
This is an example showing text to speech synthesis using SpeechT5 model. This has been verified to work on (linux-aarch64) Graviton 3 instance
4+
5+
While running this model on `linux-aarch64`, you can enable these optimizations
6+
7+
```
8+
export DNNL_DEFAULT_FPMATH_MODE=BF16
9+
export LRU_CACHE_CAPACITY=1024
10+
```
11+
More details can be found in this [blog](https://pytorch.org/blog/optimized-pytorch-w-graviton/)
12+
13+
14+
## Pre-requisites
15+
```
16+
chmod +x setup.sh
17+
./setup.sh
18+
```
19+
20+
## Download model
21+
22+
This saves the model artifacts to `model_artifacts` directory
23+
```
24+
huggingface-cli login
25+
python download_model.py
26+
```
27+
28+
## Create model archiver
29+
30+
```
31+
mkdir model_store
32+
33+
torch-model-archiver --model-name SpeechT5-TTS --version 1.0 --handler text_to_speech_handler.py --config-file model-config.yaml --archive-format no-archive --export-path model_store -f
34+
35+
mv model_artifacts/* model_store/SpeechT5-TTS/
36+
```
37+
38+
## Start TorchServe
39+
40+
```
41+
torchserve --start --ncs --model-store model_store --models SpeechT5-TTS
42+
```
43+
44+
## Send Inference request
45+
46+
```
47+
curl http://127.0.0.1:8080/predictions/SpeechT5-TTS -T sample_input.txt -o speech.wav
48+
```
49+
50+
This generates an audio file `speech.wav` corresponding to the text in `sample_input.txt`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from datasets import load_dataset
2+
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
3+
4+
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
5+
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
6+
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
7+
8+
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
9+
10+
model.save_pretrained(save_directory="model_artifacts/model")
11+
processor.save_pretrained(save_directory="model_artifacts/processor")
12+
vocoder.save_pretrained(save_directory="model_artifacts/vocoder")
13+
embeddings_dataset.save_to_disk("model_artifacts/speaker_embeddings")
14+
print("Save model artifacts to directory model_artifacts")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
minWorkers: 1
2+
maxWorkers: 1
3+
handler:
4+
model: "model"
5+
vocoder: "vocoder"
6+
processor: "processor"
7+
speaker_embeddings: "speaker_embeddings"
8+
output_dir: "/tmp"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"I love San Francisco"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
3+
# Needed for soundfile
4+
sudo apt install libsndfile1 -y
5+
6+
pip install --upgrade transformers sentencepiece datasets[audio] soundfile
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import logging
2+
import os
3+
import uuid
4+
5+
import soundfile as sf
6+
import torch
7+
from datasets import load_from_disk
8+
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
9+
10+
from ts.torch_handler.base_handler import BaseHandler
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class SpeechT5_TTS(BaseHandler):
16+
def __init__(self):
17+
self.model = None
18+
self.processor = None
19+
self.vocoder = None
20+
self.speaker_embeddings = None
21+
self.output_dir = "/tmp"
22+
23+
def initialize(self, ctx):
24+
properties = ctx.system_properties
25+
model_dir = properties.get("model_dir")
26+
27+
processor = ctx.model_yaml_config["handler"]["processor"]
28+
model = ctx.model_yaml_config["handler"]["model"]
29+
vocoder = ctx.model_yaml_config["handler"]["vocoder"]
30+
embeddings_dataset = ctx.model_yaml_config["handler"]["speaker_embeddings"]
31+
self.output_dir = ctx.model_yaml_config["handler"]["output_dir"]
32+
33+
self.processor = SpeechT5Processor.from_pretrained(processor)
34+
self.model = SpeechT5ForTextToSpeech.from_pretrained(model)
35+
self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder)
36+
37+
# load xvector containing speaker's voice characteristics from a dataset
38+
embeddings_dataset = load_from_disk(embeddings_dataset)
39+
self.speaker_embeddings = torch.tensor(
40+
embeddings_dataset[7306]["xvector"]
41+
).unsqueeze(0)
42+
43+
def preprocess(self, requests):
44+
assert len(requests) == 1, "This is currently supported with batch_size=1"
45+
req_data = requests[0]
46+
47+
input_data = req_data.get("data") or req_data.get("body")
48+
49+
if isinstance(input_data, (bytes, bytearray)):
50+
input_data = input_data.decode("utf-8")
51+
52+
inputs = self.processor(text=input_data, return_tensors="pt")
53+
54+
return inputs
55+
56+
def inference(self, inputs):
57+
output = self.model.generate_speech(
58+
inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder
59+
)
60+
return output
61+
62+
def postprocess(self, inference_output):
63+
path = self.output_dir + "/{}.wav".format(uuid.uuid4().hex)
64+
sf.write(path, inference_output.numpy(), samplerate=16000)
65+
with open(path, "rb") as output:
66+
data = output.read()
67+
os.remove(path)
68+
return [data]

requirements/developer.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pre-commit==3.3.2
1515
twine==4.0.2
1616
mypy==1.3.0
1717
torchpippy==0.1.1
18-
intel_extension_for_pytorch==2.2.0; sys_platform != 'win32' and sys_platform != 'darwin'
18+
intel_extension_for_pytorch==2.2.0; sys_platform != 'win32' and sys_platform != 'darwin' and platform_machine != 'aarch64'
1919
onnxruntime==1.17.1
2020
googleapis-common-protos
2121
onnx==1.16.0

requirements/torch_linux_aarch64.txt

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
2+
--extra-index-url https://download.pytorch.org/whl/cpu
3+
-r torch_common.txt
4+
torch==2.2.1; sys_platform == 'linux' and platform_machine == 'aarch64'
5+
torchvision==0.17.1; sys_platform == 'linux' and platform_machine == 'aarch64'
6+
torchaudio==2.2.1; sys_platform == 'linux' and platform_machine == 'aarch64'

ts_scripts/install_dependencies.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,14 @@ def install_torch_packages(self, cuda_version):
118118
f"{sys.executable} -m pip install -U -r {torch_neuronx_requirements_file}"
119119
)
120120
else:
121-
os.system(
122-
f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}.txt"
123-
)
121+
if platform.machine() == "aarch64":
122+
os.system(
123+
f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}_{platform.machine()}.txt"
124+
)
125+
else:
126+
os.system(
127+
f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}.txt"
128+
)
124129

125130
def install_python_packages(self, cuda_version, requirements_file_path, nightly):
126131
check = "where" if platform.system() == "Windows" else "which"

ts_scripts/spellcheck_conf/wordlist.txt

+4
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,10 @@ libomp
12161216
rpath
12171217
venv
12181218
TorchInductor
1219+
Graviton
1220+
aarch
1221+
linux
1222+
SpeechT
12191223
Pytests
12201224
deviceType
12211225
XGBoost

0 commit comments

Comments
 (0)