@@ -112,7 +112,8 @@ def concatenate(
112
112
113
113
# Ensure we handle axis being passed as a negative integer
114
114
first_arr = arrays [0 ]
115
- axis = axis % first_arr .ndim
115
+ if axis < 0 :
116
+ axis = axis % first_arr .ndim
116
117
117
118
arr_shapes = [arr .shape for arr in arrays ]
118
119
_check_same_shapes_except_on_concat_axis (arr_shapes , axis )
@@ -154,6 +155,7 @@ def _check_same_ndims(ndims: list[int]) -> None:
154
155
155
156
def _check_same_shapes_except_on_concat_axis (shapes : list [tuple [int , ...]], axis : int ):
156
157
"""Check that shapes are compatible for concatenation"""
158
+
157
159
shapes_without_concat_axis = [
158
160
_remove_element_at_position (shape , axis ) for shape in shapes
159
161
]
@@ -198,7 +200,8 @@ def stack(
198
200
199
201
# Ensure we handle axis being passed as a negative integer
200
202
first_arr = arrays [0 ]
201
- axis = axis % first_arr .ndim
203
+ if axis < 0 :
204
+ axis = axis % first_arr .ndim
202
205
203
206
# find what new array shape must be
204
207
length_along_new_stacked_axis = len (arrays )
@@ -267,8 +270,13 @@ def broadcast_to(x: "ManifestArray", /, shape: Tuple[int, ...]) -> "ManifestArra
267
270
if d == d_requested :
268
271
pass
269
272
elif d is None :
270
- # stack same array upon itself d_requested number of times, which inserts a new axis at axis=0
271
- result = stack ([result ] * d_requested , axis = 0 )
273
+ if result .shape == ():
274
+ # scalars are a special case because their manifests already have a chunk key with one dimension
275
+ # see https://github.com/TomNicholas/VirtualiZarr/issues/100#issuecomment-2097058282
276
+ result = _broadcast_scalar (result , new_axis_length = d_requested )
277
+ else :
278
+ # stack same array upon itself d_requested number of times, which inserts a new axis at axis=0
279
+ result = stack ([result ] * d_requested , axis = 0 )
272
280
elif d == 1 :
273
281
# concatenate same array upon itself d_requested number of times along existing axis
274
282
result = concatenate ([result ] * d_requested , axis = axis )
@@ -280,6 +288,41 @@ def broadcast_to(x: "ManifestArray", /, shape: Tuple[int, ...]) -> "ManifestArra
280
288
return result
281
289
282
290
291
+ def _broadcast_scalar (x : "ManifestArray" , new_axis_length : int ) -> "ManifestArray" :
292
+ """
293
+ Add an axis to a scalar ManifestArray, but without adding a new axis to the keys of the chunk manifest.
294
+
295
+ This is not the same as concatenation, because there is no existing axis along which to concatenate.
296
+ It's also not the same as stacking, because we don't want to insert a new axis into the chunk keys.
297
+
298
+ Scalars are a special case because their manifests still have a chunk key with one dimension.
299
+ See https://github.com/TomNicholas/VirtualiZarr/issues/100#issuecomment-2097058282
300
+ """
301
+
302
+ from .array import ManifestArray
303
+
304
+ new_shape = (new_axis_length ,)
305
+ new_chunks = (new_axis_length ,)
306
+
307
+ concatenated_manifest = concat_manifests (
308
+ [x .manifest ] * new_axis_length ,
309
+ axis = 0 ,
310
+ )
311
+
312
+ new_zarray = ZArray (
313
+ chunks = new_chunks ,
314
+ compressor = x .zarray .compressor ,
315
+ dtype = x .dtype ,
316
+ fill_value = x .zarray .fill_value ,
317
+ filters = x .zarray .filters ,
318
+ shape = new_shape ,
319
+ order = x .zarray .order ,
320
+ zarr_format = x .zarray .zarr_format ,
321
+ )
322
+
323
+ return ManifestArray (chunkmanifest = concatenated_manifest , zarray = new_zarray )
324
+
325
+
283
326
# TODO broadcast_arrays, squeeze, permute_dims
284
327
285
328
0 commit comments