Skip to content

Address more comments #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 26, 2025
Merged

Conversation

lithomas1
Copy link

Reference issue

What does this implement/fix?

Additional information

@lithomas1
Copy link
Author

cleared about half of the comments, more to follow.

@lithomas1 lithomas1 marked this pull request as ready for review January 22, 2025 23:29
Comment on lines 2523 to 2524
# array created by xp.zeros is non-writeable for dask
# and jax

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not true. All arrays are non-writeabla for jax. All arrays are writeable for dask.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r.e. writeability, I am not referring to the mutability of the original dask array, but the numpy array created from the dask array with np.asarray.

MRE

import dask.array as da
a = da.zeros(10)
a_np = np.asarray(a)
a_np[1] = 10
Traceback (most recent call last):
  File "<python-input-4>", line 1, in <module>
    a_np[1] = 10
    ~~~~^^^
ValueError: assignment destination is read-only

I'll remove the jax comment (and I'll raise an issue on scipy to disallow the output keyword for jax since it looks like there's no way to make that work).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's because the chunks of a are broadcasted, and so is a_np unless it's a concatenation of chunks.

>>> da.zeros(10).__array__().strides
(0,)
>>> da.zeros(10, chunk=5).__array__().strides
()

Broadcasted arrays are read-only as the whole thing is only 8 bytes in size.

Copy link

@crusaderky crusaderky Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole idea of having an output= parameter makes no sense if you are going to convert it to numpy.

This makes sense:

def f(x: Array, output: Array) -> None:
    output[:] = x * 2

This does not:

def f(x: Array, output: Array) -> None:
    x = np.asarray(x)
    output = np.asarray(output)
    output[:] = x * 2

In the second example, even without the broadcasting issue you're facing, the output item that the user is holding will remain completely blank, unless np.asarray returns a view of the original memory. which, for dask, can never be.

Copy link

@crusaderky crusaderky Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that np.asarray(output, copy=False) will not crash, while it should. This is a bug in Dask.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, makes sense.

I'll put in a separate PR to the main scipy repo disabling the output kwarg for jax (to keep the diff down here).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 2578 to 2579
# This output array is read-only for dask and jax
# TODO: investigate why for dask?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this makes no sense to me.

Copy link

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Co-authored-by: Guido Imperiale <crusaderky@gmail.com>
@lithomas1 lithomas1 requested a review from crusaderky January 23, 2025 17:05
@@ -534,7 +534,7 @@ def test_correlate22(self, dtype_array, dtype_output, xp):
assert_array_almost_equal(output, expected)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fully Array API compliant function would be able to suppor the output= kwarg with dask.
The problem is that you're calling np.asarray(x) will always return a buffer that is not shared with the input parameter.

Please change reason to "converts dask output array to numpy"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the comment.

@lithomas1 lithomas1 requested a review from crusaderky January 24, 2025 16:17
Copy link

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only a small nit

Co-authored-by: Guido Imperiale <crusaderky@gmail.com>
@lithomas1
Copy link
Author

Thanks for the review
@lucascolley this is ready for merge

@lucascolley lucascolley merged commit 4b14378 into lucascolley:dask-new Jan 26, 2025
@lithomas1 lithomas1 deleted the dask-new branch January 26, 2025 14:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants