Tensorflow is a heavy dependency >500Mb, it would be nice replacing it by JAX, JAX also comes with neural networks @jaspreetj @flaport @sequoiap @SkandanC @AustP