Skip to content

Move to JAX #63

@arunoruto

Description

@arunoruto

While Numba enables JIT and parallelized Python code on the CPU, the GPU aspect is lacking in features and regressing (see the depreciation of AMD support). One aspect of this problem is the batching of data #24.

JAX is a solid alternative to Numba and ships with similar features and extends them even further. Some notable differences are:

  • Code can be executed on CPU, GPU, and TPU (Tensor Processing Unit, usually used on Google Cloud solution) by switching an environmental variable
  • GPU incorporates Nvidia for now, but an experimental AMD integration is currently available and functions to a certain degree
  • Vector mapping enables easy parallelization of the code and possibly batching using jax.lax.map
  • Using an environmental variable, one can toggle between single and double precision, with single precision usually being enough but speeding things up!
  • Compared to Numba, JAX has a whole "ecosystem" of modules which build on top of it, making scientific computing with Numba much easier

A package rework is planned and done on a separate branch until all tests pass!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions