@@ -244,12 +244,15 @@ struct LoraModel : public GGMLRunner {
244
244
std::vector<std::string> keys = to_lora_keys (k_tensor, version);
245
245
if (keys.size () == 0 )
246
246
continue ;
247
+
248
+ ggml_tensor* lora_mid = NULL ; // tau for tucker decomposition
247
249
ggml_tensor* lora_up = NULL ;
248
250
ggml_tensor* lora_down = NULL ;
249
251
for (auto & key : keys) {
250
252
std::string alpha_name = " " ;
251
253
std::string scale_name = " " ;
252
254
std::string split_q_scale_name = " " ;
255
+ std::string lora_mid_name = " " ;
253
256
std::string lora_down_name = " " ;
254
257
std::string lora_up_name = " " ;
255
258
@@ -584,8 +587,10 @@ struct LoraModel : public GGMLRunner {
584
587
}
585
588
586
589
lora_down_name = lora_pre[type] + key + lora_downs[type] + " .weight" ;
587
- alpha_name = lora_pre[type] + key + " .alpha" ;
588
- scale_name = lora_pre[type] + key + " .scale" ;
590
+ lora_mid_name = lora_pre[type] + key + " .lora_mid.weight" ;
591
+
592
+ alpha_name = lora_pre[type] + key + " .alpha" ;
593
+ scale_name = lora_pre[type] + key + " .scale" ;
589
594
590
595
if (lora_tensors.find (lora_up_name) != lora_tensors.end ()) {
591
596
lora_up = lora_tensors[lora_up_name];
@@ -594,6 +599,12 @@ struct LoraModel : public GGMLRunner {
594
599
if (lora_tensors.find (lora_down_name) != lora_tensors.end ()) {
595
600
lora_down = lora_tensors[lora_down_name];
596
601
}
602
+
603
+ if (lora_tensors.find (lora_mid_name) != lora_tensors.end ()) {
604
+ lora_mid = lora_tensors[lora_mid_name];
605
+ applied_lora_tensors.insert (lora_mid_name);
606
+ }
607
+
597
608
applied_lora_tensors.insert (lora_up_name);
598
609
applied_lora_tensors.insert (lora_down_name);
599
610
applied_lora_tensors.insert (alpha_name);
@@ -622,9 +633,20 @@ struct LoraModel : public GGMLRunner {
622
633
623
634
// ggml_mul_mat requires tensor b transposed
624
635
lora_down = ggml_cont (compute_ctx, ggml_transpose (compute_ctx, lora_down));
625
- struct ggml_tensor * updown = ggml_mul_mat (compute_ctx, lora_up, lora_down);
626
- updown = ggml_cont (compute_ctx, ggml_transpose (compute_ctx, updown));
627
- updown = ggml_reshape (compute_ctx, updown, weight);
636
+ struct ggml_tensor * updown = NULL ;
637
+ if (lora_mid == NULL ) {
638
+ updown = ggml_mul_mat (compute_ctx, lora_up, lora_down);
639
+ updown = ggml_cont (compute_ctx, ggml_transpose (compute_ctx, updown));
640
+ } else {
641
+ // undoing tucker decomposition for conv layers.
642
+ // lora_mid has shape (3, 3, Rank, Rank)
643
+ // lora_down has shape (Rank, In, 1, 1)
644
+ // lora_up has shape (Rank, Out, 1, 1)
645
+ // conv layer shape is (3, 3, Out, In)
646
+ updown = ggml_mul_n_mode (compute_ctx, ggml_mul_n_mode (compute_ctx, lora_mid, lora_down, 3 ), lora_up, 2 );
647
+ updown = ggml_cont (compute_ctx, updown);
648
+ }
649
+ updown = ggml_reshape (compute_ctx, updown, weight);
628
650
GGML_ASSERT (ggml_nelements (updown) == ggml_nelements (weight));
629
651
updown = ggml_scale_inplace (compute_ctx, updown, scale_value);
630
652
ggml_tensor* final_weight;
0 commit comments