@@ -928,6 +928,183 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and
928
928
printf (latex_footer);
929
929
}
930
930
931
+ // MNK MMA Layout to SVG -- 8-value color coded by thread
932
+ template <class LayoutC , class ThrIDC ,
933
+ class LayoutA , class ThrIDA ,
934
+ class LayoutB , class ThrIDB >
935
+ CUTE_HOST_DEVICE
936
+ void
937
+ print_svg_mma (LayoutC const & C, ThrIDC const & TC, // (m,n) -> (tid,vid) and tid -> thr_idx
938
+ LayoutA const & A, ThrIDA const & TA, // (m,k) -> (tid,vid) and tid -> thr_idx
939
+ LayoutB const & B, ThrIDB const & TB) // (n,k) -> (tid,vid) and tid -> thr_idx
940
+ {
941
+ char const *color_map[8 ] = {" 175,175,255" , " 175,255,175" , " 255,255,175" ,
942
+ " 255,175,175" , " 210,210,255" , " 210,255,210" ,
943
+ " 255,255,210" , " 255,210,210" };
944
+
945
+ const int cell_width = 20 ;
946
+ const int cell_height = 20 ;
947
+
948
+ const int page_width = (size<1 >(A) + size<0 >(B) + 2 ) * cell_width;
949
+ const int page_height = (size<1 >(B) + size<0 >(A) + 2 ) * cell_height;
950
+
951
+ // header
952
+ printf (" <svg width=\" 100%%\" height=\" 100%%\" viewBox=\" 0 0 %d %d\" "
953
+ " preserveAspectRatio=\" xMidYMid meet\" "
954
+ " xmlns=\" http://www.w3.org/2000/svg\" >\n " ,
955
+ page_width, page_height);
956
+
957
+ // C
958
+ int c_base_x = (size<1 >(A) + 2 ) * cell_width;
959
+ int c_base_y = (size<1 >(B) + 2 ) * cell_height;
960
+ for (int m = 0 ; m < cute::size<0 >(C); ++m) {
961
+ for (int n = 0 ; n < cute::size<1 >(C); ++n) {
962
+
963
+ int thrid = C (m, n) % size (TC);
964
+ int val_idx = C (m, n) / size (TC);
965
+ int thr_idx = TC (thrid);
966
+
967
+ int x = n * cell_width + c_base_x;
968
+ int y = m * cell_height + c_base_y;
969
+
970
+ int thr_x = x + cell_width / 2 ;
971
+ int thr_y = y + cell_height / 4 ;
972
+ int val_x = x + cell_width / 2 ;
973
+ int val_y = y + cell_height * 3 / 4 ;
974
+
975
+ printf (" <rect x=\" %d\" y=\" %d\" width=\" %d\" height=\" %d\" "
976
+ " fill=\" rgb(%s)\" stroke=\" black\" />\n " ,
977
+ x, y, cell_width, cell_height, color_map[thr_idx % 8 ]);
978
+
979
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
980
+ " alignment-baseline=\" central\" font-size=\" 8\" >T%d</text>\n " ,
981
+ thr_x, thr_y, thr_idx);
982
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
983
+ " alignment-baseline=\" central\" font-size=\" 8\" >V%d</text>\n " ,
984
+ val_x, val_y, val_idx);
985
+ }
986
+ }
987
+
988
+ // A
989
+ int a_base_x = cell_width;
990
+ int a_base_y = (size<1 >(B) + 2 ) * cell_height;
991
+ for (int m = 0 ; m < size<0 >(A); ++m) {
992
+ for (int k = 0 ; k < size<1 >(A); ++k) {
993
+ int thrid = A (m, k) % size (TA);
994
+ int val_idx = A (m, k) / size (TA);
995
+ int thr_idx = TA (thrid);
996
+
997
+ int x = k * cell_width + a_base_x;
998
+ int y = m * cell_height + a_base_y;
999
+
1000
+ int thr_x = x + cell_width / 2 ;
1001
+ int thr_y = y + cell_height / 4 ;
1002
+ int val_x = x + cell_width / 2 ;
1003
+ int val_y = y + cell_height * 3 / 4 ;
1004
+
1005
+ printf (" <rect x=\" %d\" y=\" %d\" width=\" %d\" height=\" %d\" "
1006
+ " fill=\" rgb(%s)\" stroke=\" black\" />\n " ,
1007
+ x, y, cell_width, cell_height, color_map[thr_idx % 8 ]);
1008
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
1009
+ " alignment-baseline=\" central\" font-size=\" 8\" >T%d</text>\n " ,
1010
+ thr_x, thr_y, thr_idx);
1011
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
1012
+ " alignment-baseline=\" central\" font-size=\" 8\" >V%d</text>\n " ,
1013
+ val_x, val_y, val_idx);
1014
+ }
1015
+ }
1016
+
1017
+ // B
1018
+ int b_base_x = (size<1 >(A) + 2 ) * cell_width;
1019
+ int b_base_y = cell_height;
1020
+ for (int n = 0 ; n < size<0 >(B); ++n) {
1021
+ for (int k = 0 ; k < size<1 >(B); ++k) {
1022
+ int thrid = B (n, k) % size (TB);
1023
+ int val_idx = B (n, k) / size (TB);
1024
+ int thr_idx = TB (thrid);
1025
+
1026
+ int x = n * cell_width + b_base_x;
1027
+ int y = k * cell_height + b_base_y;
1028
+
1029
+ int thr_x = x + cell_width / 2 ;
1030
+ int thr_y = y + cell_height / 4 ;
1031
+ int val_x = x + cell_width / 2 ;
1032
+ int val_y = y + cell_height * 3 / 4 ;
1033
+
1034
+ printf (" <rect x=\" %d\" y=\" %d\" width=\" %d\" height=\" %d\" "
1035
+ " fill=\" rgb(%s)\" stroke=\" black\" />\n " ,
1036
+ x, y, cell_width, cell_height, color_map[thr_idx % 8 ]);
1037
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
1038
+ " alignment-baseline=\" central\" font-size=\" 8\" >T%d</text>\n " ,
1039
+ thr_x, thr_y, thr_idx);
1040
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
1041
+ " alignment-baseline=\" central\" font-size=\" 8\" >V%d</text>\n " ,
1042
+ val_x, val_y, val_idx);
1043
+ }
1044
+ }
1045
+
1046
+ // A labels
1047
+ for (int m = 0 ; m < size<0 >(A); ++m) {
1048
+ int x = cell_width / 2 ;
1049
+ int y = m * cell_height + cell_height / 2 + a_base_y;
1050
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
1051
+ " alignment-baseline=\" central\" font-size=\" 12\" >%d</text>\n " ,
1052
+ x, y, m);
1053
+ }
1054
+ for (int k = 0 ; k < size<1 >(A); ++k) {
1055
+ int x = cell_width + k * cell_width + cell_width / 2 ;
1056
+ int y = -cell_height / 2 + a_base_y;
1057
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
1058
+ " alignment-baseline=\" central\" font-size=\" 12\" >%d</text>\n " ,
1059
+ x, y, k);
1060
+ }
1061
+
1062
+ // B labels
1063
+ for (int n = 0 ; n < size<0 >(B); ++n) {
1064
+ int x = b_base_x + cell_width * n + cell_width / 2 ;
1065
+ int y = cell_height / 2 ;
1066
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
1067
+ " alignment-baseline=\" central\" font-size=\" 12\" >%d</text>\n " ,
1068
+ x, y, n);
1069
+ }
1070
+ for (int k = 0 ; k < size<1 >(B); ++k) {
1071
+ int x = b_base_x - cell_width / 2 ;
1072
+ int y = cell_height * (k + 1 ) + cell_height / 2 ;
1073
+ printf (" <text x=\" %d\" y=\" %d\" text-anchor=\" middle\" "
1074
+ " alignment-baseline=\" central\" font-size=\" 12\" >%d</text>\n " ,
1075
+ x, y, k);
1076
+ }
1077
+
1078
+ // footer
1079
+ printf (" </svg>" );
1080
+ }
1081
+
1082
+ template <class ... Args>
1083
+ CUTE_HOST_DEVICE
1084
+ void
1085
+ print_svg (MMA_Atom<Args...> const &mma_atom) {
1086
+ print_svg (make_tiled_mma (mma_atom));
1087
+ }
1088
+
1089
+ template <class ... Args>
1090
+ CUTE_HOST_DEVICE
1091
+ void
1092
+ print_svg (TiledMMA<Args...> const &mma) {
1093
+ auto layout_and_thrid_C = mma.get_layoutC_MN ();
1094
+ auto layoutC_MN = get<0 >(layout_and_thrid_C);
1095
+ auto thrID_C = get<1 >(layout_and_thrid_C);
1096
+
1097
+ auto layout_and_thrid_A = mma.get_layoutA_MK ();
1098
+ auto layoutA_MK = get<0 >(layout_and_thrid_A);
1099
+ auto thrID_A = get<1 >(layout_and_thrid_A);
1100
+
1101
+ auto layout_and_thrid_B = mma.get_layoutB_NK ();
1102
+ auto layoutB_NK = get<0 >(layout_and_thrid_B);
1103
+ auto thrID_B = get<1 >(layout_and_thrid_B);
1104
+
1105
+ print_svg_mma (layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B);
1106
+ }
1107
+
931
1108
} // namespace cute
932
1109
933
1110
// //////////////////////////////////////////////////////////////////////////////////////////////////
0 commit comments