@@ -537,8 +537,10 @@ void process_shaders() {
537
537
for (auto src0_f16 : {false , true }) {
538
538
for (auto src1_f16 : {false , true }) {
539
539
for (auto dst_f16 : {false , true }) {
540
- auto name = op + get_suffix (src0_f16, src1_f16, dst_f16);
541
- string_to_spv (name.c_str (), op + " .comp" , {{" A_TYPE" , get_type_str (src0_f16)}, {" B_TYPE" , get_type_str (src1_f16)}, {" D_TYPE" , get_type_str (dst_f16)}, {" FLOAT_TYPE" , " float" }});
540
+ for (auto rte : {false , true }) {
541
+ auto name = op + get_suffix (src0_f16, src1_f16, dst_f16) + (rte ? " _rte" : " " );
542
+ string_to_spv (name.c_str (), op + " .comp" , {{" A_TYPE" , get_type_str (src0_f16)}, {" B_TYPE" , get_type_str (src1_f16)}, {" D_TYPE" , get_type_str (dst_f16)}, {" FLOAT_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
543
+ }
542
544
}
543
545
}
544
546
}
@@ -592,16 +594,19 @@ void process_shaders() {
592
594
string_to_spv (" sigmoid_f16" , " sigmoid.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
593
595
string_to_spv (" sigmoid_f32" , " sigmoid.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
594
596
595
- string_to_spv (" geglu_f16" , " geglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
596
- string_to_spv (" geglu_f32" , " geglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
597
- string_to_spv (" reglu_f16" , " reglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
598
- string_to_spv (" reglu_f32" , " reglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
599
- string_to_spv (" swiglu_f16" , " swiglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
600
- string_to_spv (" swiglu_f32" , " swiglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
601
- string_to_spv (" geglu_erf_f16" , " geglu_erf.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
602
- string_to_spv (" geglu_erf_f32" , " geglu_erf.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
603
- string_to_spv (" geglu_quick_f16" ," geglu_quick.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
604
- string_to_spv (" geglu_quick_f32" ," geglu_quick.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
597
+ for (auto rte : {false , true }) {
598
+ std::string suffix = rte ? " _rte" : " " ;
599
+ string_to_spv (" geglu_f16" + suffix, " geglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
600
+ string_to_spv (" geglu_f32" + suffix, " geglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
601
+ string_to_spv (" reglu_f16" + suffix, " reglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
602
+ string_to_spv (" reglu_f32" + suffix, " reglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
603
+ string_to_spv (" swiglu_f16" + suffix, " swiglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
604
+ string_to_spv (" swiglu_f32" + suffix, " swiglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
605
+ string_to_spv (" geglu_erf_f16" + suffix, " geglu_erf.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
606
+ string_to_spv (" geglu_erf_f32" + suffix, " geglu_erf.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
607
+ string_to_spv (" geglu_quick_f16" + suffix," geglu_quick.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
608
+ string_to_spv (" geglu_quick_f32" + suffix," geglu_quick.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
609
+ }
605
610
606
611
string_to_spv (" leaky_relu_f32" , " leaky_relu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
607
612
string_to_spv (" silu_back_f32" , " silu_back.comp" , {{" A_TYPE" , " float" }, {" B_TYPE" , " float" }, {" D_TYPE" , " float" }});
@@ -709,11 +714,59 @@ void write_output_files() {
709
714
std::remove (path.c_str ());
710
715
}
711
716
}
717
+
718
+ std::string suffixes[2 ] = {" _f32" , " _f16" };
712
719
for (const char *op : {" add" , " sub" , " mul" , " div" }) {
713
- fprintf (hdr, " extern unsigned char *%s_data[2][2][2];\n " , op);
714
- fprintf (hdr, " extern uint64_t %s_len[2][2][2];\n " , op);
715
- fprintf (src, " unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n " , op, op, op, op, op, op, op, op, op);
716
- fprintf (src, " uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n " , op, op, op, op, op, op, op, op, op);
720
+ fprintf (hdr, " extern unsigned char *%s_data[2][2][2][2];\n " , op);
721
+ fprintf (hdr, " extern uint64_t %s_len[2][2][2][2];\n " , op);
722
+ std::string data = " unsigned char *" + std::string (op) + " _data[2][2][2][2] = " ;
723
+ std::string len = " uint64_t " + std::string (op) + " _len[2][2][2][2] = " ;
724
+ for (uint32_t t0 = 0 ; t0 < 2 ; ++t0) {
725
+ if (t0 == 0 ) {
726
+ data += " {" ;
727
+ len += " {" ;
728
+ }
729
+ for (uint32_t t1 = 0 ; t1 < 2 ; ++t1) {
730
+ if (t1 == 0 ) {
731
+ data += " {" ;
732
+ len += " {" ;
733
+ }
734
+ for (uint32_t t2 = 0 ; t2 < 2 ; ++t2) {
735
+ if (t2 == 0 ) {
736
+ data += " {" ;
737
+ len += " {" ;
738
+ }
739
+ for (uint32_t rte = 0 ; rte < 2 ; ++rte) {
740
+ if (rte == 0 ) {
741
+ data += " {" ;
742
+ len += " {" ;
743
+ }
744
+ data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0 ) ? " _rte" : " " );
745
+ len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0 ) ? " _rte" : " " );
746
+ data += " _data," ;
747
+ len += " _len," ;
748
+ if (rte == 1 ) {
749
+ data += " }, " ;
750
+ len += " }, " ;
751
+ }
752
+ }
753
+ if (t2 == 1 ) {
754
+ data += " }, " ;
755
+ len += " }, " ;
756
+ }
757
+ }
758
+ if (t1 == 1 ) {
759
+ data += " }, " ;
760
+ len += " }, " ;
761
+ }
762
+ }
763
+ if (t0 == 1 ) {
764
+ data += " };\n " ;
765
+ len += " };\n " ;
766
+ }
767
+ }
768
+ fprintf (src, data.c_str ());
769
+ fprintf (src, len.c_str ());
717
770
}
718
771
fclose (hdr);
719
772
fclose (src);
0 commit comments