Skip to content

Commit f0e42ee

Browse files
committed
rm legend
1 parent ab921c1 commit f0e42ee

File tree

3 files changed

+15
-39
lines changed

3 files changed

+15
-39
lines changed

testing/test2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
b = torch.rand(size=(2, 20), dtype=torch.float32)
1212
c = torch.rand(size=(2,20,200), dtype=torch.complex64)
1313
d = torch.rand(size=(2,20,200,5), dtype=torch.float16)
14-
with tsensor.explain(legend=False, savefig="/Users/parrt/Desktop/t2.pdf"):
14+
with tsensor.explain(savefig="/Users/parrt/Desktop/t2.pdf"):
1515
a + b + x + c[:,:,0] + d[:,:,0,0]
1616

17-
with tsensor.explain(legend=False, savefig="/Users/parrt/Desktop/t3.pdf"):
17+
with tsensor.explain(savefig="/Users/parrt/Desktop/t3.pdf"):
1818
c
1919

20-
with tsensor.explain(legend=False, savefig="/Users/parrt/Desktop/t4.pdf"):
20+
with tsensor.explain(savefig="/Users/parrt/Desktop/t4.pdf"):
2121
d
2222

2323
# with tsensor.explain(legend=True, savefig="/Users/parrt/Desktop/t.pdf") as e:

tsensor/analysis.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def __init__(self,
4949
underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
5050
show:(None,'viz')='viz',
5151
hush_errors=True,
52-
dtype_colors=None, dtype_precisions=None,
53-
dtype_alpha_range=None, legend=False):
52+
dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
5453
"""
5554
Augment tensor-related exceptions generated from numpy, pytorch, and tensorflow.
5655
Also display a visual representation of the offending Python line that
@@ -118,17 +117,16 @@ def __init__(self,
118117
smaller the bit size, the lower the alpha channel. You
119118
can play with the range to get better visual dynamic range
120119
depending on how many precisions you want to display.
121-
:param legend: boolean: should a legend for the types encountered be presented?
122120
"""
123121
self.show, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
124122
self.matrixcolor, self.vectorcolor, self.char_sep_scale,\
125123
self.fontcolor, self.underline_color, self.ignored_color, \
126124
self.error_op_color, self.hush_errors, \
127-
self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range, self.legend = \
125+
self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
128126
show, fontname, fontsize, dimfontname, dimfontsize, \
129127
matrixcolor, vectorcolor, char_sep_scale, \
130128
fontcolor, underline_color, ignored_color, error_op_color, hush_errors, \
131-
dtype_colors, dtype_precisions, dtype_alpha_range, legend
129+
dtype_colors, dtype_precisions, dtype_alpha_range
132130

133131
def __enter__(self):
134132
self.frame = sys._getframe().f_back # where do we start tracking? Hmm...not sure we use this
@@ -159,7 +157,8 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
159157
self.error_op_color,
160158
hush_errors=self.hush_errors,
161159
dtype_colors=self.dtype_colors,
162-
legend=self.legend)
160+
dtype_precisions=self.dtype_precisions,
161+
dtype_alpha_range=self.dtype_alpha_range)
163162
if self.view is not None: # Ignore if we can't process code causing exception (I use a subparser)
164163
if self.show=='viz':
165164
self.view.show()
@@ -173,8 +172,7 @@ def __init__(self,
173172
vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443',
174173
underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
175174
savefig=None, hush_errors=True,
176-
dtype_colors=None, dtype_precisions=None,
177-
dtype_alpha_range=None, legend=False):
175+
dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
178176
"""
179177
As the Python virtual machine executes lines of code, generate a
180178
visualization for tensor-related expressions using from numpy, pytorch,
@@ -250,17 +248,16 @@ def __init__(self,
250248
smaller the bit size, the lower the alpha channel. You
251249
can play with the range to get better visual dynamic range
252250
depending on how many precisions you want to display.
253-
:param legend: boolean: should a legend for the types encountered be presented?
254251
"""
255252
self.savefig, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
256253
self.matrixcolor, self.vectorcolor, self.char_sep_scale,\
257254
self.fontcolor, self.underline_color, self.ignored_color, \
258255
self.error_op_color, self.hush_errors, \
259-
self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range, self.legend = \
256+
self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
260257
savefig, fontname, fontsize, dimfontname, dimfontsize, \
261258
matrixcolor, vectorcolor, char_sep_scale, \
262259
fontcolor, underline_color, ignored_color, error_op_color, hush_errors, \
263-
dtype_colors, dtype_precisions, dtype_alpha_range, legend
260+
dtype_colors, dtype_precisions, dtype_alpha_range
264261

265262
def __enter__(self):
266263
# print("ON trace", sys._getframe())
@@ -365,8 +362,7 @@ def viz_statement(self, code, frame):
365362
hush_errors=self.explainer.hush_errors,
366363
dtype_colors=self.explainer.dtype_colors,
367364
dtype_precisions=self.explainer.dtype_precisions,
368-
dtype_alpha_range=self.explainer.dtype_alpha_range,
369-
legend=self.explainer.legend)
365+
dtype_alpha_range=self.explainer.dtype_alpha_range)
370366
self.views.append(view)
371367
if self.explainer.savefig is not None:
372368
file_path = Path(self.explainer.savefig)

tsensor/viz.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ class PyVizView:
4848
"""
4949
def __init__(self, statement, fontname, fontsize, dimfontname, dimfontsize,
5050
matrixcolor, vectorcolor, char_sep_scale, dpi,
51-
dtype_colors=None, dtype_precisions=None,
52-
dtype_alpha_range=None, legend=False):
51+
dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
5352
if dtype_colors is None:
5453
orangeish = '#FDD66C'
5554
limeish = '#A8E1B0'
@@ -83,7 +82,6 @@ def __init__(self, statement, fontname, fontsize, dimfontname, dimfontsize,
8382
for c, v in dtype_colors.items():
8483
self._dtype_shades[c] = \
8584
PyVizView._get_alpha_shades(v, n=nshades, alpha_range=dtype_alpha_range)
86-
self.legend = legend
8785
self.wchar = self.char_sep_scale * self.fontsize
8886
self.hchar = self.char_sep_scale * self.fontsize
8987
self.dim_ypadding = 5
@@ -127,13 +125,6 @@ def _split_dtype_precision(s):
127125
tail = s[len(head):]
128126
return head, tail
129127

130-
def get_dtype_legend_patches(self):
131-
labels, colors = [], []
132-
for name in self._dtype_encountered:
133-
labels.append(name)
134-
colors.append(self.get_dtype_color(name))
135-
return labels, colors
136-
137128
def get_dtype_color(self, dtype):
138129
"""Get color based on type and precision."""
139130
dtype_name, dtype_precision = self._split_dtype_precision(dtype)
@@ -346,8 +337,7 @@ def pyviz(statement: str, frame=None,
346337
vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443',
347338
underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
348339
ax=None, dpi=200, hush_errors=True,
349-
dtype_colors=None, dtype_precisions=None,
350-
dtype_alpha_range=None, legend=False) -> PyVizView:
340+
dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> PyVizView:
351341
"""
352342
Parse and evaluate the Python code in the statement string passed in using
353343
the indicated execution frame. The execution frame of the invoking function
@@ -407,13 +397,12 @@ def pyviz(statement: str, frame=None,
407397
smaller the bit size, the lower the alpha channel. You
408398
can play with the range to get better visual dynamic range
409399
depending on how many precisions you want to display.
410-
:param legend: boolean: should a legend for the types encountered be presented?
411400
:return: Returns a PyVizView holding info about the visualization; from a notebook
412401
an SVG image will appear. Return none upon parsing error in statement.
413402
"""
414403
view = PyVizView(statement, fontname, fontsize, dimfontname, dimfontsize, matrixcolor,
415404
vectorcolor, char_sep_scale, dpi,
416-
dtype_colors, dtype_precisions, dtype_alpha_range, legend)
405+
dtype_colors, dtype_precisions, dtype_alpha_range)
417406

418407
if frame is None: # use frame of caller if not passed in
419408
frame = sys._getframe().f_back
@@ -514,15 +503,6 @@ def pyviz(statement: str, frame=None,
514503
ax.set_xlim(0, fig_width)
515504
ax.set_ylim(0, view.maxy)
516505

517-
if view.legend:
518-
labels, colors = view.get_dtype_legend_patches()
519-
legend_patches = [
520-
patches.Patch(facecolor=c, label=l, edgecolor='grey')
521-
for c, l in zip(colors, labels)
522-
]
523-
view.legend = fig.legend(legend_patches, labels, loc='center left', fontsize=8, bbox_to_anchor=(1, 0.5))
524-
ax.add_artist(view.legend)
525-
526506
return view
527507

528508

0 commit comments

Comments
 (0)