Skip to content

Commit 8694be5

Browse files
authored
Merge pull request #20 from parrt/jax
Add JAX support
2 parents c1da933 + e5ed697 commit 8694be5

File tree

7 files changed

+1637
-433
lines changed

7 files changed

+1637
-433
lines changed

README.md

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Tensor Sensor
22

3-
<img src="https://explained.ai/tensor-sensor/images/teaser.png" width="50%" align="right">One of the biggest challenges when writing code to implement deep learning networks, particularly for us newbies, is getting all of the tensor (matrix and vector) dimensions to line up properly. It's really easy to lose track of tensor dimensionality in complicated expressions involving multiple tensors and tensor operations. Even when just feeding data into predefined [Tensorflow](https://www.tensorflow.org/) network layers, we still need to get the dimensions right. When you ask for improper computations, you're going to run into some less than helpful exception messages. To help myself and other programmers debug tensor code, I built this library. TensorSensor clarifies exceptions by augmenting messages and visualizing Python code to indicate the shape of tensor variables (see figure to the right for a teaser). It works with [Tensorflow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/), and [Numpy](https://numpy.org/), as well as higher-level libraries like [Keras](https://keras.io/) and [fastai](https://www.fast.ai/).
3+
<img src="https://explained.ai/tensor-sensor/images/teaser.png" width="50%" align="right">One of the biggest challenges when writing code to implement deep learning networks, particularly for us newbies, is getting all of the tensor (matrix and vector) dimensions to line up properly. It's really easy to lose track of tensor dimensionality in complicated expressions involving multiple tensors and tensor operations. Even when just feeding data into predefined [Tensorflow](https://www.tensorflow.org/) network layers, we still need to get the dimensions right. When you ask for improper computations, you're going to run into some less than helpful exception messages.
4+
5+
To help myself and other programmers debug tensor code, I built this library. TensorSensor clarifies exceptions by augmenting messages and visualizing Python code to indicate the shape of tensor variables (see figure to the right for a teaser). It works with [Tensorflow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/), [JAX](https://github.com/google/jax), and [Numpy](https://numpy.org/), as well as higher-level libraries like [Keras](https://keras.io/) and [fastai](https://www.fast.ai/).
46

57
Please read the complete description in article [Clarifying exceptions and visualizing tensor operations in deep learning code](https://explained.ai/tensor-sensor/index.html).
68

@@ -35,14 +37,24 @@ TensorSensor augments the message with more information about which operator cau
3537
Cause: @ on tensor operand W w/shape [764, 100] and operand X.T w/shape [764, 200]
3638
```
3739

40+
You can also get the full computation graph for an expression that includes all of these sub result shapes.
41+
42+
```python
43+
tsensor.astviz("b = W@b + (h+3).dot(h) + torch.abs(torch.tensor(34))", sys._getframe())
44+
```
45+
46+
yields the following abstract syntax tree with shapes:
47+
48+
<img src="images/ast.svg" width="400">
3849

3950
## Install
4051

4152
```
4253
pip install tensor-sensor # This will only install the library for you
4354
pip install tensor-sensor[torch] # install pytorch related dependency
4455
pip install tensor-sensor[tensorflow] # install tensorflow related dependency
45-
pip install tensor-sensor[all] # install both tensorflow and pytorch
56+
pip install tensor-sensor[jax] # install jax, jaxlib
57+
pip install tensor-sensor[all] # install tensorflow, pytorch, jax
4658
```
4759

4860
which gives you module `tsensor`. I developed and tested with the following versions
@@ -56,6 +68,9 @@ numpy 1.18.5
5668
numpydoc 1.1.0
5769
$ pip list | grep -i torch
5870
torch 1.6.0
71+
$ pip list | grep -i jax
72+
jax 0.2.6
73+
jaxlib 0.1.57
5974
```
6075

6176
### Graphviz for tsensor.astviz()
@@ -115,5 +130,3 @@ $ pip install .
115130
### TODO
116131

117132
* can i call pyviz in debugger?
118-
* try on real examples
119-
* `dict(W=[3,0,1,2], b=[1,0])` that would indicate (300, 30, 60, 3) would best be displayed as (30,60,3, 300) and b would be first dimension last and last dimension first

0 commit comments

Comments
 (0)