Skip to content

Commit ac51185

Browse files
committed
refactor fp8 + add e3m4 (fn)
1 parent dcf91f9 commit ac51185

File tree

2 files changed

+83
-73
lines changed

2 files changed

+83
-73
lines changed

model.cpp

+79-72
Original file line numberDiff line numberDiff line change
@@ -609,87 +609,72 @@ float bf16_to_f32(uint16_t bfloat16) {
609609
return *reinterpret_cast<float*>(&val_bits);
610610
}
611611

612-
uint16_t f8_e4m3_to_f16(uint8_t f8) {
613-
// do we need to support uz?
614-
615-
const uint32_t exponent_bias = 7;
616-
if (f8 == 0xff) {
617-
return ggml_fp32_to_fp16(-NAN);
618-
} else if (f8 == 0x7f) {
619-
return ggml_fp32_to_fp16(NAN);
612+
uint16_t f8_e3m4_to_f16(uint8_t fp8) {
613+
if ((fp8 & 0x7F) == 0 || (fp8 & 0x7F) == 0x7F) {
614+
// +/- 0 or NaN
615+
return static_cast<uint16_t>(fp8) << 8;
620616
}
617+
const uint32_t exponent_bias = 0x3; // 2^(3-1)-1
618+
const uint32_t f16_bias = 0xF; // 2^(5-1)-1
619+
const int mantissa_bits = 4;
620+
const uint8_t mantissa_max = 0xF; // 2^4-1
621621

622-
uint32_t sign = f8 & 0x80;
623-
uint32_t exponent = (f8 & 0x78) >> 3;
624-
uint32_t mantissa = f8 & 0x07;
625-
uint32_t result = sign << 24;
626-
if (exponent == 0) {
627-
if (mantissa > 0) {
628-
exponent = 0x7f - exponent_bias;
629-
630-
// yes, 2 times
631-
if ((mantissa & 0x04) == 0) {
632-
mantissa &= 0x03;
633-
mantissa <<= 1;
634-
exponent -= 1;
635-
}
636-
if ((mantissa & 0x04) == 0) {
637-
mantissa &= 0x03;
638-
mantissa <<= 1;
639-
exponent -= 1;
640-
}
622+
uint8_t sign = (fp8 >> 7) & 0x1;
623+
uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits);
624+
uint8_t mantissa = fp8 & mantissa_max;
641625

642-
result |= (mantissa & 0x03) << 21;
643-
result |= exponent << 23;
626+
uint16_t fp16_sign = sign << 15;
627+
uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias));
628+
if (exponent == 0) {
629+
// subnormal numbers
630+
fp16_exponent++;
631+
// mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
632+
while (!(mantissa >> mantissa_bits)) {
633+
mantissa <<= 1;
634+
fp16_exponent--;
644635
}
645-
} else {
646-
result |= mantissa << 20;
647-
exponent += 0x7f - exponent_bias;
648-
result |= exponent << 23;
636+
mantissa &= mantissa_max;
649637
}
638+
uint16_t fp16_mantissa = mantissa << 6;
650639

651-
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
640+
return fp16_sign | fp16_exponent << 10 | fp16_mantissa;
652641
}
653642

654-
uint16_t f8_e5m2_to_f16(uint8_t fp8) {
655-
uint8_t sign = (fp8 >> 7) & 0x1;
656-
uint8_t exponent = (fp8 >> 2) & 0x1F;
657-
uint8_t mantissa = fp8 & 0x3;
658-
659-
uint16_t fp16_sign = sign << 15;
660-
uint16_t fp16_exponent;
661-
uint16_t fp16_mantissa;
662-
663-
if (exponent == 0 && mantissa == 0) { // zero
664-
return fp16_sign;
643+
uint16_t f8_e4m3_to_f16(uint8_t fp8) {
644+
// do we need to support uz?
645+
if ((fp8 & 0x7F) == 0 || (fp8 & 0x7F) == 0x7F) {
646+
// +/- 0 or NaN
647+
return static_cast<uint16_t>(fp8) << 8;
665648
}
649+
const uint32_t exponent_bias = 0x7; // 2^(4-1)-1
650+
const uint32_t f16_bias = 0xF; // 2^(5-1)-1
651+
const int mantissa_bits = 3;
652+
const uint8_t mantissa_max = 0x7; // 2^3-1
666653

667-
if (exponent == 0x1F) { // NAN and INF
668-
fp16_exponent = 0x1F;
669-
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
670-
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
671-
}
654+
uint8_t sign = (fp8 >> 7) & 0x1;
655+
uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits);
656+
uint8_t mantissa = fp8 & mantissa_max;
672657

673-
if (exponent == 0) { // subnormal numbers
674-
fp16_exponent = 0;
675-
fp16_mantissa = (mantissa << 8);
676-
return fp16_sign | fp16_mantissa;
658+
uint16_t fp16_sign = sign << 15;
659+
uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias));
660+
if (exponent == 0) {
661+
// subnormal numbers
662+
fp16_exponent++;
663+
// mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
664+
while (!(mantissa >> mantissa_bits)) {
665+
mantissa <<= 1;
666+
fp16_exponent--;
667+
}
668+
mantissa &= mantissa_max;
677669
}
670+
uint16_t fp16_mantissa = mantissa << 7;
678671

679-
// normal numbers
680-
int16_t true_exponent = (int16_t)exponent - 15 + 15;
681-
if (true_exponent <= 0) {
682-
fp16_exponent = 0;
683-
fp16_mantissa = (mantissa << 8);
684-
} else if (true_exponent >= 0x1F) {
685-
fp16_exponent = 0x1F;
686-
fp16_mantissa = 0;
687-
} else {
688-
fp16_exponent = (uint16_t)true_exponent;
689-
fp16_mantissa = mantissa << 8;
690-
}
672+
return fp16_sign | fp16_exponent << 10 | fp16_mantissa;
673+
}
691674

692-
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
675+
uint16_t f8_e5m2_to_f16(uint8_t fp8) {
676+
// do we need to support fnuz?
677+
return static_cast<uint16_t>(fp8) << 8;
693678
}
694679

695680
void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
@@ -699,6 +684,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
699684
}
700685
}
701686

687+
void f8_e3m4_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
688+
// support inplace op
689+
for (int64_t i = n - 1; i >= 0; i--) {
690+
dst[i] = f8_e3m4_to_f16(src[i]);
691+
}
692+
}
693+
702694
void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
703695
// support inplace op
704696
for (int64_t i = n - 1; i >= 0; i--) {
@@ -946,6 +938,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
946938
ttype = GGML_TYPE_F32;
947939
} else if (dtype == "F32") {
948940
ttype = GGML_TYPE_F32;
941+
} else if (dtype == "F8_E3M4") {
942+
ttype = GGML_TYPE_F16;
949943
} else if (dtype == "F8_E4M3") {
950944
ttype = GGML_TYPE_F16;
951945
} else if (dtype == "F8_E5M2") {
@@ -1059,6 +1053,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10591053
if (dtype == "BF16") {
10601054
tensor_storage.is_bf16 = true;
10611055
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
1056+
} else if (dtype == "F8_E3M4") {
1057+
tensor_storage.is_f8_e3m4 = true;
1058+
// f8 -> f16
1059+
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
10621060
} else if (dtype == "F8_E4M3") {
10631061
tensor_storage.is_f8_e4m3 = true;
10641062
// f8 -> f16
@@ -1461,10 +1459,10 @@ SDVersion ModelLoader::get_sd_version() {
14611459
TensorStorage token_embedding_weight, input_block_weight;
14621460
bool input_block_checked = false;
14631461

1464-
bool has_multiple_encoders = false;
1465-
bool is_unet = false;
1462+
bool has_multiple_encoders = false;
1463+
bool is_unet = false;
14661464

1467-
bool is_xl = false;
1465+
bool is_xl = false;
14681466
bool is_flux = false;
14691467

14701468
#define found_family (is_xl || is_flux)
@@ -1481,7 +1479,7 @@ SDVersion ModelLoader::get_sd_version() {
14811479
}
14821480
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
14831481
is_unet = true;
1484-
if(has_multiple_encoders){
1482+
if (has_multiple_encoders) {
14851483
is_xl = true;
14861484
if (input_block_checked) {
14871485
break;
@@ -1490,7 +1488,7 @@ SDVersion ModelLoader::get_sd_version() {
14901488
}
14911489
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
14921490
has_multiple_encoders = true;
1493-
if(is_unet){
1491+
if (is_unet) {
14941492
is_xl = true;
14951493
if (input_block_checked) {
14961494
break;
@@ -1779,6 +1777,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
17791777
if (tensor_storage.is_bf16) {
17801778
// inplace op
17811779
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
1780+
} else if (tensor_storage.is_f8_e3m4) {
1781+
// inplace op
1782+
f8_e3m4_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
17821783
} else if (tensor_storage.is_f8_e4m3) {
17831784
// inplace op
17841785
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
@@ -1793,6 +1794,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
17931794
if (tensor_storage.is_bf16) {
17941795
// inplace op
17951796
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
1797+
} else if (tensor_storage.is_f8_e3m4) {
1798+
// inplace op
1799+
f8_e3m4_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
17961800
} else if (tensor_storage.is_f8_e4m3) {
17971801
// inplace op
17981802
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
@@ -1811,6 +1815,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
18111815
if (tensor_storage.is_bf16) {
18121816
// inplace op
18131817
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
1818+
} else if (tensor_storage.is_f8_e3m4) {
1819+
// inplace op
1820+
f8_e3m4_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
18141821
} else if (tensor_storage.is_f8_e4m3) {
18151822
// inplace op
18161823
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());

model.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ struct TensorStorage {
8989
std::string name;
9090
ggml_type type = GGML_TYPE_F32;
9191
bool is_bf16 = false;
92+
bool is_f8_e3m4 = false;
9293
bool is_f8_e4m3 = false;
9394
bool is_f8_e5m2 = false;
9495
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
@@ -120,7 +121,7 @@ struct TensorStorage {
120121
}
121122

122123
int64_t nbytes_to_read() const {
123-
if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) {
124+
if (is_bf16 || is_f8_e3m4 || is_f8_e4m3 || is_f8_e5m2) {
124125
return nbytes() / 2;
125126
} else {
126127
return nbytes();
@@ -168,6 +169,8 @@ struct TensorStorage {
168169
const char* type_name = ggml_type_name(type);
169170
if (is_bf16) {
170171
type_name = "bf16";
172+
} else if (is_f8_e3m4) {
173+
type_name = "f8_e3m4";
171174
} else if (is_f8_e4m3) {
172175
type_name = "f8_e4m3";
173176
} else if (is_f8_e5m2) {

0 commit comments

Comments
 (0)