Skip to content

Add async oindex and vindex methods to AsyncArray #3083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3083.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for async vectorized and orthogonal indexing.
67 changes: 66 additions & 1 deletion src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ZarrFormat,
_default_zarr_format,
_warn_order_kwarg,
ceildiv,
concurrent_map,
parse_shapelike,
product,
Expand All @@ -76,6 +77,8 @@
)
from zarr.core.dtype.common import HasEndianness, HasItemSize, HasObjectCodec
from zarr.core.indexing import (
AsyncOIndex,
AsyncVIndex,
BasicIndexer,
BasicSelection,
BlockIndex,
Expand All @@ -92,7 +95,6 @@
Selection,
VIndex,
_iter_grid,
ceildiv,
check_fields,
check_no_multi_fields,
is_pure_fancy_indexing,
Expand Down Expand Up @@ -1425,6 +1427,56 @@ async def getitem(
)
return await self._get_selection(indexer, prototype=prototype)

async def get_orthogonal_selection(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_basic_selection also doesn't exist on AsyncArray - should I add that too?

self,
selection: OrthogonalSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
return await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)

async def get_mask_selection(
self,
mask: MaskSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
return await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)

async def get_coordinate_selection(
self,
selection: CoordinateSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype | None = None,
) -> NDArrayLikeOrScalar:
if prototype is None:
prototype = default_buffer_prototype()
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
out_array = await self._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)

if hasattr(out_array, "shape"):
# restore shape
out_array = np.array(out_array).reshape(indexer.sel_shape)
return out_array

async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
"""
Asynchronously save the array metadata.
Expand Down Expand Up @@ -1556,6 +1608,19 @@ async def setitem(
)
return await self._set_selection(indexer, value, prototype=prototype)

@property
def oindex(self) -> AsyncOIndex[T_ArrayMetadata]:
Comment on lines +1611 to +1612
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose this API to try to follow this pattern:

  • Array.__getitem__ (exists)
  • Array.oindex.__getitem__ (exists)
  • Array.vindex.__getitem__ (exists)
  • AsyncArray.getitem (exists)
  • AsyncArray.oindex.getitem (new)
  • AsyncArray.vindex.getitem (new)

because python doesn't let you make an async version of the __getitem__ magic method.

"""Shortcut for orthogonal (outer) indexing, see :func:`get_orthogonal_selection` and
:func:`set_orthogonal_selection` for documentation and examples."""
return AsyncOIndex(self)

@property
def vindex(self) -> AsyncVIndex[T_ArrayMetadata]:
"""Shortcut for vectorized (inner) indexing, see :func:`get_coordinate_selection`,
:func:`set_coordinate_selection`, :func:`get_mask_selection` and
:func:`set_mask_selection` for documentation and examples."""
return AsyncVIndex(self)

async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) -> None:
"""
Asynchronously resize the array to a new shape.
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/core/chunk_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
ChunkCoords,
ChunkCoordsLike,
ShapeLike,
ceildiv,
parse_named_configuration,
parse_shapelike,
)
from zarr.core.indexing import ceildiv

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down
7 changes: 7 additions & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import functools
import math
import operator
import warnings
from collections.abc import Iterable, Mapping, Sequence
Expand Down Expand Up @@ -69,6 +70,12 @@ def product(tup: ChunkCoords) -> int:
return functools.reduce(operator.mul, tup, 1)


def ceildiv(a: float, b: float) -> int:
if a == 0:
return 0
return math.ceil(a / b)


T = TypeVar("T", bound=tuple[Any, ...])
V = TypeVar("V")

Expand Down
56 changes: 48 additions & 8 deletions src/zarr/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
NamedTuple,
Protocol,
Expand All @@ -25,14 +26,16 @@
import numpy as np
import numpy.typing as npt

from zarr.core.common import product
from zarr.core.common import ceildiv, product
from zarr.core.metadata import T_ArrayMetadata

if TYPE_CHECKING:
from zarr.core.array import Array
from zarr.core.array import Array, AsyncArray
from zarr.core.buffer import NDArrayLikeOrScalar
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.common import ChunkCoords


IntSequence = list[int] | npt.NDArray[np.intp]
ArrayOfIntOrBool = npt.NDArray[np.intp] | npt.NDArray[np.bool_]
BasicSelector = int | slice | EllipsisType
Expand Down Expand Up @@ -93,12 +96,6 @@ class Indexer(Protocol):
def __iter__(self) -> Iterator[ChunkProjection]: ...


def ceildiv(a: float, b: float) -> int:
if a == 0:
return 0
return math.ceil(a / b)


_ArrayIndexingOrder: TypeAlias = Literal["lexicographic"]


Expand Down Expand Up @@ -960,6 +957,25 @@ def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> N
)


@dataclass(frozen=True)
class AsyncOIndex(Generic[T_ArrayMetadata]):
array: AsyncArray[T_ArrayMetadata]

async def getitem(self, selection: OrthogonalSelection | Array) -> NDArrayLikeOrScalar:
from zarr.core.array import Array

# if input is a Zarr array, we materialize it now.
if isinstance(selection, Array):
selection = _zarr_array_to_int_or_bool_array(selection)

fields, new_selection = pop_fields(selection)
new_selection = ensure_tuple(new_selection)
new_selection = replace_lists(new_selection)
return await self.array.get_orthogonal_selection(
cast(OrthogonalSelection, new_selection), fields=fields
)


@dataclass(frozen=True)
class BlockIndexer(Indexer):
dim_indexers: list[SliceDimIndexer]
Expand Down Expand Up @@ -1268,6 +1284,30 @@ def __setitem__(
raise VindexInvalidSelectionError(new_selection)


@dataclass(frozen=True)
class AsyncVIndex(Generic[T_ArrayMetadata]):
array: AsyncArray[T_ArrayMetadata]

# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
async def getitem(
self, selection: CoordinateSelection | MaskSelection | Array
) -> NDArrayLikeOrScalar:
from zarr.core.array import Array

# if input is a Zarr array, we materialize it now.
if isinstance(selection, Array):
selection = _zarr_array_to_int_or_bool_array(selection)
fields, new_selection = pop_fields(selection)
new_selection = ensure_tuple(new_selection)
new_selection = replace_lists(new_selection)
if is_coordinate_selection(new_selection, self.array.shape):
return await self.array.get_coordinate_selection(new_selection, fields=fields)
elif is_mask_selection(new_selection, self.array.shape):
return await self.array.get_mask_selection(new_selection, fields=fields)
Copy link
Member Author

@TomNicholas TomNicholas Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add .get_mask_selection to AsyncArray to cover this codepath. But I only realised I needed to thanks to mypy. This means that this codepath is

  • not needed for me right now (I think)
  • definitely not covered by the property tests

else:
raise VindexInvalidSelectionError(new_selection)


def check_fields(fields: Fields | None, dtype: np.dtype[Any]) -> np.dtype[Any]:
# early out
if fields is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from zarr.core.buffer import NDArrayLike, NDArrayLikeOrScalar, default_buffer_prototype
from zarr.core.chunk_grids import _auto_partition
from zarr.core.chunk_key_encodings import ChunkKeyEncodingParams
from zarr.core.common import JSON, ZarrFormat
from zarr.core.common import JSON, ZarrFormat, ceildiv
from zarr.core.dtype import (
DateTime64,
Float32,
Expand All @@ -59,7 +59,7 @@
from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str
from zarr.core.dtype.npy.string import UTF8Base
from zarr.core.group import AsyncGroup
from zarr.core.indexing import BasicIndexer, ceildiv
from zarr.core.indexing import BasicIndexer
from zarr.core.metadata.v2 import ArrayV2Metadata
from zarr.core.metadata.v3 import ArrayV3Metadata
from zarr.core.sync import sync
Expand Down
55 changes: 55 additions & 0 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,27 @@ def test_basic_indexing(data: st.DataObject) -> None:
assert_array_equal(nparray, zarray[:])


@pytest.mark.asyncio
@settings(deadline=None)
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@given(data=st.data())
async def test_basic_indexing_async(data: st.DataObject) -> None:
zarray = data.draw(simple_arrays())
nparray = zarray[:]
indexer = data.draw(basic_indices(shape=nparray.shape))
async_zarray = zarray._async_array

actual = await async_zarray.getitem(indexer)
assert_array_equal(nparray[indexer], actual)

# TODO test async setitem
# new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
# asyncio.run(async_zarray.setitem(indexer, new_data))
# nparray[indexer] = new_data
# result = asyncio.run(async_zarray.getitem(indexer))
# assert_array_equal(nparray, result)


@given(data=st.data())
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
def test_oindex(data: st.DataObject) -> None:
Expand All @@ -143,6 +164,22 @@ def test_oindex(data: st.DataObject) -> None:
assert_array_equal(nparray, zarray[:])


@pytest.mark.asyncio
@given(data=st.data())
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
async def test_oindex_async(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
nparray = zarray[:]
async_zarray = zarray._async_array

zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
actual = await async_zarray.oindex.getitem(zindexer)
assert_array_equal(nparray[npindexer], actual)

# note: async oindex setting not yet implemented


@given(data=st.data())
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
def test_vindex(data: st.DataObject) -> None:
Expand All @@ -167,6 +204,24 @@ def test_vindex(data: st.DataObject) -> None:
# assert_array_equal(nparray, zarray[:])


@pytest.mark.asyncio
@given(data=st.data())
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
async def test_vindex_async(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
nparray = zarray[:]
async_zarray = zarray._async_array

indexer = data.draw(
npst.integer_array_indices(
shape=nparray.shape, result_shape=npst.array_shapes(min_side=1, max_dims=None)
)
)
actual = await async_zarray.vindex.getitem(indexer)
assert_array_equal(nparray[indexer], actual)


@given(store=stores, meta=array_metadata()) # type: ignore[misc]
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
async def test_roundtrip_array_metadata_from_store(
Expand Down
Loading