Skip to content

scan primitive implementation missing #29

@colehaus

Description

@colehaus

Looking at

@register(jax.lax.while_p)
, I assume scan is supposed to be implemented generically for all Quax objects. I took a quick try at an implementation like this:

@quax.register(lax.scan_p)
def _(
    *args: Union[quax.ArrayValue, ArrayLike],
    reverse: bool,
    length: int,
    jaxpr,
    num_consts: int,
    num_carry: int,
    linear,
    unroll: int = 1,
    _split_transpose: Optional[bool] = None,
):
    consts = args[:num_consts]
    init = args[num_consts : num_consts + num_carry]
    xs = args[num_consts + num_carry :]

    quax_f = quax.quaxify(jax.core.jaxpr_as_fun(jaxpr))
    quax_jaxpr = jax.make_jaxpr(quax_f)(*consts, *init, *xs)

    const_leaves, _ = jtu.tree_flatten(consts)
    init_leaves, init_treedef = jtu.tree_flatten(init)
    xs_leaves, _ = jtu.tree_flatten(xs)

    out_flat = lax.scan_p.bind(
        *const_leaves,
        *init_leaves,
        *xs_leaves,
        reverse=reverse,
        length=length,
        jaxpr=quax_jaxpr,
        num_consts=num_consts,
        num_carry=num_carry,
        linear=linear,
        unroll=unroll,
        _split_transpose=_split_transpose,
    )

    # _initial_style_jaxpr(quax_f, , , "scan")
    carry_nvals = len(init_leaves)
    carry, ys = out_flat[:carry_nvals], out_flat[carry_nvals:]

    carry_out = jtu.tree_unflatten(init_treedef, carry)

    return carry_out, None

But there are at least two problems:

  • When I actually use it, I get:
File /usr/local/lib/python3.11/dist-packages/quax/_core.py:194, in <listcomp>(.0)
    [192](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:192)         out = method(*values, **params)
    [193](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:193) if primitive.multiple_results:
--> [194](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:194)     return [_QuaxTracer(self, _wrap_if_array(x)) for x in out]  # pyright: ignore
    [195](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:195) else:
    [196](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:196)     return _QuaxTracer(self, _wrap_if_array(out))

File /usr/local/lib/python3.11/dist-packages/quax/_core.py:84, in _QuaxTracer.__init__(self, trace, value)
     [83](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:83) def __init__(self, trace: "_QuaxTrace", value: "Value"):
---> [84](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:84)     assert _is_value(value)
     [85](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:85)     self._trace = trace
     [86](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:86)     self.value = value

Having scan available is pretty handy for use with the scan over layers technique.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions