Skip to content

Commit ab921c1

Browse files
committed
put type text below matrices.
1 parent 9dac96e commit ab921c1

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

testing/test2.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,18 @@
77
x = torch.tensor([4, 5], dtype=torch.int32).reshape(2, 1)
88
h = torch.tensor([1,2])
99

10-
with tsensor.explain(legend=True, savefig="/Users/parrt/Desktop/t2.pdf"):
11-
torch.rand(size=(2,20,2000,10))
10+
a = torch.rand(size=(2, 20), dtype=torch.float64)
11+
b = torch.rand(size=(2, 20), dtype=torch.float32)
12+
c = torch.rand(size=(2,20,200), dtype=torch.complex64)
13+
d = torch.rand(size=(2,20,200,5), dtype=torch.float16)
14+
with tsensor.explain(legend=False, savefig="/Users/parrt/Desktop/t2.pdf"):
15+
a + b + x + c[:,:,0] + d[:,:,0,0]
16+
17+
with tsensor.explain(legend=False, savefig="/Users/parrt/Desktop/t3.pdf"):
18+
c
19+
20+
with tsensor.explain(legend=False, savefig="/Users/parrt/Desktop/t4.pdf"):
21+
d
1222

1323
# with tsensor.explain(legend=True, savefig="/Users/parrt/Desktop/t.pdf") as e:
1424
# W @ torch.dot(b, b) + torch.eye(2, 2) @ x

tsensor/viz.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,20 @@ def draw_matrix(self,ax,sub, sh, ty):
316316
fontname=self.dimfontname, fontsize=self.dimfontsize,
317317
rotation=45)
318318

319+
bottom_text_line = self.box_topy - h - self.dim_ypadding
319320
if len(sh) > 3:
320321
# Text below
321322
remaining = r"$\cdots\mathsf{x}$"+r"$\mathsf{x}$".join([self.nabbrev(sh[i]) for i in range(3,len(sh))])
322-
ax.text(mid, self.box_topy - h - self.dim_ypadding, remaining,
323+
bottom_text_line = self.box_topy - h - self.dim_ypadding
324+
ax.text(mid, bottom_text_line, remaining,
323325
verticalalignment='top', horizontalalignment='center',
324326
fontname=self.dimfontname, fontsize=self.dimfontsize)
327+
bottom_text_line -= self.hchar
328+
329+
# Type info at the bottom of everything
330+
ax.text(mid, bottom_text_line, '<${\mathit{'+ty+'}}$>',
331+
verticalalignment='top', horizontalalignment='center',
332+
fontname=self.dimfontname, fontsize=self.dimfontsize-2)
325333

326334
@staticmethod
327335
def nabbrev(n: int) -> str:

0 commit comments

Comments
 (0)