Skip to content

Commit d6d5c80

Browse files
committed
Merge branch 'sbrugman-patch-2'
2 parents dac5119 + 879d5b2 commit d6d5c80

File tree

8 files changed

+853
-7929
lines changed

8 files changed

+853
-7929
lines changed

testing/examples.ipynb

Lines changed: 501 additions & 7822 deletions
Large diffs are not rendered by default.

testing/play_cmaps.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
3+
import matplotlib.pyplot as plt
4+
import matplotlib.colors as mc
5+
6+
bits = [4,8,16,32,64,128]
7+
bits = [8,16,32,64,128]
8+
nhues = len(bits)
9+
10+
blueish = '#3B75AF'
11+
greenish = '#519E3E'
12+
13+
orangeish = '#FDDB7D'
14+
limeish = '#C1E1C5'
15+
limeish = '#A8E1B0'
16+
yellowish = '#FFFFAD'
17+
18+
print(mc.hex2color(limeish))
19+
20+
type_colors = {'float':limeish, 'int':blueish, 'complex':orangeish}
21+
22+
# Derived from https://stackoverflow.com/questions/47222585/matplotlib-generic-colormap-from-tab10
23+
24+
def categorical_cmap(color, nsc):
25+
# ccolors = plt.get_cmap(cmap)(np.arange(nc, dtype=int))
26+
# print(ccolors[0:4])
27+
cols = np.zeros((nsc, 3))
28+
# chsv = mc.rgb_to_hsv(c[:3])
29+
chsv = mc.rgb_to_hsv(mc.hex2color(color))
30+
arhsv = np.tile(chsv,nsc).reshape(nsc,3)
31+
arhsv[:,1] = np.linspace(chsv[1],0.25,nsc)
32+
arhsv[:,2] = np.linspace(chsv[2],1,nsc)
33+
rgb = mc.hsv_to_rgb(arhsv)
34+
cols[0:nsc,:] = rgb
35+
cmap = mc.ListedColormap(cols)
36+
return cmap
37+
38+
plt.figure(figsize=(3,3))
39+
c1 = categorical_cmap(blueish,nhues)
40+
plt.scatter(np.arange(nhues),[1]*nhues, c=np.arange(nhues), s=1080, cmap=c1, linewidths=.5, edgecolors='grey')
41+
c1 = categorical_cmap(limeish,nhues)
42+
plt.scatter(np.arange(nhues),[2]*nhues, c=np.arange(nhues), s=1080, cmap=c1, linewidths=.5, edgecolors='grey')
43+
c1 = categorical_cmap(yellowish,nhues)
44+
plt.scatter(np.arange(nhues),[3]*nhues, c=np.arange(nhues), s=1080, cmap=c1, linewidths=.5, edgecolors='grey')
45+
46+
plt.margins(y=3)
47+
plt.xticks([])
48+
plt.yticks([0,1,2],["(5, 4)", "(2, 5)", "(4, 3)"])
49+
plt.ylim(0, 4)
50+
plt.axis('off')
51+
52+
plt.savefig("/Users/parrt/Desktop/colors.pdf")
53+
plt.show()

testing/test2.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,34 @@
1-
import torch
1+
import numpy as np
22
import tsensor
3-
nhidden = 256
4-
n = 200 # number of instances
5-
d = 764 # number of instance features
6-
n_neurons = 100 # how many neurons in this layer?
7-
8-
Whh_ = torch.eye(nhidden, nhidden)
9-
Uxh_ = torch.randn(nhidden, d)
10-
bh_ = torch.zeros(nhidden, 1)
11-
h = torch.randn(nhidden, 1) # fake previous hidden state h
12-
r = torch.randn(nhidden, 1) # fake this computation
13-
X = torch.rand(n,d) # fake input
14-
15-
g = tsensor.astviz("h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)")
16-
g.savefig("/tmp/torch-gru-ast-shapes.svg")
3+
import torch
4+
5+
W = np.array([[1, 2], [3, 4]])
6+
b = np.array([9, 10]).reshape(2, 1)
7+
x = np.array([4, 5]).reshape(2, 1)
8+
h = np.array([1, 2])
9+
with tsensor.explain(savefig="/Users/parrt/Desktop/foo.pdf"):
10+
W @ np.dot(b,b) + np.eye(2,2)@x
11+
12+
13+
W = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
14+
b = torch.tensor([9, 10]).reshape(2, 1)
15+
x = torch.tensor([4, 5], dtype=torch.int32).reshape(2, 1)
16+
h = torch.tensor([1,2])
17+
18+
a = torch.rand(size=(2, 20), dtype=torch.float64)
19+
b = torch.rand(size=(2, 20), dtype=torch.float32)
20+
c = torch.rand(size=(2,20,200), dtype=torch.complex64)
21+
d = torch.rand(size=(2,20,200,5), dtype=torch.float16)
22+
23+
24+
with tsensor.explain(savefig="/Users/parrt/Desktop/t2.pdf"):
25+
a + b + x + c[:,:,0] + d[:,:,0,0]
26+
27+
with tsensor.explain(savefig="/Users/parrt/Desktop/t3.pdf"):
28+
c
29+
30+
with tsensor.explain(savefig="/Users/parrt/Desktop/t4.pdf"):
31+
d
32+
33+
# with tsensor.explain(legend=True, savefig="/Users/parrt/Desktop/t.pdf") as e:
34+
# W @ torch.dot(b, b) + torch.eye(2, 2) @ x

