@@ -558,6 +558,26 @@ std::string convert_tensor_name(std::string name) {
558
558
return new_name;
559
559
}
560
560
561
+ void add_preprocess_tensor_storage_types (std::map<std::string, enum ggml_type>& tensor_storages_types, std::string name, enum ggml_type type) {
562
+ std::string new_name = convert_tensor_name (name);
563
+
564
+ if (new_name.find (" cond_stage_model" ) != std::string::npos && ends_with (new_name, " attn.in_proj_weight" )) {
565
+ size_t prefix_size = new_name.find (" attn.in_proj_weight" );
566
+ std::string prefix = new_name.substr (0 , prefix_size);
567
+ tensor_storages_types[prefix + " self_attn.q_proj.weight" ] = type;
568
+ tensor_storages_types[prefix + " self_attn.k_proj.weight" ] = type;
569
+ tensor_storages_types[prefix + " self_attn.v_proj.weight" ] = type;
570
+ } else if (new_name.find (" cond_stage_model" ) != std::string::npos && ends_with (new_name, " attn.in_proj_bias" )) {
571
+ size_t prefix_size = new_name.find (" attn.in_proj_bias" );
572
+ std::string prefix = new_name.substr (0 , prefix_size);
573
+ tensor_storages_types[prefix + " self_attn.q_proj.bias" ] = type;
574
+ tensor_storages_types[prefix + " self_attn.k_proj.bias" ] = type;
575
+ tensor_storages_types[prefix + " self_attn.v_proj.bias" ] = type;
576
+ } else {
577
+ tensor_storages_types[new_name] = type;
578
+ }
579
+ }
580
+
561
581
void preprocess_tensor (TensorStorage tensor_storage,
562
582
std::vector<TensorStorage>& processed_tensor_storages) {
563
583
std::vector<TensorStorage> result;
@@ -927,7 +947,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
927
947
GGML_ASSERT (ggml_nbytes (dummy) == tensor_storage.nbytes ());
928
948
929
949
tensor_storages.push_back (tensor_storage);
930
- tensor_storages_types[ tensor_storage.name ] = tensor_storage.type ;
950
+ add_preprocess_tensor_storage_types ( tensor_storages_types, tensor_storage.name , tensor_storage.type ) ;
931
951
}
932
952
933
953
gguf_free (ctx_gguf_);
@@ -1072,7 +1092,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
1072
1092
}
1073
1093
1074
1094
tensor_storages.push_back (tensor_storage);
1075
- tensor_storages_types[ tensor_storage.name ] = tensor_storage.type ;
1095
+ add_preprocess_tensor_storage_types ( tensor_storages_types, tensor_storage.name , tensor_storage.type ) ;
1076
1096
1077
1097
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
1078
1098
}
@@ -1403,7 +1423,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
1403
1423
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
1404
1424
reader.tensor_storage .name = prefix + reader.tensor_storage .name ;
1405
1425
tensor_storages.push_back (reader.tensor_storage );
1406
- tensor_storages_types[ reader.tensor_storage .name ] = reader.tensor_storage .type ;
1426
+ add_preprocess_tensor_storage_types ( tensor_storages_types, reader.tensor_storage .name , reader.tensor_storage .type ) ;
1407
1427
1408
1428
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
1409
1429
// reset
@@ -1461,10 +1481,10 @@ SDVersion ModelLoader::get_sd_version() {
1461
1481
TensorStorage token_embedding_weight, input_block_weight;
1462
1482
bool input_block_checked = false ;
1463
1483
1464
- bool has_multiple_encoders = false ;
1465
- bool is_unet = false ;
1484
+ bool has_multiple_encoders = false ;
1485
+ bool is_unet = false ;
1466
1486
1467
- bool is_xl = false ;
1487
+ bool is_xl = false ;
1468
1488
bool is_flux = false ;
1469
1489
1470
1490
#define found_family (is_xl || is_flux)
@@ -1481,7 +1501,7 @@ SDVersion ModelLoader::get_sd_version() {
1481
1501
}
1482
1502
if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos) {
1483
1503
is_unet = true ;
1484
- if (has_multiple_encoders){
1504
+ if (has_multiple_encoders) {
1485
1505
is_xl = true ;
1486
1506
if (input_block_checked) {
1487
1507
break ;
@@ -1490,7 +1510,7 @@ SDVersion ModelLoader::get_sd_version() {
1490
1510
}
1491
1511
if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos || tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos) {
1492
1512
has_multiple_encoders = true ;
1493
- if (is_unet){
1513
+ if (is_unet) {
1494
1514
is_xl = true ;
1495
1515
if (input_block_checked) {
1496
1516
break ;
@@ -1635,11 +1655,20 @@ ggml_type ModelLoader::get_vae_wtype() {
1635
1655
void ModelLoader::set_wtype_override (ggml_type wtype, std::string prefix) {
1636
1656
for (auto & pair : tensor_storages_types) {
1637
1657
if (prefix.size () < 1 || pair.first .substr (0 , prefix.size ()) == prefix) {
1658
+ bool found = false ;
1638
1659
for (auto & tensor_storage : tensor_storages) {
1639
- if (tensor_storage.name == pair.first ) {
1640
- if (tensor_should_be_converted (tensor_storage, wtype)) {
1641
- pair.second = wtype;
1660
+ std::map<std::string, ggml_type> temp;
1661
+ add_preprocess_tensor_storage_types (temp, tensor_storage.name , tensor_storage.type );
1662
+ for (auto & preprocessed_name : temp) {
1663
+ if (preprocessed_name.first == pair.first ) {
1664
+ if (tensor_should_be_converted (tensor_storage, wtype)) {
1665
+ pair.second = wtype;
1666
+ }
1667
+ found = true ;
1668
+ break ;
1642
1669
}
1670
+ }
1671
+ if (found) {
1643
1672
break ;
1644
1673
}
1645
1674
}
0 commit comments