Skip to content

Commit 3b59c96

Browse files
committed
adjust for typename in width
1 parent f0e42ee commit 3b59c96

File tree

3 files changed

+80
-29
lines changed

3 files changed

+80
-29
lines changed

testing/examples.ipynb

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@
3838
},
3939
{
4040
"cell_type": "code",
41-
"execution_count": 3,
41+
"execution_count": 2,
4242
"metadata": {},
4343
"outputs": [
4444
{
4545
"data": {
4646
"text/plain": [
47-
"'0.1.4'"
47+
"'0.2'"
4848
]
4949
},
50-
"execution_count": 3,
50+
"execution_count": 2,
5151
"metadata": {},
5252
"output_type": "execute_result"
5353
}
@@ -66,7 +66,7 @@
6666
},
6767
{
6868
"cell_type": "code",
69-
"execution_count": 4,
69+
"execution_count": 3,
7070
"metadata": {},
7171
"outputs": [],
7272
"source": [
@@ -80,18 +80,18 @@
8080
},
8181
{
8282
"cell_type": "code",
83-
"execution_count": 5,
83+
"execution_count": 4,
8484
"metadata": {},
8585
"outputs": [
8686
{
8787
"data": {
8888
"image/svg+xml": [
89-
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" height=\"45.995716pt\" version=\"1.1\" viewBox=\"0 0 209.980969 45.995716\" width=\"209.980969pt\">\n",
89+
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" height=\"52.941034pt\" version=\"1.1\" viewBox=\"0 0 209.980969 52.941034\" width=\"209.980969pt\">\n",
9090
" <metadata>\n",
9191
" <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
9292
" <cc:Work>\n",
9393
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
94-
" <dc:date>2021-12-09T16:02:20.706887</dc:date>\n",
94+
" <dc:date>2021-12-10T16:20:45.764164</dc:date>\n",
9595
" <dc:format>image/svg+xml</dc:format>\n",
9696
" <dc:creator>\n",
9797
" <cc:Agent>\n",
@@ -106,20 +106,20 @@
106106
" </defs>\n",
107107
" <g id=\"figure_1\">\n",
108108
" <g id=\"patch_1\">\n",
109-
" <path d=\"M 0 45.995716 L 209.980969 45.995716 L 209.980969 0 L 0 0 z \" style=\"fill:none;\"/>\n",
109+
" <path d=\"M 0 52.941034 L 209.980969 52.941034 L 209.980969 0 L 0 0 z \" style=\"fill:none;\"/>\n",
110110
" </g>\n",
111111
" <g id=\"axes_1\">\n",
112112
" <g id=\"patch_2\">\n",
113-
" <path clip-path=\"url(#p31913c9652)\" d=\"M 79.44246 45.180316 L 84.66534 45.180316 L 84.66534 22.919896 L 79.44246 22.919896 z \" style=\"fill:#7fa4d3;fill-opacity:0.75;stroke:#808080;stroke-linejoin:miter;stroke-width:0.7;\"/>\n",
113+
" <path clip-path=\"url(#p6eb96ee2d5)\" d=\"M 79.44246 45.180316 L 84.66534 45.180316 L 84.66534 22.919896 L 79.44246 22.919896 z \" style=\"fill:#7fa4d3;fill-opacity:0.75;stroke:#808080;stroke-linejoin:miter;stroke-width:0.7;\"/>\n",
114114
" </g>\n",
115115
" <g id=\"patch_3\">\n",
116-
" <path clip-path=\"url(#p31913c9652)\" d=\"M 92.49966 45.180316 L 97.72254 45.180316 L 97.72254 22.919896 L 92.49966 22.919896 z \" style=\"fill:#7fa4d3;fill-opacity:0.75;stroke:#808080;stroke-linejoin:miter;stroke-width:0.7;\"/>\n",
116+
" <path clip-path=\"url(#p6eb96ee2d5)\" d=\"M 92.49966 45.180316 L 97.72254 45.180316 L 97.72254 22.919896 L 92.49966 22.919896 z \" style=\"fill:#7fa4d3;fill-opacity:0.75;stroke:#808080;stroke-linejoin:miter;stroke-width:0.7;\"/>\n",
117117
" </g>\n",
118118
" <g id=\"line2d_1\">\n",
119-
" <path clip-path=\"url(#p31913c9652)\" d=\"M 78.13674 13.192653 L 85.97106 13.192653 \" style=\"fill:none;stroke:#c2c2c2;stroke-linecap:square;stroke-width:0.5;\"/>\n",
119+
" <path clip-path=\"url(#p6eb96ee2d5)\" d=\"M 78.13674 13.192653 L 85.97106 13.192653 \" style=\"fill:none;stroke:#c2c2c2;stroke-linecap:square;stroke-width:0.5;\"/>\n",
120120
" </g>\n",
121121
" <g id=\"line2d_2\">\n",
122-
" <path clip-path=\"url(#p31913c9652)\" d=\"M 91.19394 13.192653 L 99.02826 13.192653 \" style=\"fill:none;stroke:#c2c2c2;stroke-linecap:square;stroke-width:0.5;\"/>\n",
122+
" <path clip-path=\"url(#p6eb96ee2d5)\" d=\"M 91.19394 13.192653 L 99.02826 13.192653 \" style=\"fill:none;stroke:#c2c2c2;stroke-linecap:square;stroke-width:0.5;\"/>\n",
123123
" </g>\n",
124124
" <g id=\"text_1\">\n",
125125
" <!-- W -->\n",
@@ -380,21 +380,54 @@
380380
" </g>\n",
381381
" </g>\n",
382382
" <g id=\"text_34\">\n",
383+
" <!-- &lt;${\\mathit{int64}}$&gt; -->\n",
384+
" <g transform=\"translate(68.9289 51.549784)scale(0.07 -0.07)\">\n",
385+
" <defs>\n",
386+
" <path d=\"M 5.46875 31.296875 L 5.46875 39.5 L 52.875 59.515625 L 52.875 50.78125 L 15.28125 35.359375 L 52.875 19.78125 L 52.875 11.03125 z \" id=\"ArialMT-60\"/>\n",
387+
" <path d=\"M 18.3125 75.984375 L 27.296875 75.984375 L 25.09375 64.59375 L 16.109375 64.59375 z M 14.203125 54.6875 L 23.1875 54.6875 L 12.5 0 L 3.515625 0 z \" id=\"DejaVuSans-Oblique-105\"/>\n",
388+
" <path d=\"M 55.71875 33.015625 L 49.3125 0 L 40.28125 0 L 46.6875 32.671875 Q 47.125 34.96875 47.359375 36.71875 Q 47.609375 38.484375 47.609375 39.5 Q 47.609375 43.609375 45.015625 45.890625 Q 42.4375 48.1875 37.796875 48.1875 Q 30.5625 48.1875 25.34375 43.375 Q 20.125 38.578125 18.5 30.328125 L 12.5 0 L 3.515625 0 L 14.109375 54.6875 L 23.09375 54.6875 L 21.296875 46.09375 Q 25.046875 50.828125 30.3125 53.40625 Q 35.59375 56 41.40625 56 Q 48.640625 56 52.609375 52.09375 Q 56.59375 48.1875 56.59375 41.109375 Q 56.59375 39.359375 56.375 37.359375 Q 56.15625 35.359375 55.71875 33.015625 z \" id=\"DejaVuSans-Oblique-110\"/>\n",
389+
" <path d=\"M 42.28125 54.6875 L 40.921875 47.703125 L 23 47.703125 L 17.1875 18.015625 Q 16.890625 16.359375 16.75 15.234375 Q 16.609375 14.109375 16.609375 13.484375 Q 16.609375 10.359375 18.484375 8.9375 Q 20.359375 7.515625 24.515625 7.515625 L 33.59375 7.515625 L 32.078125 0 L 23.484375 0 Q 15.484375 0 11.546875 3.125 Q 7.625 6.25 7.625 12.59375 Q 7.625 13.71875 7.765625 15.0625 Q 7.90625 16.40625 8.203125 18.015625 L 14.015625 47.703125 L 6.390625 47.703125 L 7.8125 54.6875 L 15.28125 54.6875 L 18.3125 70.21875 L 27.296875 70.21875 L 24.3125 54.6875 z \" id=\"DejaVuSans-Oblique-116\"/>\n",
390+
" <path d=\"M 33.015625 40.375 Q 26.375 40.375 22.484375 35.828125 Q 18.609375 31.296875 18.609375 23.390625 Q 18.609375 15.53125 22.484375 10.953125 Q 26.375 6.390625 33.015625 6.390625 Q 39.65625 6.390625 43.53125 10.953125 Q 47.40625 15.53125 47.40625 23.390625 Q 47.40625 31.296875 43.53125 35.828125 Q 39.65625 40.375 33.015625 40.375 z M 52.59375 71.296875 L 52.59375 62.3125 Q 48.875 64.0625 45.09375 64.984375 Q 41.3125 65.921875 37.59375 65.921875 Q 27.828125 65.921875 22.671875 59.328125 Q 17.53125 52.734375 16.796875 39.40625 Q 19.671875 43.65625 24.015625 45.921875 Q 28.375 48.1875 33.59375 48.1875 Q 44.578125 48.1875 50.953125 41.515625 Q 57.328125 34.859375 57.328125 23.390625 Q 57.328125 12.15625 50.6875 5.359375 Q 44.046875 -1.421875 33.015625 -1.421875 Q 20.359375 -1.421875 13.671875 8.265625 Q 6.984375 17.96875 6.984375 36.375 Q 6.984375 53.65625 15.1875 63.9375 Q 23.390625 74.21875 37.203125 74.21875 Q 40.921875 74.21875 44.703125 73.484375 Q 48.484375 72.75 52.59375 71.296875 z \" id=\"DejaVuSans-54\"/>\n",
391+
" <path d=\"M 37.796875 64.3125 L 12.890625 25.390625 L 37.796875 25.390625 z M 35.203125 72.90625 L 47.609375 72.90625 L 47.609375 25.390625 L 58.015625 25.390625 L 58.015625 17.1875 L 47.609375 17.1875 L 47.609375 0 L 37.796875 0 L 37.796875 17.1875 L 4.890625 17.1875 L 4.890625 26.703125 z \" id=\"DejaVuSans-52\"/>\n",
392+
" <path d=\"M 52.875 31.296875 L 5.46875 11.03125 L 5.46875 19.78125 L 43.015625 35.359375 L 5.46875 50.78125 L 5.46875 59.515625 L 52.875 39.5 z \" id=\"ArialMT-62\"/>\n",
393+
" </defs>\n",
394+
" <use transform=\"translate(0 0.015625)\" xlink:href=\"#ArialMT-60\"/>\n",
395+
" <use transform=\"translate(58.398438 0.015625)\" xlink:href=\"#DejaVuSans-Oblique-105\"/>\n",
396+
" <use transform=\"translate(86.181641 0.015625)\" xlink:href=\"#DejaVuSans-Oblique-110\"/>\n",
397+
" <use transform=\"translate(149.560547 0.015625)\" xlink:href=\"#DejaVuSans-Oblique-116\"/>\n",
398+
" <use transform=\"translate(188.769531 0.015625)\" xlink:href=\"#DejaVuSans-54\"/>\n",
399+
" <use transform=\"translate(252.392578 0.015625)\" xlink:href=\"#DejaVuSans-52\"/>\n",
400+
" <use transform=\"translate(316.015625 0.015625)\" xlink:href=\"#ArialMT-62\"/>\n",
401+
" </g>\n",
402+
" </g>\n",
403+
" <g id=\"text_35\">\n",
383404
" <!-- 2 -->\n",
384405
" <g transform=\"translate(90.71091 36.552528)rotate(-90)scale(0.09 -0.09)\">\n",
385406
" <use xlink:href=\"#ArialMT-50\"/>\n",
386407
" </g>\n",
387408
" </g>\n",
388-
" <g id=\"text_35\">\n",
409+
" <g id=\"text_36\">\n",
389410
" <!-- 1 -->\n",
390411
" <g transform=\"translate(92.608678 21.560896)scale(0.09 -0.09)\">\n",
391412
" <use xlink:href=\"#ArialMT-49\"/>\n",
392413
" </g>\n",
393414
" </g>\n",
415+
" <g id=\"text_37\">\n",
416+
" <!-- &lt;${\\mathit{int64}}$&gt; -->\n",
417+
" <g transform=\"translate(81.9861 51.549784)scale(0.07 -0.07)\">\n",
418+
" <use transform=\"translate(0 0.015625)\" xlink:href=\"#ArialMT-60\"/>\n",
419+
" <use transform=\"translate(58.398438 0.015625)\" xlink:href=\"#DejaVuSans-Oblique-105\"/>\n",
420+
" <use transform=\"translate(86.181641 0.015625)\" xlink:href=\"#DejaVuSans-Oblique-110\"/>\n",
421+
" <use transform=\"translate(149.560547 0.015625)\" xlink:href=\"#DejaVuSans-Oblique-116\"/>\n",
422+
" <use transform=\"translate(188.769531 0.015625)\" xlink:href=\"#DejaVuSans-54\"/>\n",
423+
" <use transform=\"translate(252.392578 0.015625)\" xlink:href=\"#DejaVuSans-52\"/>\n",
424+
" <use transform=\"translate(316.015625 0.015625)\" xlink:href=\"#ArialMT-62\"/>\n",
425+
" </g>\n",
426+
" </g>\n",
394427
" </g>\n",
395428
" </g>\n",
396429
" <defs>\n",
397-
" <clipPath id=\"p31913c9652\">\n",
430+
" <clipPath id=\"p6eb96ee2d5\">\n",
398431
" <rect height=\"41.491069\" width=\"209.3616\" x=\"0\" y=\"4.504646\"/>\n",
399432
" </clipPath>\n",
400433
" </defs>\n",
@@ -10458,7 +10491,7 @@
1045810491
],
1045910492
"metadata": {
1046010493
"kernelspec": {
10461-
"display_name": "Python 3 (ipykernel)",
10494+
"display_name": "Python 3",
1046210495
"language": "python",
1046310496
"name": "python3"
1046410497
},

testing/test2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
import tsensor
33
import torch
44

5+
W = np.array([[1, 2], [3, 4]])
6+
b = np.array([9, 10]).reshape(2, 1)
7+
x = np.array([4, 5]).reshape(2, 1)
8+
h = np.array([1, 2])
9+
with tsensor.explain(savefig="/Users/parrt/Desktop/foo.pdf"):
10+
W @ np.dot(b,b) + np.eye(2,2)@x
11+
12+
513
W = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
614
b = torch.tensor([9, 10]).reshape(2, 1)
715
x = torch.tensor([4, 5], dtype=torch.int32).reshape(2, 1)
@@ -11,6 +19,8 @@
1119
b = torch.rand(size=(2, 20), dtype=torch.float32)
1220
c = torch.rand(size=(2,20,200), dtype=torch.complex64)
1321
d = torch.rand(size=(2,20,200,5), dtype=torch.float16)
22+
23+
1424
with tsensor.explain(savefig="/Users/parrt/Desktop/t2.pdf"):
1525
a + b + x + c[:,:,0] + d[:,:,0,0]
1626

tsensor/viz.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -205,27 +205,31 @@ def boxsize(self, v):
205205
How wide and tall should we draw the box representing a vector or matrix.
206206
"""
207207
sh = tsensor.analysis._shape(v)
208+
ty = tsensor.analysis._dtype(v)
208209
if sh is None: return None
209-
if len(sh)==1: return self.vector_size(sh)
210-
return self.matrix_size(sh)
210+
if len(sh)==1: return self.vector_size(sh, ty)
211+
return self.matrix_size(sh, ty)
211212

212-
def matrix_size(self, sh):
213+
def matrix_size(self, sh, ty):
213214
"""
214215
How wide and tall should we draw the box representing a matrix.
215216
"""
216217
if len(sh)==1 and sh[0]==1:
217-
return self.vector_size(sh)
218-
elif len(sh) > 1 and sh[0] == 1 and sh[1] == 1:
218+
return self.vector_size(sh, ty)
219+
220+
if len(sh) > 1 and sh[0] == 1 and sh[1] == 1:
219221
# A special case where we have a 1x1 matrix extending into the screen.
220222
# Make the 1x1 part a little bit wider than a vector so it's more readable
221-
return 2 * self.vector_size_scaler * self.wchar, 2 * self.vector_size_scaler * self.wchar
223+
w, h = 2 * self.vector_size_scaler * self.wchar, 2 * self.vector_size_scaler * self.wchar
222224
elif len(sh) > 1 and sh[1] == 1:
223-
return self.vector_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
225+
w, h = self.vector_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
224226
elif len(sh)>1 and sh[0]==1:
225-
return self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar
226-
return self.matrix_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
227+
w, h = self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar
228+
else:
229+
w, h = self.matrix_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
230+
return w, h
227231

228-
def vector_size(self, sh):
232+
def vector_size(self, sh, ty):
229233
"""
230234
How wide and tall is a vector? It's not a function of vector length; instead
231235
we make a row vector with same width as a matrix but height of just one char.
@@ -262,7 +266,7 @@ def draw_vector(self,ax,sub, sh, ty: str):
262266

263267
def draw_matrix(self,ax,sub, sh, ty):
264268
mid = (sub.leftx + sub.rightx) / 2
265-
w,h = self.matrix_size(sh)
269+
w,h = self.matrix_size(sh, ty)
266270
box_left = mid - w / 2
267271
color = self.get_dtype_color(ty)
268272

@@ -440,13 +444,17 @@ def pyviz(statement: str, frame=None,
440444
maxh = 0
441445
for sub in subexprs:
442446
w, h = view.boxsize(sub.value)
447+
# update width to include horizontal room for type text like int32
448+
ty = tsensor.analysis._dtype(sub.value)
449+
w_typename = len(ty)/2 * view.wchar
450+
w += w_typename
443451
maxh = max(h, maxh)
444452
nexpr = sub.stop.cstop_idx - sub.start.cstart_idx
445-
if (sub.start.cstart_idx-1)>0 and statement[sub.start.cstart_idx - 1]== ' ': # if char to left is space
453+
if (sub.start.cstart_idx-1)>0 and statement[sub.start.cstart_idx - 1]== ' ': # if char to left is space
446454
nexpr += 1
447-
if sub.stop.cstop_idx<len(statement) and statement[sub.stop.cstop_idx]== ' ': # if char to right is space
455+
if sub.stop.cstop_idx<len(statement) and statement[sub.stop.cstop_idx]== ' ': # if char to right is space
448456
nexpr += 1
449-
if w>view.wchar * nexpr:
457+
if w > view.wchar * nexpr:
450458
lpad[sub.start.cstart_idx] += (w - view.wchar) / 2
451459
rpad[sub.stop.cstop_idx - 1] += (w - view.wchar) / 2
452460

0 commit comments

Comments
 (0)