-
-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
Looking at
Line 551 in a9d875e
@register(jax.lax.while_p) |
scan
is supposed to be implemented generically for all Quax objects. I took a quick try at an implementation like this:
@quax.register(lax.scan_p)
def _(
*args: Union[quax.ArrayValue, ArrayLike],
reverse: bool,
length: int,
jaxpr,
num_consts: int,
num_carry: int,
linear,
unroll: int = 1,
_split_transpose: Optional[bool] = None,
):
consts = args[:num_consts]
init = args[num_consts : num_consts + num_carry]
xs = args[num_consts + num_carry :]
quax_f = quax.quaxify(jax.core.jaxpr_as_fun(jaxpr))
quax_jaxpr = jax.make_jaxpr(quax_f)(*consts, *init, *xs)
const_leaves, _ = jtu.tree_flatten(consts)
init_leaves, init_treedef = jtu.tree_flatten(init)
xs_leaves, _ = jtu.tree_flatten(xs)
out_flat = lax.scan_p.bind(
*const_leaves,
*init_leaves,
*xs_leaves,
reverse=reverse,
length=length,
jaxpr=quax_jaxpr,
num_consts=num_consts,
num_carry=num_carry,
linear=linear,
unroll=unroll,
_split_transpose=_split_transpose,
)
# _initial_style_jaxpr(quax_f, , , "scan")
carry_nvals = len(init_leaves)
carry, ys = out_flat[:carry_nvals], out_flat[carry_nvals:]
carry_out = jtu.tree_unflatten(init_treedef, carry)
return carry_out, None
But there are at least two problems:
- When I actually use it, I get:
File /usr/local/lib/python3.11/dist-packages/quax/_core.py:194, in <listcomp>(.0)
[192](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:192) out = method(*values, **params)
[193](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:193) if primitive.multiple_results:
--> [194](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:194) return [_QuaxTracer(self, _wrap_if_array(x)) for x in out] # pyright: ignore
[195](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:195) else:
[196](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:196) return _QuaxTracer(self, _wrap_if_array(out))
File /usr/local/lib/python3.11/dist-packages/quax/_core.py:84, in _QuaxTracer.__init__(self, trace, value)
[83](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:83) def __init__(self, trace: "_QuaxTrace", value: "Value"):
---> [84](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:84) assert _is_value(value)
[85](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:85) self._trace = trace
[86](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:86) self.value = value
- I haven't implemented the part to extract the
out_treedef
for the second element of the return value. I think this would be possible by following the same strategy as what's in https://github.com/google/jax/blob/ebc6c1815297c79bc1c9c907aaf858d70caef5e6/jax/_src/lax/control_flow/loops.py#L123, but I'm not sure if there's a simpler way.
Having scan
available is pretty handy for use with the scan over layers technique.
Metadata
Metadata
Assignees
Labels
No labels