Skip to content

Support partial reading of Zarr datasets #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions PythonModule/ZarrPy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,16 @@ def writeZarr (kvstore_schema, data):
zarr_file[...] = data


def readZarr (kvstore_schema):
def readZarr (kvstore_schema, starts, ends, strides):
"""
Reads a subset of data from a Zarr file.

Parameters:
- kvstore_schema (dictionary): Schema for the file store (local or remote)
- starts (numpy.ndarray): Array of start indices for each dimension (0-based)
- ends (numpy.ndarray): Array of end indices for each dimension (elements
at the end index will not be read)
- strides (numpy.ndarray): Array of strides for each dimensions

Returns:
- numpy.ndarray: The subset of the data read from the Zarr file.
Expand All @@ -96,6 +100,10 @@ def readZarr (kvstore_schema):
'kvstore': kvstore_schema,
}).result()

# Construct the indexing slices
slices = tuple(slice(start, end, stride) for start, end, stride in zip(starts, ends, strides))

# Read a subset of the data
data = zarr_file[...].read().result()
data = zarr_file[slices].read().result()

return data
56 changes: 54 additions & 2 deletions Zarr.m
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,38 @@
isZgroup = isfile(fullfile(path, '.zgroup'));
end

function newParams = validatePartialReadParams(params, dims, defaultValues)
% Validate the parameters for partial read (Start, Stride,
% Count)

arguments (Output)
newParams (1,:) int64
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprised considering the name of the function, you didn't also use arguments block to validate inputs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I am doing the basic arguments block validation for Start/Stride/Count in zarrread (better error message and faster erroring out), so not much left to validate here. Not great to validate different things in different places though..


if isempty(params)
newParams = defaultValues;
return
end

% Allow using a scalar value for indexing into row or column
% datasets
if isscalar(params) && any(dims==1) && numel(dims)==2
newParams = defaultValues;
% use the provided value for the non-scalar dimension
newParams(dims~=1) = params;
return
end

if numel(params) ~= numel(dims)
error("MATLAB:Zarr:badPartialReadDimensions",...
"Length of parameters for partial reading " +...
"(Start, Stride, Count) must be the same "+...
"as the number of dataset dimensions.")
end

newParams = params;
end

function resolvedPath = getFullPath(path)
% Given a path, resolves it to a full path. The trailing
% directories do not have to exist.
Expand Down Expand Up @@ -200,7 +232,7 @@ function makeZarrGroups(existingParentPath, newGroupsPath)
end


function data = read(obj)
function data = read(obj, start, count, stride)
% Function to read the Zarr array

% If the Zarr array is local, verify that it is a valid folder
Expand All @@ -214,7 +246,27 @@ function makeZarrGroups(existingParentPath, newGroupsPath)
end
end

ndArrayData = py.ZarrPy.readZarr(obj.KVStoreSchema);
% Validate partial read parameters
info = zarrinfo(obj.Path);
numDims = numel(info.shape);
start = Zarr.validatePartialReadParams(start, info.shape,...
ones([1,numDims]));
stride = Zarr.validatePartialReadParams(stride, info.shape,...
ones([1,numDims]));
maxCount = (int64(info.shape') - start + 1)./stride; % has to be a row vector
count = Zarr.validatePartialReadParams(count, info.shape,...
maxCount);

% Convert partial read parameters to tensorstore-style
% indexing
start = start - 1; % tensorstore is 0-based
% Tensorstore uses end index instead of count
% (it does NOT include element at the end index)
endInds = start + stride.*count;

% Read the data
ndArrayData = py.ZarrPy.readZarr(obj.KVStoreSchema,...
start, endInds, stride);

% Store the datatype
obj.Datatype = ZarrDatatype.fromTensorstoreType(ndArrayData.dtype.name);
Expand Down
1 change: 1 addition & 0 deletions test/dataFiles/grp_v2/smallArr/.zarray
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"chunks":[3,4],"compressor":null,"dimension_separator":".","dtype":"<f8","fill_value":null,"filters":null,"order":"C","shape":[3,4],"zarr_format":2}
Binary file added test/dataFiles/grp_v2/smallArr/0.0
Binary file not shown.
1 change: 1 addition & 0 deletions test/dataFiles/grp_v2/vectorData/.zarray
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"chunks":[1,10],"compressor":null,"dimension_separator":".","dtype":"<f8","fill_value":null,"filters":null,"order":"C","shape":[1,10],"zarr_format":2}
Binary file added test/dataFiles/grp_v2/vectorData/0.0
Binary file not shown.
72 changes: 72 additions & 0 deletions test/tZarrRead.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
% Path for read functions
GrpPathRead = "dataFiles/grp_v2"
ArrPathRead = "dataFiles/grp_v2/arr_v2"
ArrPathReadSmall = "dataFiles/grp_v2/smallArr"
ArrPathReadVector = "dataFiles/grp_v2/vectorData"
ArrPathReadV3 = "dataFiles/grp_v3/arr_v3"

ExpData = load(fullfile(pwd,"dataFiles","expZarrArrData.mat"))
Expand All @@ -28,6 +30,54 @@ function verifyArrayData(testcase)
testcase.verifyEqual(actArrData,expArrData,'Failed to verify array data.');
end

function verifyPartialArrayData(testcase)
% Verify array data using zarrread function with Start/Stride/Count.

% The full data in the small array is
%
% 1 4 7 10
% 2 5 8 11
% 3 6 9 12
zpath = testcase.ArrPathReadSmall;

% Start
actData = zarrread(zpath, Start=[2, 3]);
expData = [8, 11; 9, 12];
testcase.verifyEqual(actData,expData,...
'Failed to verify reading with Start.');

% Count
actData = zarrread(zpath, Count=[2, 1]);
expData = [1;2];
testcase.verifyEqual(actData,expData,...
'Failed to verify reading with Count.');

% Stride
actData = zarrread(zpath, Stride=[3, 2]);
expData = [1, 7];
testcase.verifyEqual(actData,expData,...
'Failed to verify reading with Stride.');

% Start, Stride, and Count
actData = zarrread(zpath,...
Start=[2, 1], Stride=[1, 2], Count=[1, 2]);
expData = [2, 8];
testcase.verifyEqual(actData,expData,...
'Failed to verify reading with Start, Stride, and Count.');
end

function verifyPartialVectorData(testcase)
% Verify that specifying a scalar value for Start/Stride/Count
% for vector datasets works as expected

zpath = testcase.ArrPathReadVector; % data is 1:10

expData = [2,5];
actData = zarrread(zpath, Start=2, Stride=3, Count=2);
testcase.verifyEqual(actData,expData,...
'Failed to verify using scalar Start, Stride, and Count.');
end

function verifyArrayDataRelativePath(testcase)
% Verify array data if the input is using relative path to the
% array.
Expand Down Expand Up @@ -83,5 +133,27 @@ function invalidFilePath(testcase)
errID = 'MATLAB:Zarr:invalidZarrObject';
testcase.verifyError(@()zarrread(inpPath),errID);
end

function invalidPartialReadParams(testcase)
% Verify zarrread errors when invalid partial read
% Start/Stride/Count are used

zpath = testcase.ArrPathReadSmall; % a 2D array, 3x4

errID = 'MATLAB:Zarr:badPartialReadDimensions';
wrongNumberOfDimensions = [1,1,1];
testcase.verifyError(...
@()zarrread(zpath,Start=wrongNumberOfDimensions),...
errID);
testcase.verifyError(...
@()zarrread(zpath,Stride=wrongNumberOfDimensions),...
errID);
testcase.verifyError(...
@()zarrread(zpath,Count=wrongNumberOfDimensions),...
errID);

%TODO: negative values, wrong datatypes, out of bounds

end
end
end
31 changes: 26 additions & 5 deletions zarrread.m
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
function data = zarrread(filepath)
function data = zarrread(filepath, options)
%ZARRREAD Read data from Zarr array
% DATA = ZARRREAD(FILEPATH) retrieves all the data from the Zarr array
% located at FILEPATH.
% The datatype of DATA is the MATLAB equivalent of the Zarr datatype of the
% array located at FILEPATH.
% located at FILEPATH. The datatype of DATA is the MATLAB equivalent of
% the Zarr datatype of the array located at FILEPATH.
%
% DATA = ZARRREAD(FILEPATH, Start=start) retrieves a subset of the data
% from the Zarr array located at FILEPATH. Start is a row vector of
% one-based indices of the first element to be read in each dimension.
% Default is to read all the elements starting from the first (Start=
% [1,1,..].
%
% DATA = ZARRREAD(FILEPATH, Count=count) retrieves a subset of the data
% from the Zarr array located at FILEPATH. Count is a row vector
% of number of elements to be read in each dimension. Default is to read
% all the available elements (based on dimension size and the specified
% Start and Stride).
%
% DATA = ZARRREAD(FILEPATH, Stride=stride) retrieves a subset of the data
% from the Zarr array located at FILEPATH. Stride is a row vector of
% spaces between indices along each dimension. A value of 1 accesses
% adjacent elements in the corresponding dimension, a value of 2
% accesses every other element in the corresponding dimension, etc.
% Default is to read all elements without skipping (Stride=[1,1,...])

% Copyright 2025 The MathWorks, Inc.

arguments
filepath {mustBeTextScalar, mustBeNonzeroLengthText}
options.Start (1,:) {mustBeInteger, mustBePositive} = [];
options.Count (1,:) {mustBeInteger, mustBePositive} = [];
options.Stride (1,:) {mustBeInteger, mustBePositive} = [];
end

zarrObj = Zarr(filepath);
data = zarrObj.read;
data = zarrObj.read(options.Start, options.Count, options.Stride);
end