Skip to content

Commit 2991ce1

Browse files
authored
Add print_svg for mma (#1733)
* add print_svg for mma * correct the code indentation
1 parent 1ebda1c commit 2991ce1

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed

include/cute/atom/mma_atom.hpp

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,183 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and
928928
printf(latex_footer);
929929
}
930930

931+
// MNK MMA Layout to SVG -- 8-value color coded by thread
932+
template <class LayoutC, class ThrIDC,
933+
class LayoutA, class ThrIDA,
934+
class LayoutB, class ThrIDB>
935+
CUTE_HOST_DEVICE
936+
void
937+
print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx
938+
LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx
939+
LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx
940+
{
941+
char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175",
942+
"255,175,175", "210,210,255", "210,255,210",
943+
"255,255,210", "255,210,210"};
944+
945+
const int cell_width = 20;
946+
const int cell_height = 20;
947+
948+
const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width;
949+
const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height;
950+
951+
// header
952+
printf("<svg width=\"100%%\" height=\"100%%\" viewBox=\"0 0 %d %d\" "
953+
"preserveAspectRatio=\"xMidYMid meet\" "
954+
"xmlns=\"http://www.w3.org/2000/svg\">\n",
955+
page_width, page_height);
956+
957+
// C
958+
int c_base_x = (size<1>(A) + 2) * cell_width;
959+
int c_base_y = (size<1>(B) + 2) * cell_height;
960+
for (int m = 0; m < cute::size<0>(C); ++m) {
961+
for (int n = 0; n < cute::size<1>(C); ++n) {
962+
963+
int thrid = C(m, n) % size(TC);
964+
int val_idx = C(m, n) / size(TC);
965+
int thr_idx = TC(thrid);
966+
967+
int x = n * cell_width + c_base_x;
968+
int y = m * cell_height + c_base_y;
969+
970+
int thr_x = x + cell_width / 2;
971+
int thr_y = y + cell_height / 4;
972+
int val_x = x + cell_width / 2;
973+
int val_y = y + cell_height * 3 / 4;
974+
975+
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
976+
"fill=\"rgb(%s)\" stroke=\"black\"/>\n",
977+
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
978+
979+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
980+
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
981+
thr_x, thr_y, thr_idx);
982+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
983+
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
984+
val_x, val_y, val_idx);
985+
}
986+
}
987+
988+
// A
989+
int a_base_x = cell_width;
990+
int a_base_y = (size<1>(B) + 2) * cell_height;
991+
for (int m = 0; m < size<0>(A); ++m) {
992+
for (int k = 0; k < size<1>(A); ++k) {
993+
int thrid = A(m, k) % size(TA);
994+
int val_idx = A(m, k) / size(TA);
995+
int thr_idx = TA(thrid);
996+
997+
int x = k * cell_width + a_base_x;
998+
int y = m * cell_height + a_base_y;
999+
1000+
int thr_x = x + cell_width / 2;
1001+
int thr_y = y + cell_height / 4;
1002+
int val_x = x + cell_width / 2;
1003+
int val_y = y + cell_height * 3 / 4;
1004+
1005+
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
1006+
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
1007+
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
1008+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
1009+
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
1010+
thr_x, thr_y, thr_idx);
1011+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
1012+
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
1013+
val_x, val_y, val_idx);
1014+
}
1015+
}
1016+
1017+
// B
1018+
int b_base_x = (size<1>(A) + 2) * cell_width;
1019+
int b_base_y = cell_height;
1020+
for (int n = 0; n < size<0>(B); ++n) {
1021+
for (int k = 0; k < size<1>(B); ++k) {
1022+
int thrid = B(n, k) % size(TB);
1023+
int val_idx = B(n, k) / size(TB);
1024+
int thr_idx = TB(thrid);
1025+
1026+
int x = n * cell_width + b_base_x;
1027+
int y = k * cell_height + b_base_y;
1028+
1029+
int thr_x = x + cell_width / 2;
1030+
int thr_y = y + cell_height / 4;
1031+
int val_x = x + cell_width / 2;
1032+
int val_y = y + cell_height * 3 / 4;
1033+
1034+
printf("<rect x=\"%d\" y=\"%d\" width=\"%d\" height=\"%d\" "
1035+
"fill=\"rgb(%s)\" stroke=\"black\" />\n",
1036+
x, y, cell_width, cell_height, color_map[thr_idx % 8]);
1037+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
1038+
"alignment-baseline=\"central\" font-size=\"8\">T%d</text>\n",
1039+
thr_x, thr_y, thr_idx);
1040+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
1041+
"alignment-baseline=\"central\" font-size=\"8\">V%d</text>\n",
1042+
val_x, val_y, val_idx);
1043+
}
1044+
}
1045+
1046+
// A labels
1047+
for (int m = 0; m < size<0>(A); ++m) {
1048+
int x = cell_width / 2;
1049+
int y = m * cell_height + cell_height / 2 + a_base_y;
1050+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
1051+
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
1052+
x, y, m);
1053+
}
1054+
for (int k = 0; k < size<1>(A); ++k) {
1055+
int x = cell_width + k * cell_width + cell_width / 2;
1056+
int y = -cell_height / 2 + a_base_y;
1057+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
1058+
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
1059+
x, y, k);
1060+
}
1061+
1062+
// B labels
1063+
for (int n = 0; n < size<0>(B); ++n) {
1064+
int x = b_base_x + cell_width * n + cell_width / 2;
1065+
int y = cell_height / 2;
1066+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
1067+
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
1068+
x, y, n);
1069+
}
1070+
for (int k = 0; k < size<1>(B); ++k) {
1071+
int x = b_base_x - cell_width / 2;
1072+
int y = cell_height * (k + 1) + cell_height / 2;
1073+
printf("<text x=\"%d\" y=\"%d\" text-anchor=\"middle\" "
1074+
"alignment-baseline=\"central\" font-size=\"12\">%d</text>\n",
1075+
x, y, k);
1076+
}
1077+
1078+
// footer
1079+
printf("</svg>");
1080+
}
1081+
1082+
template <class... Args>
1083+
CUTE_HOST_DEVICE
1084+
void
1085+
print_svg(MMA_Atom<Args...> const &mma_atom) {
1086+
print_svg(make_tiled_mma(mma_atom));
1087+
}
1088+
1089+
template <class... Args>
1090+
CUTE_HOST_DEVICE
1091+
void
1092+
print_svg(TiledMMA<Args...> const &mma) {
1093+
auto layout_and_thrid_C = mma.get_layoutC_MN();
1094+
auto layoutC_MN = get<0>(layout_and_thrid_C);
1095+
auto thrID_C = get<1>(layout_and_thrid_C);
1096+
1097+
auto layout_and_thrid_A = mma.get_layoutA_MK();
1098+
auto layoutA_MK = get<0>(layout_and_thrid_A);
1099+
auto thrID_A = get<1>(layout_and_thrid_A);
1100+
1101+
auto layout_and_thrid_B = mma.get_layoutB_NK();
1102+
auto layoutB_NK = get<0>(layout_and_thrid_B);
1103+
auto thrID_B = get<1>(layout_and_thrid_B);
1104+
1105+
print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B);
1106+
}
1107+
9311108
} // namespace cute
9321109

9331110
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)