Skip to content

Commit f1cf28a

Browse files
committed
lora: support tucker decomposition (from loha)
1 parent 5eb15ef commit f1cf28a

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

ggml_extend.hpp

+27-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,32 @@
5252
#define __STATIC_INLINE__ static inline
5353
#endif
5454

55+
__STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_only, const char* mark);
56+
57+
// n-mode trensor-matrix product
58+
// example: 2-mode product
59+
// A: [ne03, k, ne01, ne00]
60+
// B: k rows, m columns => [k, m]
61+
// result is [ne03, m, ne01, ne00]
62+
__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) {
63+
// reshape A
64+
// swap 0th and nth axis
65+
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
66+
int ne1 = a->ne[1];
67+
int ne2 = a->ne[2];
68+
int ne3 = a->ne[3];
69+
// make 2D
70+
a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1)));
71+
72+
struct ggml_tensor* result = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, a, b)));
73+
74+
// reshape output (same shape as a after permutation except first dim)
75+
result = ggml_reshape_4d(ctx, result, result->ne[0], ne1, ne2, ne3);
76+
// swap back 0th and nth axis
77+
result = ggml_permute(ctx, result, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0);
78+
return result;
79+
}
80+
5581
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) {
5682
(void)level;
5783
(void)user_data;
@@ -319,7 +345,7 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
319345
for (int iy = 0; iy < height; iy++) {
320346
float m = ggml_tensor_get_f32(mask, ix, iy);
321347
for (int k = 0; k < channels; k++) {
322-
float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
348+
float value = ((float)(m < 254.5 / 255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
323349
ggml_tensor_set_f32(output, value, ix, iy, k);
324350
}
325351
}

lora.hpp

+27-5
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,15 @@ struct LoraModel : public GGMLRunner {
244244
std::vector<std::string> keys = to_lora_keys(k_tensor, version);
245245
if (keys.size() == 0)
246246
continue;
247+
248+
ggml_tensor* lora_mid = NULL; // tau for tucker decomposition
247249
ggml_tensor* lora_up = NULL;
248250
ggml_tensor* lora_down = NULL;
249251
for (auto& key : keys) {
250252
std::string alpha_name = "";
251253
std::string scale_name = "";
252254
std::string split_q_scale_name = "";
255+
std::string lora_mid_name = "";
253256
std::string lora_down_name = "";
254257
std::string lora_up_name = "";
255258

@@ -584,8 +587,10 @@ struct LoraModel : public GGMLRunner {
584587
}
585588

586589
lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
587-
alpha_name = lora_pre[type] + key + ".alpha";
588-
scale_name = lora_pre[type] + key + ".scale";
590+
lora_mid_name = lora_pre[type] + key + ".lora_mid.weight";
591+
592+
alpha_name = lora_pre[type] + key + ".alpha";
593+
scale_name = lora_pre[type] + key + ".scale";
589594

590595
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
591596
lora_up = lora_tensors[lora_up_name];
@@ -594,6 +599,12 @@ struct LoraModel : public GGMLRunner {
594599
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
595600
lora_down = lora_tensors[lora_down_name];
596601
}
602+
603+
if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) {
604+
lora_mid = lora_tensors[lora_mid_name];
605+
applied_lora_tensors.insert(lora_mid_name);
606+
}
607+
597608
applied_lora_tensors.insert(lora_up_name);
598609
applied_lora_tensors.insert(lora_down_name);
599610
applied_lora_tensors.insert(alpha_name);
@@ -622,9 +633,20 @@ struct LoraModel : public GGMLRunner {
622633

623634
// ggml_mul_mat requires tensor b transposed
624635
lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down));
625-
struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down);
626-
updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown));
627-
updown = ggml_reshape(compute_ctx, updown, weight);
636+
struct ggml_tensor* updown = NULL;
637+
if (lora_mid == NULL) {
638+
updown = ggml_mul_mat(compute_ctx, lora_up, lora_down);
639+
updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown));
640+
} else {
641+
// undoing tucker decomposition for conv layers.
642+
// lora_mid has shape (3, 3, Rank, Rank)
643+
// lora_down has shape (Rank, In, 1, 1)
644+
// lora_up has shape (Rank, Out, 1, 1)
645+
// conv layer shape is (3, 3, Out, In)
646+
updown = ggml_mul_n_mode(compute_ctx, ggml_mul_n_mode(compute_ctx, lora_mid, lora_down, 3), lora_up, 2);
647+
updown = ggml_cont(compute_ctx, updown);
648+
}
649+
updown = ggml_reshape(compute_ctx, updown, weight);
628650
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
629651
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
630652
ggml_tensor* final_weight;

0 commit comments

Comments
 (0)