Skip to content

Building a differentiable compute graph dynamically, a la Pytorch #17715

Answered by jakevdp
eguiraud asked this question in Q&A
Discussion options

You must be logged in to vote

I don't see any reason why you can't just evaluate the expression directly in jax, just as you do in pytorch. For example, this uses the same compute function from your pytorch example:

from functools import partial
import jax.numpy as jnp
import jax

ast = {"+": [{"*": [{"in": "x"}, {"in": "x"}]}, {"in": "y"}]}

def compute(ast: dict, inputs: dict) -> jax.Array:
    """Walk the AST, apply the corresponding operations to the inputs"""
    key, value = next(iter(ast.items()))
    if key == "+":
        return compute(value[0], inputs) + compute(value[1], inputs)
    elif key == "*":
        return compute(value[0], inputs) * compute(value[1], inputs)
    elif key == "in":
        return in…

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@eguiraud
Comment options

@jakevdp
Comment options

Answer selected by eguiraud
@eguiraud
Comment options

@jakevdp
Comment options

@eguiraud
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants