Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enabling Intel(R) Extension for PyTorch* #16

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
## Global Args #################################################################
ARG BASE_UBI_IMAGE_TAG=9.3-1476
ARG PROTOC_VERSION=25.1
#ARG PYTORCH_INDEX="https://download.pytorch.org/whl"
ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly"
ARG PYTORCH_VERSION=2.3.0.dev20231221
ARG PYTORCH_INDEX="https://download.pytorch.org/whl"
ARG PYTORCH_VERSION=2.1.0
# ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly"
# ARG PYTORCH_VERSION=2.3.0.dev20231221
ARG IPEX_INDEX="https://pytorch-extension.intel.com/release-whl/stable/cpu/us/"
ARG IPEX_VERSION=2.1.100


## Base Layer ##################################################################
FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} as base
Expand Down Expand Up @@ -148,6 +152,7 @@ WORKDIR /usr/src

# Install specific version of torch
RUN pip install torch=="$PYTORCH_VERSION+cpu" --index-url "${PYTORCH_INDEX}/cpu" --no-cache-dir
RUN pip install intel-extension-for-pytorch=="$IPEX_VERSION" --extra-index-url "${IPEX_INDEX}" --no-cache-dir

COPY server/Makefile server/Makefile

Expand All @@ -174,6 +179,8 @@ RUN cd integration_tests && make install
FROM cuda-devel as python-builder
ARG PYTORCH_INDEX
ARG PYTORCH_VERSION
ARG IPEX_INDEX
ARG IPEX_VERSION

RUN dnf install -y unzip git ninja-build && dnf clean all

Expand All @@ -187,6 +194,7 @@ ENV PATH=/opt/miniconda/bin:$PATH
# Install specific version of torch
RUN pip install ninja==1.11.1.1 --no-cache-dir
RUN pip install torch==$PYTORCH_VERSION+cu118 --index-url "${PYTORCH_INDEX}/cu118" --no-cache-dir
RUN pip install intel-extension-for-pytorch~="$IPEX_VERSION" --extra-index-url "${IPEX_INDEX}" --no-cache-dir


## Build flash attention v2 ####################################################
Expand Down Expand Up @@ -241,6 +249,14 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build /usr/sr
FROM base as flash-att-v2-cache
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build /usr/src/flash-attention-v2/build

## Setup environment variables for performance on Xeon
ENV KMP_BLOCKTIME=INF
ENV KMP_TPAUSE=0
ENV KMP_SETTINGS=1
ENV KMP_AFFINITY=granularity=fine,compact,1,0
ENV KMP_FORJOIN_BARRIER_PATTERN=dist,dist
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist

## Final Inference Server image ################################################
FROM cuda-runtime as server-release
Expand Down
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,53 @@ They are all prefixed with `tgi_`. Descriptions will be added to the table below
| `tgi_tokenize_request_input_count` | `counter` | | Count of tokenize request inputs (batch of n counts as n) |
| `tgi_tokenize_request_tokens` | `histogram` | | Count of tokenized tokens per tokenize request |
| `tgi_tokenize_request_duration` | `histogram` | | Tokenize request duration (in seconds) |

### Run Inference Locally with Intel(R) Extension for PyTorch*

#### 0. Build the image

```
make build
```

This command will print the Docker image id for `text-gen-server`. Set `IMAGE_ID` in the commands below to this.

#### 1. Run the server

```
export IMAGE_ID=<image_id>
export MODEL=<model>
export volume=$PWD/data
mkdir $volume
chmod 777 volume
```

It's possible to use `text-generation-server download-weights`, but in this example we use a model that we download locally with `transformers-cli`.

```
transformers-cli download $MODEL
```

Move model from `~/.cache/huggingface/hub/` to `$volume` You can then run the inference server with:

```
docker run -p 8033:8033 -p 3000:3000 -e TRANSFORMERS_CACHE=/data -e HUGGINGFACE_HUB_CACHE=/data -e DEPLOYMENT_FRAMEWORK=hf_transformers_ipex -e MODEL_NAME=$MODEL -v $volume:/data $IMAGE_ID text-generation-launcher --dtype-str bfloat16
```

#### 2. Prepare the client

Install GRPC in a Python environment: `pip install grpcio grpcio-tools`

In the repository root, run:
```
python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generate.proto
python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generation.proto
```
This generates the necessary files in the pb directory.

Then to run inference:
```
python pb/client.py
```

Edit `pb/client.py` to change the prompts.
33 changes: 33 additions & 0 deletions pb/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
import time

import grpc
import requests
from google.protobuf import json_format

import generation_pb2 as pb2
import generation_pb2_grpc as gpb2

port = 8033
channel = grpc.insecure_channel(f"localhost:{port}")
stub = gpb2.GenerationServiceStub(channel)

# warmup inference
for i in range (5):
text = "hello world"
message = json_format.ParseDict(
{"requests": [{"text": text}]}, pb2.BatchedGenerationRequest()
)
response = stub.Generate(message)

# time inference
for prompt in ["The weather is", "The cat is walking on", "I would like to"]:
# for prompt in ["def hello_world():"]:
message = json_format.ParseDict(
{"requests": [{"text": prompt}]}, pb2.BatchedGenerationRequest()
)
start = time.perf_counter()
response = stub.Generate(message)
end = time.perf_counter()
print(prompt, response)
print(f"Duration: {end-start:.2f}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import torch
import intel_extension_for_pytorch as ipex
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from text_generation_server.inference_engine.engine import BaseInferenceEngine
from text_generation_server.utils.hub import TRUST_REMOTE_CODE
from typing import Any, Optional


class InferenceEngine(BaseInferenceEngine):
def __init__(
self,
model_path: str,
model_class: type[_BaseAutoModelClass],
dtype: torch.dtype,
quantize: Optional[str],
model_config: Optional[Any]
) -> None:
super().__init__(model_path, model_config)

kwargs = {
"pretrained_model_name_or_path": model_path,
"local_files_only": True,
"trust_remote_code": TRUST_REMOTE_CODE,
"torchscript": 'jit',
"torch_dtype": dtype
}

if model_config.model_type == "mpt":
model_config.init_device = str(self.device)
kwargs["config"] = model_config

try:
ipex._C.disable_jit_linear_repack()
except Exception:
pass

torch._C._jit_set_texpr_fuser_enabled(False)

slow_but_exact = os.getenv('BLOOM_SLOW_BUT_EXACT', 'false').lower() == 'true'
if slow_but_exact:
kwargs["slow_but_exact"] = True

with self.device:
self.model = model_class.from_pretrained(**kwargs).requires_grad_(False).eval()

self.model = self.model.to(memory_format=torch.channels_last)
self.model = ipex.optimize_transformers(self.model, dtype=dtype, inplace=True)
print('Intel(R) Extension for PyTorch* enabled')

self.model.to(self.device)
3 changes: 3 additions & 0 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,9 @@ def __init__(
_, past_key_values, _ = self.forward(input_ids=one_token, attention_mask=one_token)
if torch.is_tensor(past_key_values[0]):
self.batch_type = CombinedKVCausalLMBatch
elif 'ipex' in deployment_framework:
print(deployment_framework)
self.batch_type = CausalLMBatch
else:
# check the ordering of the key tensor dimensions
key_past, value_past = past_key_values[0]
Expand Down