@@ -48,7 +48,8 @@ def __init__(self,
48
48
vectorcolor = "#fefecd" , char_sep_scale = 1.8 , fontcolor = '#444443' ,
49
49
underline_color = '#C2C2C2' , ignored_color = '#B4B4B4' , error_op_color = '#A40227' ,
50
50
show :(None ,'viz' )= 'viz' ,
51
- hush_errors = True ):
51
+ hush_errors = True ,
52
+ dtype_colors = None , dtype_precisions = None , dtype_alpha_range = None ):
52
53
"""
53
54
Augment tensor-related exceptions generated from numpy, pytorch, and tensorflow.
54
55
Also display a visual representation of the offending Python line that
@@ -109,15 +110,23 @@ def __init__(self,
109
110
unhandled code caught by my parser are ignored. Turn this off
110
111
to see what the error messages are coming from my parser.
111
112
:param show: Show visualization upon tensor error if show='viz'.
113
+ :param dtype_colors: map from dtype w/o precision like 'int' to color
114
+ :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
115
+ :param dtype_alpha_range: all tensors of the same type are drawn to the same color,
116
+ and the alpha channel is used to show precision; the
117
+ smaller the bit size, the lower the alpha channel. You
118
+ can play with the range to get better visual dynamic range
119
+ depending on how many precisions you want to display.
112
120
"""
113
- self .show = show
114
- self .fontname , self .fontsize , self .dimfontname , self .dimfontsize , \
121
+ self .show , self .fontname , self .fontsize , self .dimfontname , self .dimfontsize , \
115
122
self .matrixcolor , self .vectorcolor , self .char_sep_scale ,\
116
123
self .fontcolor , self .underline_color , self .ignored_color , \
117
- self .error_op_color , self .hush_errors = \
118
- fontname , fontsize , dimfontname , dimfontsize , \
124
+ self .error_op_color , self .hush_errors , \
125
+ self .dtype_colors , self .dtype_precisions , self .dtype_alpha_range = \
126
+ show , fontname , fontsize , dimfontname , dimfontsize , \
119
127
matrixcolor , vectorcolor , char_sep_scale , \
120
- fontcolor , underline_color , ignored_color , error_op_color , hush_errors
128
+ fontcolor , underline_color , ignored_color , error_op_color , hush_errors , \
129
+ dtype_colors , dtype_precisions , dtype_alpha_range
121
130
122
131
def __enter__ (self ):
123
132
self .frame = sys ._getframe ().f_back # where do we start tracking? Hmm...not sure we use this
@@ -146,7 +155,10 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
146
155
self .char_sep_scale , self .fontcolor ,
147
156
self .underline_color , self .ignored_color ,
148
157
self .error_op_color ,
149
- hush_errors = self .hush_errors )
158
+ hush_errors = self .hush_errors ,
159
+ dtype_colors = self .dtype_colors ,
160
+ dtype_precisions = self .dtype_precisions ,
161
+ dtype_alpha_range = self .dtype_alpha_range )
150
162
if self .view is not None : # Ignore if we can't process code causing exception (I use a subparser)
151
163
if self .show == 'viz' :
152
164
self .view .show ()
@@ -159,8 +171,8 @@ def __init__(self,
159
171
dimfontname = 'Arial' , dimfontsize = 9 , matrixcolor = "#cfe2d4" ,
160
172
vectorcolor = "#fefecd" , char_sep_scale = 1.8 , fontcolor = '#444443' ,
161
173
underline_color = '#C2C2C2' , ignored_color = '#B4B4B4' , error_op_color = '#A40227' ,
162
- savefig = None ,
163
- hush_errors = True ):
174
+ savefig = None , hush_errors = True ,
175
+ dtype_colors = None , dtype_precisions = None , dtype_alpha_range = None ):
164
176
"""
165
177
As the Python virtual machine executes lines of code, generate a
166
178
visualization for tensor-related expressions using from numpy, pytorch,
@@ -229,15 +241,23 @@ def __init__(self,
229
241
to see what the error messages are coming from my parser.
230
242
:param savefig: A string indicating where to save the visualization; don't save
231
243
a file if None.
244
+ :param dtype_colors: map from dtype w/o precision like 'int' to color
245
+ :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
246
+ :param dtype_alpha_range: all tensors of the same type are drawn to the same color,
247
+ and the alpha channel is used to show precision; the
248
+ smaller the bit size, the lower the alpha channel. You
249
+ can play with the range to get better visual dynamic range
250
+ depending on how many precisions you want to display.
232
251
"""
233
- self .savefig = savefig
234
- self .fontname , self .fontsize , self .dimfontname , self .dimfontsize , \
252
+ self .savefig , self .fontname , self .fontsize , self .dimfontname , self .dimfontsize , \
235
253
self .matrixcolor , self .vectorcolor , self .char_sep_scale ,\
236
254
self .fontcolor , self .underline_color , self .ignored_color , \
237
- self .error_op_color , self .hush_errors = \
238
- fontname , fontsize , dimfontname , dimfontsize , \
255
+ self .error_op_color , self .hush_errors , \
256
+ self .dtype_colors , self .dtype_precisions , self .dtype_alpha_range = \
257
+ savefig , fontname , fontsize , dimfontname , dimfontsize , \
239
258
matrixcolor , vectorcolor , char_sep_scale , \
240
- fontcolor , underline_color , ignored_color , error_op_color , hush_errors
259
+ fontcolor , underline_color , ignored_color , error_op_color , hush_errors , \
260
+ dtype_colors , dtype_precisions , dtype_alpha_range
241
261
242
262
def __enter__ (self ):
243
263
# print("ON trace", sys._getframe())
@@ -301,6 +321,7 @@ def listener(self, frame, event, arg):
301
321
filename , line = info .filename , info .lineno
302
322
name = info .function
303
323
324
+ # Note: always true since L292 above...
304
325
if event == 'line' :
305
326
self .line_listener (module , name , filename , line , info , frame )
306
327
@@ -338,7 +359,10 @@ def viz_statement(self, code, frame):
338
359
self .explainer .char_sep_scale , self .explainer .fontcolor ,
339
360
self .explainer .underline_color , self .explainer .ignored_color ,
340
361
self .explainer .error_op_color ,
341
- hush_errors = self .explainer .hush_errors )
362
+ hush_errors = self .explainer .hush_errors ,
363
+ dtype_colors = self .explainer .dtype_colors ,
364
+ dtype_precisions = self .explainer .dtype_precisions ,
365
+ dtype_alpha_range = self .explainer .dtype_alpha_range )
342
366
self .views .append (view )
343
367
if self .explainer .savefig is not None :
344
368
file_path = Path (self .explainer .savefig )
@@ -495,6 +519,23 @@ def istensor(x):
495
519
return _shape (x ) is not None
496
520
497
521
522
+ def _dtype (v ) -> str :
523
+ if hasattr (v , "dtype" ):
524
+ dtype = v .dtype
525
+ elif "dtype" in v .__class__ .__name__ :
526
+ dtype = v
527
+ else :
528
+ return None
529
+
530
+ if dtype .__class__ .__module__ == "torch" :
531
+ # ugly but works
532
+ return str (dtype ).replace ("torch." , "" )
533
+ if hasattr (dtype , "names" ) and dtype .names is not None and hasattr (dtype , "fields" ):
534
+ # structured dtype: https://numpy.org/devdocs/user/basics.rec.html
535
+ return "," .join ([_dtype (val ) for val , _ in dtype .fields .values ()])
536
+ return dtype .name
537
+
538
+
498
539
def _shape (v ):
499
540
# do we have a shape and it answers len()? Should get stuff right.
500
541
if hasattr (v , "shape" ) and hasattr (v .shape , "__len__" ):
0 commit comments