diff --git a/examples/pt2/README.md b/examples/pt2/README.md
index ef4578fbed..38b81e0374 100644
--- a/examples/pt2/README.md
+++ b/examples/pt2/README.md
@@ -1,6 +1,6 @@
 ## PyTorch 2.x integration
 
-PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental given that most public benchmarks have focused on training instead of inference.
+PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption.
 
 We strongly recommend you leverage newer hardware so for GPUs that would be an Ampere architecture. You'll get even more benefits from using server GPU deployments like A10G and A100 vs consumer cards. But you should expect to see some speedups for any Volta or Ampere architecture.
 
@@ -8,9 +8,10 @@ We strongly recommend you leverage newer hardware so for GPUs that would be an A
 
 Install torchserve and ensure that you're using at least `torch>=2.0.0`
 
+To use the latest nightlies, you can run the following commands
 ```sh
-python ts_scripts/install_dependencies.py --cuda=cu118
-pip install torchserve torch-model-archiver
+python ts_scripts/install_dependencies.py --cuda=cu121 --nightly_torch
+pip install torchserve-nightly torch-model-archiver-nightly
 ```
 
 ## torch.compile
@@ -27,13 +28,7 @@ You can also pass a dictionary with compile options if you need more control ove
 pt2 : {backend: inductor, mode: reduce-overhead}
 ```
 
-As an example let's expand our getting started guide with the only difference being passing in the extra `model_config.yaml` file
-
-```
-mkdir model_store
-torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json --handler image_classifier --config-file model_config.yaml
-torchserve --start --ncs --model-store model_store --models densenet161.mar
-```
+An example of using `torch.compile` can be found [here](./torch_compile/README.md)
 
 The exact same approach works with any other model, what's going on is the below
 
diff --git a/examples/pt2/torch_compile/README.md b/examples/pt2/torch_compile/README.md
new file mode 100644
index 0000000000..1a9ad5a897
--- /dev/null
+++ b/examples/pt2/torch_compile/README.md
@@ -0,0 +1,94 @@
+
+# TorchServe inference with torch.compile of densenet161 model
+
+This example shows how to take eager model of `densenet161`, configure TorchServe to use `torch.compile` and run inference using `torch.compile`
+
+
+### Pre-requisites
+
+- `PyTorch >= 2.0`
+
+Change directory to the examples directory
+Ex:  `cd  examples/pt2/torch_compile`
+
+
+### torch.compile config
+
+`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html)
+
+In this example , we use the following config
+
+```
+echo "pt2 : {backend: inductor, mode: reduce-overhead}" > model-config.yaml
+```
+
+### Create model archive
+
+```
+wget https://download.pytorch.org/models/densenet161-8d451a50.pth
+mkdir model_store
+torch-model-archiver --model-name densenet161 --version 1.0 --model-file model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler image_classifier --config-file model-config.yaml -f
+```
+
+#### Start TorchServe
+```
+torchserve --start --ncs --model-store model_store --models densenet161.mar
+```
+
+#### Run Inference
+
+```
+curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg
+```
+
+produces the output
+
+```
+{
+  "tabby": 0.4664836823940277,
+  "tiger_cat": 0.4645617604255676,
+  "Egyptian_cat": 0.06619937717914581,
+  "lynx": 0.0012969186063855886,
+  "plastic_bag": 0.00022856894065625966
+}
+```
+
+### Performance improvement from using `torch.compile`
+
+To measure the handler `preprocess`, `inference`, `postprocess` times, run the following
+
+#### Measure inference time with PyTorch eager
+
+```
+echo "handler:" > model-config.yaml && \
+echo "  profile: true" >> model-config.yaml
+```
+
+Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
+After a few iterations of warmup, we see the following
+
+```
+2024-02-03T00:54:31,136 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:6.118656158447266|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
+2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:18.77564811706543|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
+2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.16630400717258453|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
+```
+
+#### Measure inference time with `torch.compile`
+
+```
+echo "pt2: {backend: inductor, mode: reduce-overhead}" > model-config.yaml && \
+echo "handler:" >> model-config.yaml && \
+echo "  profile: true" >> model-config.yaml
+```
+
+Once the `yaml` file is updated, create the model-archive, start TorchServe and run inference using the steps shown above.
+`torch.compile` needs a few inferences to warmup. Once warmed up, we see the following
+```
+2024-02-03T00:56:14,808 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:5.9771199226379395|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
+2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:5.8818559646606445|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
+2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.19392000138759613|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
+```
+
+### Conclusion
+
+`torch.compile` reduces the inference time from 18ms to 5ms
diff --git a/examples/pt2/torch_compile/model-config.yaml b/examples/pt2/torch_compile/model-config.yaml
new file mode 100644
index 0000000000..9366d69480
--- /dev/null
+++ b/examples/pt2/torch_compile/model-config.yaml
@@ -0,0 +1 @@
+pt2 : {backend: inductor, mode: reduce-overhead}
diff --git a/examples/pt2/torch_compile/model.py b/examples/pt2/torch_compile/model.py
new file mode 100644
index 0000000000..ca7e83eb13
--- /dev/null
+++ b/examples/pt2/torch_compile/model.py
@@ -0,0 +1,27 @@
+from torchvision.models.densenet import DenseNet
+
+
+class ImageClassifier(DenseNet):
+    def __init__(self):
+        super(ImageClassifier, self).__init__(48, (6, 12, 36, 24), 96)
+
+    def load_state_dict(self, state_dict, strict=True):
+        # '.'s are no longer allowed in module names, but previous _DenseLayer
+        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
+        # They are also in the checkpoints in model_urls. This pattern is used
+        # to find such keys.
+        # Credit - https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py#def _load_state_dict()
+        import re
+
+        pattern = re.compile(
+            r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
+        )
+
+        for key in list(state_dict.keys()):
+            res = pattern.match(key)
+            if res:
+                new_key = res.group(1) + res.group(2)
+                state_dict[new_key] = state_dict[key]
+                del state_dict[key]
+
+        return super(ImageClassifier, self).load_state_dict(state_dict, strict)
diff --git a/test/pytest/test_example_torch_compile.py b/test/pytest/test_example_torch_compile.py
new file mode 100644
index 0000000000..c87258675e
--- /dev/null
+++ b/test/pytest/test_example_torch_compile.py
@@ -0,0 +1,82 @@
+import os
+from pathlib import Path
+
+import pytest
+import torch
+from pkg_resources import packaging
+
+from ts.torch_handler.image_classifier import ImageClassifier
+from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
+from ts.utils.util import load_label_mapping
+from ts_scripts.utils import try_and_handle
+
+CURR_FILE_PATH = Path(__file__).parent.absolute()
+REPO_ROOT_DIR = CURR_FILE_PATH.parents[1]
+EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_compile")
+TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg")
+MAPPING_DATA = REPO_ROOT_DIR.joinpath(
+    "examples", "image_classifier", "index_to_name.json"
+)
+MODEL_PTH_FILE = "densenet161-8d451a50.pth"
+MODEL_FILE = "model.py"
+MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml")
+
+
+PT2_AVAILABLE = (
+    True
+    if packaging.version.parse(torch.__version__) > packaging.version.parse("2.0")
+    else False
+)
+
+EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "plastic_bag"]
+
+
+@pytest.fixture
+def custom_working_directory(tmp_path):
+    # Set the custom working directory
+    custom_dir = tmp_path / "model_dir"
+    custom_dir.mkdir()
+    os.chdir(custom_dir)
+    yield custom_dir
+    # Clean up and return to the original working directory
+    os.chdir(tmp_path)
+
+
+@pytest.mark.skipif(PT2_AVAILABLE == False, reason="torch version is < 2.0")
+@pytest.mark.skip(reason="Skipping as its causing other testcases to fail")
+def test_torch_compile_inference(monkeypatch, custom_working_directory):
+    monkeypatch.syspath_prepend(EXAMPLE_ROOT_DIR)
+    # Get the path to the custom working directory
+    model_dir = custom_working_directory
+
+    try_and_handle(
+        f"wget https://download.pytorch.org/models/{MODEL_PTH_FILE} -P {model_dir}"
+    )
+
+    # Handler for Image classification
+    handler = ImageClassifier()
+
+    # Context definition
+    ctx = MockContext(
+        model_pt_file=model_dir.joinpath(MODEL_PTH_FILE),
+        model_dir=EXAMPLE_ROOT_DIR.as_posix(),
+        model_file=MODEL_FILE,
+        model_yaml_config_file=MODEL_YAML_CFG_FILE,
+    )
+
+    torch.manual_seed(42 * 42)
+    handler.initialize(ctx)
+    handler.context = ctx
+    handler.mapping = load_label_mapping(MAPPING_DATA)
+
+    data = {}
+    with open(TEST_DATA, "rb") as image:
+        image_file = image.read()
+        byte_array_type = bytearray(image_file)
+        data["body"] = byte_array_type
+
+    result = handler.handle([data], ctx)
+
+    labels = list(result[0].keys())
+
+    assert labels == EXPECTED_RESULTS
diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt
index ccc2188e9f..d9133f66a9 100644
--- a/ts_scripts/spellcheck_conf/wordlist.txt
+++ b/ts_scripts/spellcheck_conf/wordlist.txt
@@ -1175,3 +1175,4 @@ BabyLlamaHandler
 CMakeLists
 TorchScriptHandler
 libllamacpp
+warmup