Skip to content

Commit 85e9a12

Browse files
authored
fix: preprocess tensor names in tensor types map (#607)
Thank you for your contribution
1 parent fbd42b6 commit 85e9a12

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

model.cpp

+40-11
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,26 @@ std::string convert_tensor_name(std::string name) {
558558
return new_name;
559559
}
560560

561+
void add_preprocess_tensor_storage_types(std::map<std::string, enum ggml_type>& tensor_storages_types, std::string name, enum ggml_type type) {
562+
std::string new_name = convert_tensor_name(name);
563+
564+
if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) {
565+
size_t prefix_size = new_name.find("attn.in_proj_weight");
566+
std::string prefix = new_name.substr(0, prefix_size);
567+
tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type;
568+
tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type;
569+
tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type;
570+
} else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) {
571+
size_t prefix_size = new_name.find("attn.in_proj_bias");
572+
std::string prefix = new_name.substr(0, prefix_size);
573+
tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type;
574+
tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type;
575+
tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type;
576+
} else {
577+
tensor_storages_types[new_name] = type;
578+
}
579+
}
580+
561581
void preprocess_tensor(TensorStorage tensor_storage,
562582
std::vector<TensorStorage>& processed_tensor_storages) {
563583
std::vector<TensorStorage> result;
@@ -927,7 +947,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
927947
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
928948

929949
tensor_storages.push_back(tensor_storage);
930-
tensor_storages_types[tensor_storage.name] = tensor_storage.type;
950+
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
931951
}
932952

933953
gguf_free(ctx_gguf_);
@@ -1072,7 +1092,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10721092
}
10731093

10741094
tensor_storages.push_back(tensor_storage);
1075-
tensor_storages_types[tensor_storage.name] = tensor_storage.type;
1095+
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
10761096

10771097
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
10781098
}
@@ -1403,7 +1423,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
14031423
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
14041424
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
14051425
tensor_storages.push_back(reader.tensor_storage);
1406-
tensor_storages_types[reader.tensor_storage.name] = reader.tensor_storage.type;
1426+
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
14071427

14081428
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
14091429
// reset
@@ -1461,10 +1481,10 @@ SDVersion ModelLoader::get_sd_version() {
14611481
TensorStorage token_embedding_weight, input_block_weight;
14621482
bool input_block_checked = false;
14631483

1464-
bool has_multiple_encoders = false;
1465-
bool is_unet = false;
1484+
bool has_multiple_encoders = false;
1485+
bool is_unet = false;
14661486

1467-
bool is_xl = false;
1487+
bool is_xl = false;
14681488
bool is_flux = false;
14691489

14701490
#define found_family (is_xl || is_flux)
@@ -1481,7 +1501,7 @@ SDVersion ModelLoader::get_sd_version() {
14811501
}
14821502
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
14831503
is_unet = true;
1484-
if(has_multiple_encoders){
1504+
if (has_multiple_encoders) {
14851505
is_xl = true;
14861506
if (input_block_checked) {
14871507
break;
@@ -1490,7 +1510,7 @@ SDVersion ModelLoader::get_sd_version() {
14901510
}
14911511
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
14921512
has_multiple_encoders = true;
1493-
if(is_unet){
1513+
if (is_unet) {
14941514
is_xl = true;
14951515
if (input_block_checked) {
14961516
break;
@@ -1635,11 +1655,20 @@ ggml_type ModelLoader::get_vae_wtype() {
16351655
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
16361656
for (auto& pair : tensor_storages_types) {
16371657
if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) {
1658+
bool found = false;
16381659
for (auto& tensor_storage : tensor_storages) {
1639-
if (tensor_storage.name == pair.first) {
1640-
if (tensor_should_be_converted(tensor_storage, wtype)) {
1641-
pair.second = wtype;
1660+
std::map<std::string, ggml_type> temp;
1661+
add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type);
1662+
for (auto& preprocessed_name : temp) {
1663+
if (preprocessed_name.first == pair.first) {
1664+
if (tensor_should_be_converted(tensor_storage, wtype)) {
1665+
pair.second = wtype;
1666+
}
1667+
found = true;
1668+
break;
16421669
}
1670+
}
1671+
if (found) {
16431672
break;
16441673
}
16451674
}

0 commit comments

Comments
 (0)