Skip to content

Commit a797f8c

Browse files
authored
Diffusion Fast Example (#2902)
* initial commit for Diffusion Fast * code cleanup * lint failure * Integrated with new api ..load_pipeline * Integrated with new api ..load_pipeline * Integrated with new api ..load_pipeline * spellcheck * updated based on review comments
1 parent cfb4285 commit a797f8c

File tree

7 files changed

+258
-0
lines changed

7 files changed

+258
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
from diffusers import DiffusionPipeline
3+
4+
pipeline = DiffusionPipeline.from_pretrained(
5+
"stabilityai/stable-diffusion-xl-base-1.0",
6+
torch_dtype=torch.float32,
7+
use_safetensors=True,
8+
)
9+
pipeline.save_pretrained("./Base_Diffusion_model")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
2+
## Diffusion-Fast
3+
4+
[Diffusion fast](https://github.com/huggingface/diffusion-fast) is a simple and efficient pytorch-native way of optimizing Stable Diffusion XL (SDXL).
5+
6+
It features:
7+
* Running with the bfloat16 precision
8+
* scaled_dot_product_attention (SDPA)
9+
* torch.compile
10+
* Combining q,k,v projections for attention computation
11+
* Dynamic int8 quantization
12+
13+
Details about the optimizations and various results can be found in this [blog](https://pytorch.org/blog/accelerating-generative-ai-3/).
14+
The example has been tested on A10, A100 as well as H100.
15+
16+
17+
#### Pre-requisites
18+
19+
`cd` to the example folder `examples/image_generation/diffusion_fast`
20+
21+
Install dependencies and upgrade torch to nightly build (currently required)
22+
```
23+
git clone https://github.com/huggingface/diffusion-fast.git
24+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed -y
25+
pip install accelerate transformers peft
26+
pip install --no-cache-dir git+https://github.com/pytorch-labs/ao@54bcd5a10d0abbe7b0c045052029257099f83fd9
27+
pip install pandas matplotlib seaborn
28+
```
29+
### Step 1: Download the Stable diffusion model
30+
31+
```bash
32+
python Download_model.py
33+
```
34+
This saves the model in `Base_Diffusion_model`
35+
36+
### Step 1: Generate model archive
37+
At this stage we're creating the model archive which includes the configuration of our model in [model_config.yaml](./model_config.yaml).
38+
It's also the point where we need to decide if we want to deploy our model on a single or multiple GPUs.
39+
For the single GPU case we can use the default configuration that can be found in [model_config.yaml](./model_config.yaml).
40+
41+
```
42+
torch-model-archiver --model-name diffusion_fast --version 1.0 --handler diffusion_fast_handler.py --config-file model_config.yaml --extra-files "diffusion-fast/utils/pipeline_utils.py" --archive-format no-archive
43+
mv Base_Diffusion_model diffusion_fast/
44+
```
45+
46+
### Step 2: Add the model archive to model store
47+
48+
```
49+
mkdir model_store
50+
mv diffusion_fast model_store
51+
```
52+
53+
### Step 3: Start torchserve
54+
55+
```
56+
torchserve --start --ts-config config.properties --model-store model_store --models diffusion_fast
57+
```
58+
59+
### Step 4: Run inference
60+
61+
```
62+
python query.py --url "http://localhost:8080/predictions/diffusion_fast" --prompt "a photo of an astronaut riding a horse on mars"
63+
```
64+
The image generated will be written to a file `output-<>.jpg`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
inference_address=http://127.0.0.1:8080
2+
management_address=http://127.0.0.1:8081
3+
metrics_address=http://127.0.0.1:8082
4+
max_response_size=655350000
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import logging
2+
import os
3+
from pathlib import Path
4+
5+
import numpy as np
6+
import torch
7+
from pipeline_utils import load_pipeline
8+
9+
from ts.handler_utils.timer import timed
10+
from ts.torch_handler.base_handler import BaseHandler
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class DiffusionFastHandler(BaseHandler):
16+
"""
17+
Diffusion-Fast handler class for text to image generation.
18+
"""
19+
20+
def __init__(self):
21+
super().__init__()
22+
self.initialized = False
23+
24+
def initialize(self, ctx):
25+
"""In this initialize function, the Diffusion Fast model is loaded and
26+
initialized here.
27+
Args:
28+
ctx (context): It is a JSON Object containing information
29+
pertaining to the model artifacts parameters.
30+
"""
31+
self.context = ctx
32+
self.manifest = ctx.manifest
33+
properties = ctx.system_properties
34+
model_dir = properties.get("model_dir")
35+
36+
if torch.cuda.is_available() and properties.get("gpu_id") is not None:
37+
self.map_location = "cuda"
38+
self.device = torch.device(
39+
self.map_location + ":" + str(properties.get("gpu_id"))
40+
)
41+
else:
42+
self.map_location = "cpu"
43+
self.device = torch.device(self.map_location)
44+
45+
self.num_inference_steps = ctx.model_yaml_config["handler"][
46+
"num_inference_steps"
47+
]
48+
49+
# Parameters for the model
50+
compile_unet = ctx.model_yaml_config["handler"]["compile_unet"]
51+
compile_vae = ctx.model_yaml_config["handler"]["compile_vae"]
52+
compile_mode = ctx.model_yaml_config["handler"]["compile_mode"]
53+
enable_fused_projections = ctx.model_yaml_config["handler"][
54+
"enable_fused_projections"
55+
]
56+
do_quant = ctx.model_yaml_config["handler"]["do_quant"]
57+
change_comp_config = ctx.model_yaml_config["handler"]["change_comp_config"]
58+
no_sdpa = ctx.model_yaml_config["handler"]["no_sdpa"]
59+
no_bf16 = ctx.model_yaml_config["handler"]["no_bf16"]
60+
upcast_vae = ctx.model_yaml_config["handler"]["upcast_vae"]
61+
62+
# Load model weights
63+
model_path = Path(ctx.model_yaml_config["handler"]["model_path"])
64+
ckpt = os.path.join(model_dir, model_path)
65+
66+
self.pipeline = load_pipeline(
67+
ckpt=ckpt,
68+
compile_unet=compile_unet,
69+
compile_vae=compile_vae,
70+
compile_mode=compile_mode,
71+
enable_fused_projections=enable_fused_projections,
72+
do_quant=do_quant,
73+
change_comp_config=change_comp_config,
74+
no_bf16=no_bf16,
75+
no_sdpa=no_sdpa,
76+
upcast_vae=upcast_vae,
77+
)
78+
79+
logger.info("Diffusion Fast model loaded successfully")
80+
81+
self.initialized = True
82+
83+
@timed
84+
def preprocess(self, requests):
85+
"""Basic text preprocessing, of the user's prompt.
86+
Args:
87+
requests (str): The Input data in the form of text is passed on to the preprocess
88+
function.
89+
Returns:
90+
list : The preprocess function returns a list of prompts.
91+
"""
92+
93+
assert (
94+
len(requests) == 1
95+
), "Diffusion Fast is currently only supported with batch_size=1"
96+
97+
inputs = []
98+
for _, data in enumerate(requests):
99+
input_text = data.get("data")
100+
if input_text is None:
101+
input_text = data.get("body")
102+
if isinstance(input_text, (bytes, bytearray)):
103+
input_text = input_text.decode("utf-8")
104+
inputs.append(input_text)
105+
return inputs
106+
107+
@timed
108+
def inference(self, inputs):
109+
"""Generates the image relevant to the received text.
110+
Args:
111+
input_batch (list): List of Text from the pre-process function is passed here
112+
Returns:
113+
list : It returns a list of the generate images for the input text
114+
"""
115+
# Handling inference for sequence_classification.
116+
inferences = self.pipeline(
117+
inputs, num_inference_steps=self.num_inference_steps, height=768, width=768
118+
).images
119+
120+
return inferences
121+
122+
@timed
123+
def postprocess(self, inference_output):
124+
"""Post Process Function converts the generated image into Torchserve readable format.
125+
Args:
126+
inference_output (list): It contains the generated image of the input text.
127+
Returns:
128+
(list): Returns a list of the images.
129+
"""
130+
images = []
131+
for image in inference_output:
132+
images.append(np.array(image).tolist())
133+
return images
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
minWorkers: 1
2+
maxWorkers: 1
3+
maxBatchDelay: 200
4+
responseTimeout: 3600
5+
deviceType: "gpu"
6+
handler:
7+
model_path: "Base_Diffusion_model"
8+
num_inference_steps: 30
9+
compile_unet: true
10+
compile_mode: "max-autotune"
11+
compile_vae: true
12+
enable_fused_projections: true
13+
do_quant: "int8dynamic"
14+
change_comp_config: true
15+
no_sdpa: false
16+
no_bf16: false
17+
upcast_vae: false
18+
profile: true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import argparse
2+
import json
3+
from datetime import datetime
4+
5+
import numpy as np
6+
import requests
7+
from PIL import Image
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument(
11+
"--url", type=str, required=True, help="Torchserve inference endpoint"
12+
)
13+
parser.add_argument(
14+
"--prompt", type=str, required=True, help="Prompt for image generation"
15+
)
16+
parser.add_argument(
17+
"--filename",
18+
type=str,
19+
default="output-{}.jpg".format(str(datetime.now().strftime("%Y%m%d%H%M%S"))),
20+
help="Filename of output image",
21+
)
22+
args = parser.parse_args()
23+
24+
response = requests.post(args.url, data=args.prompt)
25+
# Contruct image from response
26+
image = Image.fromarray(np.array(json.loads(response.text), dtype="uint8"))
27+
image.save(args.filename)

ts_scripts/spellcheck_conf/wordlist.txt

+3
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,9 @@ compilable
11641164
nightlies
11651165
torchexportaotcompile
11661166
autotune
1167+
SDXL
1168+
SDPA
1169+
bfloat
11671170
bb
11681171
babyllama
11691172
libbabyllama

0 commit comments

Comments
 (0)