testing/test3.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
import tsensor
2-
import tensorflow as tf
31
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import tsensor
4+
import torch
45

5-
W = tf.constant([[1, 2], [3, 4]])
6-
b = tf.reshape(tf.constant([[9, 10]]), (2, 1))
7-
x = tf.reshape(tf.constant([[8, 5, 7]]), (3, 1))
8-
z = 0
9-
10-
# tsensor.parse("z /= b + x * 3", hush_errors=False)
11-
12-
# with tsensor.clarify(show='viz'):
13-
# b + x * 3
6+
W = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
7+
b = torch.tensor([9, 10]).reshape(2, 1)
8+
x = torch.tensor([4, 5], dtype=torch.int32).reshape(2, 1)
9+
h = torch.tensor([1,2])
1410

1511
fig, ax = plt.subplots(1,1)
16-
tsensor.pyviz("b + x", ax=ax)
12+
view = tsensor.pyviz("b + x", ax=ax, legend=True)
13+
# view.savefig("/Users/parrt/Desktop/foo.pdf")
1714
plt.show()
1815
# with tsensor.explain():
1916
# b + x

testing/test_dtype.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import numpy as np
2+
import jax.numpy as jnp
3+
import tensorflow as tf
4+
import torch
5+
import pytest
6+
7+
import tsensor as ts
8+
9+
10+
@pytest.mark.parametrize(
11+
"value,expected",
12+
[
13+
# Numpy
14+
(np.random.randint(1, 10, size=(10, 2, 5)), "int64"),
15+
(np.random.randint(1, 10, size=(10, 2, 5), dtype="int8"), "int8"),
16+
(np.random.normal(size=(5, 1)).astype(np.float32), "float32"),
17+
(np.random.normal(size=(5, 1)).astype(np.float32), "float32"),
18+
(np.array([('Rex', 9, 81.0), ('Fido', 3, 27.0)], dtype=[('name', 'U10'), ('age', 'i4'), ('weight', 'f4')]),
19+
"str320,int32,float32"),
20+
# Jax
21+
(jnp.array([[1, 2], [3, 4]]), "int32"),
22+
(jnp.array([[1, 2], [3, 4]], dtype="int8"), "int8"),
23+
# Tensorflow
24+
(tf.constant([[1, 2], [3, 4]]), "int32"),
25+
(tf.constant([[1, 2], [3, 4]], dtype="int64"), "int64"),
26+
# Pytorch
27+
(torch.tensor([[1, 2], [3, 4]]), "int64"),
28+
(torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "int32"),
29+
],
30+
)
31+
def test_dtypes(value, expected):
32+
assert ts.analysis._dtype(value) == expected

tsensor/analysis.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def __init__(self,
4848
vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443',
4949
underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
5050
show:(None,'viz')='viz',
51-
hush_errors=True):
51+
hush_errors=True,
52+
dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
5253
"""
5354
Augment tensor-related exceptions generated from numpy, pytorch, and tensorflow.
5455
Also display a visual representation of the offending Python line that
@@ -109,15 +110,23 @@ def __init__(self,
109110
unhandled code caught by my parser are ignored. Turn this off
110111
to see what the error messages are coming from my parser.
111112
:param show: Show visualization upon tensor error if show='viz'.
113+
:param dtype_colors: map from dtype w/o precision like 'int' to color
114+
:param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
115+
:param dtype_alpha_range: all tensors of the same type are drawn to the same color,
116+
and the alpha channel is used to show precision; the
117+
smaller the bit size, the lower the alpha channel. You
118+
can play with the range to get better visual dynamic range
119+
depending on how many precisions you want to display.
112120
"""
113-
self.show = show
114-
self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
121+
self.show, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
115122
self.matrixcolor, self.vectorcolor, self.char_sep_scale,\
116123
self.fontcolor, self.underline_color, self.ignored_color, \
117-
self.error_op_color, self.hush_errors = \
118-
fontname, fontsize, dimfontname, dimfontsize, \
124+
self.error_op_color, self.hush_errors, \
125+
self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
126+
show, fontname, fontsize, dimfontname, dimfontsize, \
119127
matrixcolor, vectorcolor, char_sep_scale, \
120-
fontcolor, underline_color, ignored_color, error_op_color, hush_errors
128+
fontcolor, underline_color, ignored_color, error_op_color, hush_errors, \
129+
dtype_colors, dtype_precisions, dtype_alpha_range
121130

