@@ -609,87 +609,72 @@ float bf16_to_f32(uint16_t bfloat16) {
609
609
return *reinterpret_cast <float *>(&val_bits);
610
610
}
611
611
612
- uint16_t f8_e4m3_to_f16 (uint8_t f8) {
613
- // do we need to support uz?
614
-
615
- const uint32_t exponent_bias = 7 ;
616
- if (f8 == 0xff ) {
617
- return ggml_fp32_to_fp16 (-NAN);
618
- } else if (f8 == 0x7f ) {
619
- return ggml_fp32_to_fp16 (NAN);
612
+ uint16_t f8_e3m4_to_f16 (uint8_t fp8) {
613
+ if ((fp8 & 0x7F ) == 0 || (fp8 & 0x7F ) == 0x7F ) {
614
+ // +/- 0 or NaN
615
+ return static_cast <uint16_t >(fp8) << 8 ;
620
616
}
617
+ const uint32_t exponent_bias = 0x3 ; // 2^(3-1)-1
618
+ const uint32_t f16_bias = 0xF ; // 2^(5-1)-1
619
+ const int mantissa_bits = 4 ;
620
+ const uint8_t mantissa_max = 0xF ; // 2^4-1
621
621
622
- uint32_t sign = f8 & 0x80 ;
623
- uint32_t exponent = (f8 & 0x78 ) >> 3 ;
624
- uint32_t mantissa = f8 & 0x07 ;
625
- uint32_t result = sign << 24 ;
626
- if (exponent == 0 ) {
627
- if (mantissa > 0 ) {
628
- exponent = 0x7f - exponent_bias;
629
-
630
- // yes, 2 times
631
- if ((mantissa & 0x04 ) == 0 ) {
632
- mantissa &= 0x03 ;
633
- mantissa <<= 1 ;
634
- exponent -= 1 ;
635
- }
636
- if ((mantissa & 0x04 ) == 0 ) {
637
- mantissa &= 0x03 ;
638
- mantissa <<= 1 ;
639
- exponent -= 1 ;
640
- }
622
+ uint8_t sign = (fp8 >> 7 ) & 0x1 ;
623
+ uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits);
624
+ uint8_t mantissa = fp8 & mantissa_max;
641
625
642
- result |= (mantissa & 0x03 ) << 21 ;
643
- result |= exponent << 23 ;
626
+ uint16_t fp16_sign = sign << 15 ;
627
+ uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias));
628
+ if (exponent == 0 ) {
629
+ // subnormal numbers
630
+ fp16_exponent++;
631
+ // mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
632
+ while (!(mantissa >> mantissa_bits)) {
633
+ mantissa <<= 1 ;
634
+ fp16_exponent--;
644
635
}
645
- } else {
646
- result |= mantissa << 20 ;
647
- exponent += 0x7f - exponent_bias;
648
- result |= exponent << 23 ;
636
+ mantissa &= mantissa_max;
649
637
}
638
+ uint16_t fp16_mantissa = mantissa << 6 ;
650
639
651
- return ggml_fp32_to_fp16 (* reinterpret_cast < const float *>(&result)) ;
640
+ return fp16_sign | fp16_exponent << 10 | fp16_mantissa ;
652
641
}
653
642
654
- uint16_t f8_e5m2_to_f16 (uint8_t fp8) {
655
- uint8_t sign = (fp8 >> 7 ) & 0x1 ;
656
- uint8_t exponent = (fp8 >> 2 ) & 0x1F ;
657
- uint8_t mantissa = fp8 & 0x3 ;
658
-
659
- uint16_t fp16_sign = sign << 15 ;
660
- uint16_t fp16_exponent;
661
- uint16_t fp16_mantissa;
662
-
663
- if (exponent == 0 && mantissa == 0 ) { // zero
664
- return fp16_sign;
643
+ uint16_t f8_e4m3_to_f16 (uint8_t fp8) {
644
+ // do we need to support uz?
645
+ if ((fp8 & 0x7F ) == 0 || (fp8 & 0x7F ) == 0x7F ) {
646
+ // +/- 0 or NaN
647
+ return static_cast <uint16_t >(fp8) << 8 ;
665
648
}
649
+ const uint32_t exponent_bias = 0x7 ; // 2^(4-1)-1
650
+ const uint32_t f16_bias = 0xF ; // 2^(5-1)-1
651
+ const int mantissa_bits = 3 ;
652
+ const uint8_t mantissa_max = 0x7 ; // 2^3-1
666
653
667
- if (exponent == 0x1F ) { // NAN and INF
668
- fp16_exponent = 0x1F ;
669
- fp16_mantissa = mantissa ? (mantissa << 8 ) : 0 ;
670
- return fp16_sign | (fp16_exponent << 10 ) | fp16_mantissa;
671
- }
654
+ uint8_t sign = (fp8 >> 7 ) & 0x1 ;
655
+ uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits);
656
+ uint8_t mantissa = fp8 & mantissa_max;
672
657
673
- if (exponent == 0 ) { // subnormal numbers
674
- fp16_exponent = 0 ;
675
- fp16_mantissa = (mantissa << 8 );
676
- return fp16_sign | fp16_mantissa;
658
+ uint16_t fp16_sign = sign << 15 ;
659
+ uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias));
660
+ if (exponent == 0 ) {
661
+ // subnormal numbers
662
+ fp16_exponent++;
663
+ // mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
664
+ while (!(mantissa >> mantissa_bits)) {
665
+ mantissa <<= 1 ;
666
+ fp16_exponent--;
667
+ }
668
+ mantissa &= mantissa_max;
677
669
}
670
+ uint16_t fp16_mantissa = mantissa << 7 ;
678
671
679
- // normal numbers
680
- int16_t true_exponent = (int16_t )exponent - 15 + 15 ;
681
- if (true_exponent <= 0 ) {
682
- fp16_exponent = 0 ;
683
- fp16_mantissa = (mantissa << 8 );
684
- } else if (true_exponent >= 0x1F ) {
685
- fp16_exponent = 0x1F ;
686
- fp16_mantissa = 0 ;
687
- } else {
688
- fp16_exponent = (uint16_t )true_exponent;
689
- fp16_mantissa = mantissa << 8 ;
690
- }
672
+ return fp16_sign | fp16_exponent << 10 | fp16_mantissa;
673
+ }
691
674
692
- return fp16_sign | (fp16_exponent << 10 ) | fp16_mantissa;
675
+ uint16_t f8_e5m2_to_f16 (uint8_t fp8) {
676
+ // do we need to support fnuz?
677
+ return static_cast <uint16_t >(fp8) << 8 ;
693
678
}
694
679
695
680
void bf16_to_f32_vec (uint16_t * src, float * dst, int64_t n) {
@@ -699,6 +684,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
699
684
}
700
685
}
701
686
687
+ void f8_e3m4_to_f16_vec (uint8_t * src, uint16_t * dst, int64_t n) {
688
+ // support inplace op
689
+ for (int64_t i = n - 1 ; i >= 0 ; i--) {
690
+ dst[i] = f8_e3m4_to_f16 (src[i]);
691
+ }
692
+ }
693
+
702
694
void f8_e4m3_to_f16_vec (uint8_t * src, uint16_t * dst, int64_t n) {
703
695
// support inplace op
704
696
for (int64_t i = n - 1 ; i >= 0 ; i--) {
@@ -946,6 +938,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
946
938
ttype = GGML_TYPE_F32;
947
939
} else if (dtype == " F32" ) {
948
940
ttype = GGML_TYPE_F32;
941
+ } else if (dtype == " F8_E3M4" ) {
942
+ ttype = GGML_TYPE_F16;
949
943
} else if (dtype == " F8_E4M3" ) {
950
944
ttype = GGML_TYPE_F16;
951
945
} else if (dtype == " F8_E5M2" ) {
@@ -1059,6 +1053,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
1059
1053
if (dtype == " BF16" ) {
1060
1054
tensor_storage.is_bf16 = true ;
1061
1055
GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1056
+ } else if (dtype == " F8_E3M4" ) {
1057
+ tensor_storage.is_f8_e3m4 = true ;
1058
+ // f8 -> f16
1059
+ GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1062
1060
} else if (dtype == " F8_E4M3" ) {
1063
1061
tensor_storage.is_f8_e4m3 = true ;
1064
1062
// f8 -> f16
@@ -1461,10 +1459,10 @@ SDVersion ModelLoader::get_sd_version() {
1461
1459
TensorStorage token_embedding_weight, input_block_weight;
1462
1460
bool input_block_checked = false ;
1463
1461
1464
- bool has_multiple_encoders = false ;
1465
- bool is_unet = false ;
1462
+ bool has_multiple_encoders = false ;
1463
+ bool is_unet = false ;
1466
1464
1467
- bool is_xl = false ;
1465
+ bool is_xl = false ;
1468
1466
bool is_flux = false ;
1469
1467
1470
1468
#define found_family (is_xl || is_flux)
@@ -1481,7 +1479,7 @@ SDVersion ModelLoader::get_sd_version() {
1481
1479
}
1482
1480
if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos) {
1483
1481
is_unet = true ;
1484
- if (has_multiple_encoders){
1482
+ if (has_multiple_encoders) {
1485
1483
is_xl = true ;
1486
1484
if (input_block_checked) {
1487
1485
break ;
@@ -1490,7 +1488,7 @@ SDVersion ModelLoader::get_sd_version() {
1490
1488
}
1491
1489
if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos || tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos) {
1492
1490
has_multiple_encoders = true ;
1493
- if (is_unet){
1491
+ if (is_unet) {
1494
1492
is_xl = true ;
1495
1493
if (input_block_checked) {
1496
1494
break ;
@@ -1779,6 +1777,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
1779
1777
if (tensor_storage.is_bf16 ) {
1780
1778
// inplace op
1781
1779
bf16_to_f32_vec ((uint16_t *)dst_tensor->data , (float *)dst_tensor->data , tensor_storage.nelements ());
1780
+ } else if (tensor_storage.is_f8_e3m4 ) {
1781
+ // inplace op
1782
+ f8_e3m4_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
1782
1783
} else if (tensor_storage.is_f8_e4m3 ) {
1783
1784
// inplace op
1784
1785
f8_e4m3_to_f16_vec ((uint8_t *)dst_tensor->data , (uint16_t *)dst_tensor->data , tensor_storage.nelements ());
@@ -1793,6 +1794,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
1793
1794
if (tensor_storage.is_bf16 ) {
1794
1795
// inplace op
1795
1796
bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1797
+ } else if (tensor_storage.is_f8_e3m4 ) {
1798
+ // inplace op
1799
+ f8_e3m4_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
1796
1800
} else if (tensor_storage.is_f8_e4m3 ) {
1797
1801
// inplace op
1798
1802
f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
@@ -1811,6 +1815,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
1811
1815
if (tensor_storage.is_bf16 ) {
1812
1816
// inplace op
1813
1817
bf16_to_f32_vec ((uint16_t *)read_buffer.data (), (float *)read_buffer.data (), tensor_storage.nelements ());
1818
+ } else if (tensor_storage.is_f8_e3m4 ) {
1819
+ // inplace op
1820
+ f8_e3m4_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
1814
1821
} else if (tensor_storage.is_f8_e4m3 ) {
1815
1822
// inplace op
1816
1823
f8_e4m3_to_f16_vec ((uint8_t *)read_buffer.data (), (uint16_t *)read_buffer.data (), tensor_storage.nelements ());
0 commit comments