diff --git a/.github/workflows/ci-cpu-cpp.yml b/.github/workflows/ci-cpu-cpp.yml index ec0a9c88d8..0b3a6529c2 100644 --- a/.github/workflows/ci-cpu-cpp.yml +++ b/.github/workflows/ci-cpu-cpp.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macOS-latest] + os: [ubuntu-20.04, macOS-latest] steps: - name: Checkout TorchServe uses: actions/checkout@v2 @@ -31,4 +31,4 @@ jobs: python ts_scripts/install_dependencies.py --environment=dev --cpp - name: Build run: | - cd cpp && ./build.sh --install-dependencies + cd cpp && ./build.sh diff --git a/.gitmodules b/.gitmodules index 3125a3b997..f24d0431c1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,15 @@ [submodule "cpp/third-party/llama2.so"] path = cpp/third-party/llama2.so url = https://github.com/mreso/llama2.so.git +[submodule "cpp/third-party/folly"] + path = cpp/third-party/folly + url = https://github.com/facebook/folly.git +[submodule "cpp/third-party/yaml-cpp"] + path = cpp/third-party/yaml-cpp + url = https://github.com/jbeder/yaml-cpp.git +[submodule "cpp/third-party/tokenizers-cpp"] + path = cpp/third-party/tokenizers-cpp + url = https://github.com/mlc-ai/tokenizers-cpp.git +[submodule "cpp/third-party/kineto"] + path = cpp/third-party/kineto + url = https://github.com/pytorch/kineto.git diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 01bf8b9ce8..f466ee6a6b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) project(torchserve_cpp VERSION 0.1) set(CMAKE_CXX_STANDARD 17) diff --git a/cpp/README.md b/cpp/README.md index 42df03fbd1..70b5094b2d 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -2,6 +2,7 @@ ## Requirements * C++17 * GCC version: gcc-9 +* cmake version: 3.18+ ## Installation and Running TorchServe CPP ### Install dependencies diff --git a/cpp/build.sh b/cpp/build.sh index 5beb2b79b5..b986163852 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -24,13 +24,6 @@ function install_folly() { FOLLY_SRC_DIR=$BASE_DIR/third-party/folly FOLLY_BUILD_DIR=$DEPS_DIR/folly-build - if [ ! -d "$FOLLY_SRC_DIR" ] ; then - echo -e "${COLOR_GREEN}[ INFO ] Cloning folly repo ${COLOR_OFF}" - git clone https://github.com/facebook/folly.git "$FOLLY_SRC_DIR" - cd $FOLLY_SRC_DIR - git checkout tags/v2024.01.29.00 - fi - if [ ! -d "$FOLLY_BUILD_DIR" ] ; then echo -e "${COLOR_GREEN}[ INFO ] Building Folly ${COLOR_OFF}" cd $FOLLY_SRC_DIR @@ -60,9 +53,7 @@ function install_kineto() { elif [ "$PLATFORM" = "Mac" ]; then KINETO_SRC_DIR=$BASE_DIR/third-party/kineto - if [ ! -d "$KINETO_SRC_DIR" ] ; then - echo -e "${COLOR_GREEN}[ INFO ] Cloning kineto repo ${COLOR_OFF}" - git clone --recursive https://github.com/pytorch/kineto.git "$KINETO_SRC_DIR" + if [ ! -d "$KINETO_SRC_DIR/libkineto/build" ] ; then cd $KINETO_SRC_DIR/libkineto mkdir build && cd build cmake .. @@ -128,13 +119,6 @@ function install_yaml_cpp() { YAML_CPP_SRC_DIR=$BASE_DIR/third-party/yaml-cpp YAML_CPP_BUILD_DIR=$DEPS_DIR/yaml-cpp-build - if [ ! -d "$YAML_CPP_SRC_DIR" ] ; then - echo -e "${COLOR_GREEN}[ INFO ] Cloning yaml-cpp repo ${COLOR_OFF}" - git clone https://github.com/jbeder/yaml-cpp.git "$YAML_CPP_SRC_DIR" - cd $YAML_CPP_SRC_DIR - git checkout tags/0.8.0 - fi - if [ ! -d "$YAML_CPP_BUILD_DIR" ] ; then echo -e "${COLOR_GREEN}[ INFO ] Building yaml-cpp ${COLOR_OFF}" @@ -187,6 +171,16 @@ function prepare_test_files() { local LLAMA_SO_DIR=${BASE_DIR}/third-party/llama2.so/ PYTHONPATH=${LLAMA_SO_DIR}:${PYTHONPATH} python ${BASE_DIR}/../examples/cpp/aot_inductor/llama2/compile.py --checkpoint ${HANDLER_DIR}/stories15M.pt ${HANDLER_DIR}/stories15M.so fi + if [ ! -f "${EX_DIR}/aot_inductor/bert_handler/bert-seq.so" ]; then + pip install transformers + local HANDLER_DIR=${EX_DIR}/aot_inductor/bert_handler/ + export TOKENIZERS_PARALLELISM=false + cd ${BASE_DIR}/../examples/cpp/aot_inductor/bert/ + python aot_compile_export.py + mv bert-seq.so ${HANDLER_DIR}/bert-seq.so + mv Transformer_model/tokenizer.json ${HANDLER_DIR}/tokenizer.json + export TOKENIZERS_PARALLELISM="" + fi if [ ! -f "${EX_DIR}/aot_inductor/resnet_handler/resne50_pt2.so" ]; then local HANDLER_DIR=${EX_DIR}/aot_inductor/resnet_handler/ cd ${HANDLER_DIR} @@ -376,7 +370,7 @@ cd $BASE_DIR git submodule update --init --recursive install_folly -install_kineto +#install_kineto install_libtorch install_yaml_cpp build_llama_cpp diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 513a3cdcb5..8c9d928691 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -5,10 +5,11 @@ add_subdirectory("../../../examples/cpp/llamacpp/" "${CMAKE_CURRENT_BINARY_DIR}/ add_subdirectory("../../../examples/cpp/mnist/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/mnist/mnist_handler/") - # PT2.2 torch.expport does not support Mac if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory("../../../examples/cpp/aot_inductor/llama2/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/llama_handler/") + add_subdirectory("../../../examples/cpp/aot_inductor/bert" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/bert_handler/") + add_subdirectory("../../../examples/cpp/aot_inductor/resnet" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/resnet_handler/") endif() diff --git a/cpp/src/utils/file_system.cc b/cpp/src/utils/file_system.cc index 7ba9b13501..945f57a0c8 100644 --- a/cpp/src/utils/file_system.cc +++ b/cpp/src/utils/file_system.cc @@ -1,4 +1,8 @@ #include "src/utils/file_system.hh" +#include "src/utils/logging.hh" + +#include +#include namespace torchserve { std::unique_ptr FileSystem::GetStream( @@ -10,4 +14,37 @@ std::unique_ptr FileSystem::GetStream( } return file_stream; } -} // namespace torchserve \ No newline at end of file + +std::string FileSystem::LoadBytesFromFile(const std::string& path) { + std::ifstream fs(path, std::ios::in | std::ios::binary); + if (fs.fail()) { + TS_LOGF(ERROR, "Cannot open tokenizer file {}", path); + throw; + } + std::string data; + fs.seekg(0, std::ios::end); + size_t size = static_cast(fs.tellg()); + fs.seekg(0, std::ios::beg); + data.resize(size); + fs.read(data.data(), size); + return data; +} + +std::unique_ptr FileSystem::LoadJsonFile(const std::string& file_path) { + std::string content; + if (!folly::readFile(file_path.c_str(), content)) { + TS_LOGF(ERROR, "{} not found", file_path); + throw; + } + return std::make_unique(folly::parseJson(content)); +} + +const folly::dynamic& FileSystem::GetJsonValue(std::unique_ptr& json, const std::string& key) { + if (json->find(key) != json->items().end()) { + return (*json)[key]; + } else { + TS_LOG(ERROR, "Required field {} not found in JSON.", key); + throw ; + } +} +} // namespace torchserve diff --git a/cpp/src/utils/file_system.hh b/cpp/src/utils/file_system.hh index 352ccdcbb8..dd21fcbf7b 100644 --- a/cpp/src/utils/file_system.hh +++ b/cpp/src/utils/file_system.hh @@ -1,8 +1,7 @@ #ifndef TS_CPP_UTILS_FILE_SYSTEM_HH_ #define TS_CPP_UTILS_FILE_SYSTEM_HH_ -#include - +#include #include #include #include @@ -11,6 +10,9 @@ namespace torchserve { class FileSystem { public: static std::unique_ptr GetStream(const std::string& path); + static std::string LoadBytesFromFile(const std::string& path); + static std::unique_ptr LoadJsonFile(const std::string& file_path); + static const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); }; } // namespace torchserve #endif // TS_CPP_UTILS_FILE_SYSTEM_HH_ diff --git a/cpp/test/examples/examples_test.cc b/cpp/test/examples/examples_test.cc index 0e370fc654..433a94d7f6 100644 --- a/cpp/test/examples/examples_test.cc +++ b/cpp/test/examples/examples_test.cc @@ -60,6 +60,29 @@ TEST_F(ModelPredictTest, TestLoadPredictLlamaCppHandler) { base_dir + "llamacpp_handler", base_dir + "prompt.txt", "llm_ts", 200); } +TEST_F(ModelPredictTest, TestLoadPredictAotInductorBertHandler) { + std::string base_dir = "_build/test/resources/examples/aot_inductor/"; + std::string file1 = base_dir + "bert_handler/bert-seq.so"; + std::string file2 = base_dir + "bert_handler/tokenizer.json"; + + std::ifstream f1(file1); + std::ifstream f2(file2); + + if (!f1.good() || !f2.good()) + GTEST_SKIP() << "Skipping TestLoadPredictAotInductorBertHandler because " + "of missing files: " + << file1 << " or " << file2; + + this->LoadPredict( + std::make_shared( + base_dir + "bert_handler", "bert_aot", + torch::cuda::is_available() ? 0 : -1, "", "", 1, false), + base_dir + "bert_handler", + base_dir + "bert_handler/sample_text.txt", + "bert_ts", + 200); +} + TEST_F(ModelPredictTest, TestLoadPredictAotInductorResnetHandler) { std::string base_dir = "_build/test/resources/examples/aot_inductor/"; std::string file1 = base_dir + "resnet_handler/resnet50_pt2.so"; diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json b/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json new file mode 100644 index 0000000000..c5f5d519f3 --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json @@ -0,0 +1,11 @@ +{ + "createdOn": "12/02/2024 21:09:26", + "runtime": "LSP", + "model": { + "modelName": "bertcppaot", + "handler": "libbert_handler:BertCppHandler", + "modelVersion": "1.0", + "configFile": "model-config.yaml" + }, + "archiverVersion": "0.9.0" +} diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/index_to_name.json b/cpp/test/resources/examples/aot_inductor/bert_handler/index_to_name.json new file mode 100644 index 0000000000..9ccff719f6 --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/index_to_name.json @@ -0,0 +1,4 @@ +{ + "0":"Not Accepted", + "1":"Accepted" +} diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/model-config.yaml b/cpp/test/resources/examples/aot_inductor/bert_handler/model-config.yaml new file mode 100644 index 0000000000..f44839848c --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/model-config.yaml @@ -0,0 +1,13 @@ +minWorkers: 1 +maxWorkers: 1 +batchSize: 2 + +handler: + model_so_path: "bert-seq.so" + tokenizer_path: "tokenizer.json" + mapping: "index_to_name.json" + model_name: "bert-base-uncased" + mode: "sequence_classification" + do_lower_case: true + num_labels: 2 + max_length: 150 diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt b/cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt new file mode 100644 index 0000000000..4c15a88ad2 --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt @@ -0,0 +1 @@ +Bloomberg has decided to publish a new report on the global economy. diff --git a/cpp/third-party/folly b/cpp/third-party/folly new file mode 160000 index 0000000000..323e467e23 --- /dev/null +++ b/cpp/third-party/folly @@ -0,0 +1 @@ +Subproject commit 323e467e2375e535e10bda62faf2569e8f5c9b19 diff --git a/cpp/third-party/kineto b/cpp/third-party/kineto new file mode 160000 index 0000000000..594c63c50d --- /dev/null +++ b/cpp/third-party/kineto @@ -0,0 +1 @@ +Subproject commit 594c63c50dd9684a592ad7670ecdef6dd5e36be7 diff --git a/cpp/third-party/tokenizers-cpp b/cpp/third-party/tokenizers-cpp new file mode 160000 index 0000000000..27dbe17d72 --- /dev/null +++ b/cpp/third-party/tokenizers-cpp @@ -0,0 +1 @@ +Subproject commit 27dbe17d7268801ec720569167af905c88d3db50 diff --git a/cpp/third-party/yaml-cpp b/cpp/third-party/yaml-cpp new file mode 160000 index 0000000000..f732014112 --- /dev/null +++ b/cpp/third-party/yaml-cpp @@ -0,0 +1 @@ +Subproject commit f7320141120f720aecc4c32be25586e7da9eb978 diff --git a/examples/cpp/aot_inductor/bert/CMakeLists.txt b/examples/cpp/aot_inductor/bert/CMakeLists.txt new file mode 100644 index 0000000000..a4f48301fc --- /dev/null +++ b/examples/cpp/aot_inductor/bert/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TOKENZIER_CPP_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../../cpp/third-party/tokenizers-cpp) +add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL) +add_library(bert_handler SHARED src/bert_handler.cc) +target_include_directories(bert_handler PRIVATE ${TOKENZIER_CPP_PATH}/include) +target_link_libraries(bert_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES} tokenizers_cpp) diff --git a/examples/cpp/aot_inductor/bert/README.md b/examples/cpp/aot_inductor/bert/README.md new file mode 100644 index 0000000000..f10e106f86 --- /dev/null +++ b/examples/cpp/aot_inductor/bert/README.md @@ -0,0 +1,61 @@ +This example uses AOTInductor to compile the [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) into an so file (see script [aot_compile_export.py](aot_compile_export.py)). In PyTorch 2.2, the supported `MAX_SEQ_LENGTH` in this script is 511. + +Then, this example loads model and runs prediction using libtorch. The handler C++ source code for this examples can be found [here](src). + +### Setup +1. Follow the instructions in [README.md](../../../../cpp/README.md) to build the TorchServe C++ backend. + +``` +cd serve/cpp +./builld.sh +``` + +The build script will create the necessary artifact for this example. +To recreate these by hand you can follow the prepare_test_files function of the [build.sh](../../../../cpp/build.sh) script. +We will need the handler .so file as well as the bert-seq.so and tokenizer.json. + +2. Create a [model-config.yaml](model-config.yaml) + +```yaml +minWorkers: 1 +maxWorkers: 1 +batchSize: 2 + +handler: + model_so_path: "bert-seq.so" + tokenizer_path: "tokenizer.json" + mapping: "index_to_name.json" + model_name: "bert-base-uncased" + mode: "sequence_classification" + do_lower_case: true + num_labels: 2 + max_length: 150 +``` + +### Generate Model Artifact Folder + +```bash +torch-model-archiver --model-name bertcppaot --version 1.0 --handler ../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/libbert_handler:BertCppHandler --runtime LSP --extra-files index_to_name.json,../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/bert-seq.so,../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/tokenizer.json --config-file model-config.yaml --archive-format no-archive +``` + +Create model store directory and move the folder `bertcppaot` + +``` +mkdir model_store +mv bertcppaot model_store/ +``` + +### Inference + +Start torchserve using the following command + +``` +torchserve --ncs --model-store model_store/ --models bertcppaot +``` + +Infer the model using the following command + +``` +curl http://localhost:8080/predictions/bertcppaot -T ../../../../cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt +Not Accepted +``` diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py new file mode 100644 index 0000000000..2a01ad1c21 --- /dev/null +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -0,0 +1,121 @@ +import os +import sys + +import torch +import yaml +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + set_seed, +) + +set_seed(1) +# PT2.2 has limitation on the max +MAX_BATCH_SIZE = 15 +MAX_SEQ_LENGTH = 511 + + +def transformers_model_dowloader( + mode, + pretrained_model_name, + num_labels, + do_lower_case, + max_length, + batch_size, +): + print("Download model and tokenizer", pretrained_model_name) + # loading pre-trained model and tokenizer + if mode == "sequence_classification": + config = AutoConfig.from_pretrained( + pretrained_model_name, + num_labels=num_labels, + torchscript=False, + return_dict=False, + ) + model = AutoModelForSequenceClassification.from_pretrained( + pretrained_model_name, config=config + ) + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name, do_lower_case=do_lower_case + ) + else: + sys.exit(f"mode={mode} has not been implemented in this cpp example yet.") + + NEW_DIR = "./Transformer_model" + try: + os.mkdir(NEW_DIR) + except OSError: + print("Creation of directory %s failed" % NEW_DIR) + else: + print("Successfully created directory %s " % NEW_DIR) + + print( + "Save model and tokenizer model based on the setting from setup_config", + pretrained_model_name, + "in directory", + NEW_DIR, + ) + + model.save_pretrained(NEW_DIR) + tokenizer.save_pretrained(NEW_DIR) + + with torch.no_grad(): + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = model.to(device=device) + dummy_input = "This is a dummy input for torch compile export" + inputs = tokenizer.encode_plus( + dummy_input, + max_length=max_length, + padding=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = torch.cat([inputs["input_ids"]] * batch_size, 0).to(device) + attention_mask = torch.cat([inputs["attention_mask"]] * batch_size, 0).to( + device + ) + batch_dim = torch.export.Dim("batch", min=1, max=MAX_BATCH_SIZE) + seq_len_dim = torch.export.Dim("seq_len", min=1, max=MAX_SEQ_LENGTH) + torch._C._GLIBCXX_USE_CXX11_ABI = True + model_so_path = torch._export.aot_compile( + model, + (input_ids, attention_mask), + dynamic_shapes={ + "input_ids": (batch_dim, seq_len_dim), + "attention_mask": (batch_dim, seq_len_dim), + }, + options={ + "aot_inductor.output_path": os.path.join(os.getcwd(), "bert-seq.so"), + "max_autotune": True, + }, + ) + + return + + +if __name__ == "__main__": + dirname = os.path.dirname(__file__) + if len(sys.argv) > 1: + filename = os.path.join(dirname, sys.argv[1]) + else: + filename = os.path.join(dirname, "model-config.yaml") + with open(filename, "r") as f: + settings = yaml.safe_load(f) + + mode = settings["handler"]["mode"] + model_name = settings["handler"]["model_name"] + num_labels = int(settings["handler"]["num_labels"]) + do_lower_case = bool(settings["handler"]["do_lower_case"]) + max_length = int(settings["handler"]["max_length"]) + batch_size = int(settings["batchSize"]) + transformers_model_dowloader( + mode, + model_name, + num_labels, + do_lower_case, + max_length, + batch_size, + ) diff --git a/examples/cpp/aot_inductor/bert/index_to_name.json b/examples/cpp/aot_inductor/bert/index_to_name.json new file mode 100644 index 0000000000..9ccff719f6 --- /dev/null +++ b/examples/cpp/aot_inductor/bert/index_to_name.json @@ -0,0 +1,4 @@ +{ + "0":"Not Accepted", + "1":"Accepted" +} diff --git a/examples/cpp/aot_inductor/bert/model-config.yaml b/examples/cpp/aot_inductor/bert/model-config.yaml new file mode 100644 index 0000000000..f44839848c --- /dev/null +++ b/examples/cpp/aot_inductor/bert/model-config.yaml @@ -0,0 +1,13 @@ +minWorkers: 1 +maxWorkers: 1 +batchSize: 2 + +handler: + model_so_path: "bert-seq.so" + tokenizer_path: "tokenizer.json" + mapping: "index_to_name.json" + model_name: "bert-base-uncased" + mode: "sequence_classification" + do_lower_case: true + num_labels: 2 + max_length: 150 diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc new file mode 100644 index 0000000000..3fdf950a1f --- /dev/null +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -0,0 +1,206 @@ +#include "bert_handler.hh" +#include "src/utils/file_system.hh" + +#include +#include +#include +#include +#include + +namespace bert { +std::pair, std::shared_ptr> +BertCppHandler::LoadModel( + std::shared_ptr& load_model_request) { + try { + TS_LOG(INFO, "start LoadModel"); + auto device = GetTorchDevice(load_model_request); + + const std::string modelConfigYamlFilePath = + fmt::format("{}/{}", load_model_request->model_dir, "model-config.yaml"); + model_config_yaml_ = std::make_unique(YAML::LoadFile(modelConfigYamlFilePath)); + + const std::string mapFilePath = + fmt::format("{}/{}", load_model_request->model_dir, + (*model_config_yaml_)["handler"]["mapping"].as()); + mapping_json_ = torchserve::FileSystem::LoadJsonFile(mapFilePath); + + max_length_ = (*model_config_yaml_)["handler"]["max_length"].as(); + + std::string tokenizer_path = + fmt::format("{}/{}", load_model_request->model_dir, + (*model_config_yaml_)["handler"]["tokenizer_path"].as()); + auto tokenizer_blob = torchserve::FileSystem::LoadBytesFromFile(tokenizer_path); + tokenizer_ = tokenizers::Tokenizer::FromBlobJSON(tokenizer_blob); + + std::string model_so_path = + fmt::format("{}/{}", load_model_request->model_dir, + (*model_config_yaml_)["handler"]["model_so_path"].as()); + + c10::InferenceMode mode; + + if (device->is_cuda()) { + return std::make_pair( + std::make_shared(model_so_path.c_str(), 1, device->str().c_str()), + device); + } else { + return std::make_pair( + std::make_shared(model_so_path.c_str()), + device); + } + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.msg()); + throw e; + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.what()); + throw e; + } +} + +c10::IValue BertCppHandler::Preprocess( + std::shared_ptr &device, + std::pair &> &idx_to_req_id, + std::shared_ptr &request_batch, + std::shared_ptr &response_batch) { + auto options = torch::TensorOptions().dtype(torch::kLong); + auto attention_mask = torch::zeros({static_cast(request_batch->size()), max_length_}, torch::kLong); + auto batch_tokens = torch::full({static_cast(request_batch->size()), max_length_}, tokenizer_->TokenToId(""), torch::kLong); + + uint8_t idx = 0; + for (auto& request : *request_batch) { + try { + (*response_batch)[request.request_id] = + std::make_shared(request.request_id); + idx_to_req_id.first += idx_to_req_id.first.empty() + ? request.request_id + : "," + request.request_id; + + auto data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_DATA); + auto dtype_it = + request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); + if (data_it == request.parameters.end()) { + data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_BODY); + dtype_it = request.headers.find( + torchserve::PayloadType::kHEADER_NAME_BODY_TYPE); + } + + if (data_it == request.parameters.end() || + dtype_it == request.headers.end()) { + (*response_batch)[request.request_id]->SetResponse( + 500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT, + "Empty payload"); + continue; + } + + std::string msg = torchserve::Converter::VectorToStr(data_it->second); + // tokenization + std::vector token_ids = tokenizer_->Encode(msg);; + int cur_token_ids_length = (int)token_ids.size(); + if (cur_token_ids_length > max_length_) { + TS_LOGF(ERROR, "prompt too long ({} tokens, max {})", cur_token_ids_length, max_length_); + } + for (int i = 0; i < std::min(cur_token_ids_length, max_length_); i++) { + attention_mask[idx][i] = 1; + batch_tokens[idx][i] = token_ids[i]; + } + + idx_to_req_id.second[idx++] = request.request_id; + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + request.request_id, e.what()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to load tensor"); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, c10 error: {}", + request.request_id, e.msg()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to load tensor"); + } + } + auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); + batch_ivalue.emplace_back(batch_tokens.to(*device)); + batch_ivalue.emplace_back(attention_mask.to(*device)); + + return batch_ivalue; +} + +c10::IValue BertCppHandler::Inference( + std::shared_ptr model, c10::IValue &inputs, + std::shared_ptr &device, + std::pair &> &idx_to_req_id, + std::shared_ptr &response_batch) { + c10::InferenceMode mode; + try { + std::shared_ptr runner; + if (device->is_cuda()) { + runner = std::static_pointer_cast(model); + } else { + runner = std::static_pointer_cast(model); + } + + auto batch_output_tensor_vector = runner->run(inputs.toTensorVector()); + return c10::IValue(batch_output_tensor_vector[0]); + } catch (std::runtime_error& e) { + TS_LOG(ERROR, e.what()); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "Failed to apply inference on input, c10 error:{}", e.msg()); + } +} + +void BertCppHandler::Postprocess( + c10::IValue &inputs, + std::pair &> &idx_to_req_id, + std::shared_ptr &response_batch) { + auto& data = inputs.toTensor(); + for (const auto &kv : idx_to_req_id.second) { + try { + auto out = data[kv.first].unsqueeze(0); + auto y_hat = torch::argmax(out, 1).item(); + auto predicted_idx = std::to_string(y_hat); + auto response = (*response_batch)[kv.second]; + + response->SetResponse(200, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + torchserve::FileSystem::GetJsonValue(mapping_json_, predicted_idx).asString()); + } catch (const std::runtime_error &e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + kv.second, e.what()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to postprocess tensor"); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, + "Failed to postprocess tensor for request id: {}, error: {}", + kv.second, e.msg()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to postprocess tensor"); + } + } +} +} // namespace bert + +#if defined(__linux__) || defined(__APPLE__) +extern "C" { +torchserve::BaseHandler *allocatorBertCppHandler() { + return new bert::BertCppHandler(); +} + +void deleterBertCppHandler(torchserve::BaseHandler *p) { + if (p != nullptr) { + delete static_cast(p); + } +} +} +#endif diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.hh b/examples/cpp/aot_inductor/bert/src/bert_handler.hh new file mode 100644 index 0000000000..80a3d68cb0 --- /dev/null +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.hh @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "src/backends/handler/base_handler.hh" + +namespace bert { +class BertCppHandler : public torchserve::BaseHandler { + public: + // NOLINTBEGIN(bugprone-exception-escape) + BertCppHandler() = default; + // NOLINTEND(bugprone-exception-escape) + ~BertCppHandler() noexcept = default; + + std::pair, std::shared_ptr> LoadModel( + std::shared_ptr& load_model_request) + override; + + c10::IValue Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) + override; + + c10::IValue Inference( + std::shared_ptr model, c10::IValue& inputs, + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; + + void Postprocess( + c10::IValue& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; + +private: + std::unique_ptr mapping_json_; + std::unique_ptr tokenizer_; + std::unique_ptr model_config_yaml_; + int max_length_; +}; +} // namespace bert diff --git a/examples/cpp/aot_inductor/resnet/src/resnet_handler.cc b/examples/cpp/aot_inductor/resnet/src/resnet_handler.cc index a7d55381bb..351ac7e216 100644 --- a/examples/cpp/aot_inductor/resnet/src/resnet_handler.cc +++ b/examples/cpp/aot_inductor/resnet/src/resnet_handler.cc @@ -1,30 +1,17 @@ #include "resnet_handler.hh" +#include "src/utils/file_system.hh" +#include +#include +#include +#include #include namespace resnet { -std::unique_ptr ResnetCppHandler::LoadJsonFile(const std::string& file_path) { - std::string content; - if (!folly::readFile(file_path.c_str(), content)) { - TS_LOGF(ERROR, "{}} not found", file_path); - throw; - } - return std::make_unique(folly::parseJson(content)); -} - -const folly::dynamic& ResnetCppHandler::GetJsonValue(std::unique_ptr& json, const std::string& key) { - if (json->find(key) != json->items().end()) { - return (*json)[key]; - } else { - TS_LOG(ERROR, "Required field {} not found in JSON.", key); - throw ; - } -} - std::string ResnetCppHandler::MapClassToLabel(const torch::Tensor& classes, const torch::Tensor& probs) { folly::dynamic map = folly::dynamic::object; for (int i = 0; i < classes.sizes()[0]; i++) { - auto class_value = GetJsonValue(mapping_json_, std::to_string(classes[i].item())); + auto class_value = torchserve::FileSystem::GetJsonValue(mapping_json_, std::to_string(classes[i].item())); map[class_value[1].asString()] = probs[i].item(); } @@ -44,12 +31,12 @@ ResnetCppHandler::LoadModel( const std::string mapFilePath = fmt::format("{}/{}", load_model_request->model_dir, (*model_config_yaml_)["handler"]["mapping"].as()); - mapping_json_ = LoadJsonFile(mapFilePath); + mapping_json_ = torchserve::FileSystem::LoadJsonFile(mapFilePath); std::string model_so_path = fmt::format("{}/{}", load_model_request->model_dir, (*model_config_yaml_)["handler"]["model_so_path"].as()); - mapping_json_ = LoadJsonFile(mapFilePath); + mapping_json_ = torchserve::FileSystem::LoadJsonFile(mapFilePath); c10::InferenceMode mode; if (device->is_cuda()) { diff --git a/examples/cpp/aot_inductor/resnet/src/resnet_handler.hh b/examples/cpp/aot_inductor/resnet/src/resnet_handler.hh index db20b551fe..4e43ea9fad 100644 --- a/examples/cpp/aot_inductor/resnet/src/resnet_handler.hh +++ b/examples/cpp/aot_inductor/resnet/src/resnet_handler.hh @@ -2,12 +2,7 @@ #include #include -#include -#include -#include #include -#include -#include #include #include "src/backends/handler/base_handler.hh" @@ -45,8 +40,6 @@ class ResnetCppHandler : public torchserve::BaseHandler { override; private: - std::unique_ptr LoadJsonFile(const std::string& file_path); - const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); std::string MapClassToLabel(const torch::Tensor& classes, const torch::Tensor& probs); std::unique_ptr mapping_json_; diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index 8727a189c3..ff9ae4c25e 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -45,6 +45,11 @@ "ninja-build", "clang-tidy", "clang-format", + "build-essential", + "libgoogle-perftools-dev", + "rustc", + "cargo", + "libunwind-dev", ) CPP_DARWIN_DEPENDENCIES = ( diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 4077b5baa9..d19638e5c1 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1187,6 +1187,7 @@ FxGraphCache TorchInductor fx locustapache +bertcppaot resnetcppaot FINhR IBY