122131
def __enter__(self):
123132
self.frame = sys._getframe().f_back # where do we start tracking? Hmm...not sure we use this
@@ -146,7 +155,10 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
146155
self.char_sep_scale, self.fontcolor,
147156
self.underline_color, self.ignored_color,
148157
self.error_op_color,
149-
hush_errors=self.hush_errors)
158+
hush_errors=self.hush_errors,
159+
dtype_colors=self.dtype_colors,
160+
dtype_precisions=self.dtype_precisions,
161+
dtype_alpha_range=self.dtype_alpha_range)
150162
if self.view is not None: # Ignore if we can't process code causing exception (I use a subparser)
151163
if self.show=='viz':
152164
self.view.show()
@@ -159,8 +171,8 @@ def __init__(self,
159171
dimfontname='Arial', dimfontsize=9, matrixcolor="#cfe2d4",
160172
vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443',
161173
underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
162-
savefig=None,
163-
hush_errors=True):
174+
savefig=None, hush_errors=True,
175+
dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
164176
"""
165177
As the Python virtual machine executes lines of code, generate a
166178
visualization for tensor-related expressions using from numpy, pytorch,
@@ -229,15 +241,23 @@ def __init__(self,
229241
to see what the error messages are coming from my parser.
230242
:param savefig: A string indicating where to save the visualization; don't save
231243
a file if None.
244+
:param dtype_colors: map from dtype w/o precision like 'int' to color
245+
:param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
246+
:param dtype_alpha_range: all tensors of the same type are drawn to the same color,
247+
and the alpha channel is used to show precision; the
248+
smaller the bit size, the lower the alpha channel. You
249+
can play with the range to get better visual dynamic range
250+
depending on how many precisions you want to display.
232251
"""
233-
self.savefig = savefig
234-
self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
252+
self.savefig, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
235253
self.matrixcolor, self.vectorcolor, self.char_sep_scale,\
236254
self.fontcolor, self.underline_color, self.ignored_color, \
237-
self.error_op_color, self.hush_errors = \
238-
fontname, fontsize, dimfontname, dimfontsize, \
255+
self.error_op_color, self.hush_errors, \
256+
self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
257+
savefig, fontname, fontsize, dimfontname, dimfontsize, \
239258
matrixcolor, vectorcolor, char_sep_scale, \
240-
fontcolor, underline_color, ignored_color, error_op_color, hush_errors
259+
fontcolor, underline_color, ignored_color, error_op_color, hush_errors, \
260+
dtype_colors, dtype_precisions, dtype_alpha_range
241261

242262
def __enter__(self):
243263
# print("ON trace", sys._getframe())
@@ -301,6 +321,7 @@ def listener(self, frame, event, arg):
301321
filename, line = info.filename, info.lineno
302322
name = info.function
303323

324+
# Note: always true since L292 above...
304325
if event=='line':
305326
self.line_listener(module, name, filename, line, info, frame)
306327

@@ -338,7 +359,10 @@ def viz_statement(self, code, frame):
338359
self.explainer.char_sep_scale, self.explainer.fontcolor,
339360
self.explainer.underline_color, self.explainer.ignored_color,
340361
self.explainer.error_op_color,
341-
hush_errors=self.explainer.hush_errors)
362+
hush_errors=self.explainer.hush_errors,
363+
dtype_colors=self.explainer.dtype_colors,
364+
dtype_precisions=self.explainer.dtype_precisions,
365+
dtype_alpha_range=self.explainer.dtype_alpha_range)
342366
self.views.append(view)
343367
if self.explainer.savefig is not None:
344368
file_path = Path(self.explainer.savefig)
@@ -495,6 +519,23 @@ def istensor(x):
495519
return _shape(x) is not None
496520

497521

522+
def _dtype(v) -> str:
523+
if hasattr(v, "dtype"):
524+
dtype = v.dtype
525+
elif "dtype" in v.__class__.__name__:
526+
dtype = v
527+
else:
528+
return None
529+
530+
if dtype.__class__.__module__ == "torch":
531+
# ugly but works
532+
return str(dtype).replace("torch.", "")
533+
if hasattr(dtype, "names") and dtype.names is not None and hasattr(dtype, "fields"):
534+
# structured dtype: https://numpy.org/devdocs/user/basics.rec.html
535+
return ",".join([_dtype(val) for val, _ in dtype.fields.values()])
536+
return dtype.name
537+
538+
498539
def _shape(v):
499540
# do we have a shape and it answers len()? Should get stuff right.
500541
if hasattr(v, "shape") and hasattr(v.shape, "__len__"):

tsensor/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@
2121
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2222
SOFTWARE.
2323
"""
24-
__version__ = '0.1.3'
24+
__version__ = '0.2'

0 commit comments

Comments
 (0)