diff --git a/changes/3083.feature.rst b/changes/3083.feature.rst new file mode 100644 index 0000000000..4403224df1 --- /dev/null +++ b/changes/3083.feature.rst @@ -0,0 +1 @@ +Added support for async vectorized and orthogonal indexing. \ No newline at end of file diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 78dddf3669..260e94bc88 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -61,6 +61,7 @@ ZarrFormat, _default_zarr_format, _warn_order_kwarg, + ceildiv, concurrent_map, parse_shapelike, product, @@ -76,6 +77,8 @@ ) from zarr.core.dtype.common import HasEndianness, HasItemSize, HasObjectCodec from zarr.core.indexing import ( + AsyncOIndex, + AsyncVIndex, BasicIndexer, BasicSelection, BlockIndex, @@ -92,7 +95,6 @@ Selection, VIndex, _iter_grid, - ceildiv, check_fields, check_no_multi_fields, is_pure_fancy_indexing, @@ -1425,6 +1427,56 @@ async def getitem( ) return await self._get_selection(indexer, prototype=prototype) + async def get_orthogonal_selection( + self, + selection: OrthogonalSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, + ) -> NDArrayLikeOrScalar: + if prototype is None: + prototype = default_buffer_prototype() + indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) + return await self._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + + async def get_mask_selection( + self, + mask: MaskSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, + ) -> NDArrayLikeOrScalar: + if prototype is None: + prototype = default_buffer_prototype() + indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) + return await self._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + + async def get_coordinate_selection( + self, + selection: CoordinateSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, + ) -> NDArrayLikeOrScalar: + if prototype is None: + prototype = default_buffer_prototype() + indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) + out_array = await self._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + + if hasattr(out_array, "shape"): + # restore shape + out_array = np.array(out_array).reshape(indexer.sel_shape) + return out_array + async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None: """ Asynchronously save the array metadata. @@ -1556,6 +1608,19 @@ async def setitem( ) return await self._set_selection(indexer, value, prototype=prototype) + @property + def oindex(self) -> AsyncOIndex[T_ArrayMetadata]: + """Shortcut for orthogonal (outer) indexing, see :func:`get_orthogonal_selection` and + :func:`set_orthogonal_selection` for documentation and examples.""" + return AsyncOIndex(self) + + @property + def vindex(self) -> AsyncVIndex[T_ArrayMetadata]: + """Shortcut for vectorized (inner) indexing, see :func:`get_coordinate_selection`, + :func:`set_coordinate_selection`, :func:`get_mask_selection` and + :func:`set_mask_selection` for documentation and examples.""" + return AsyncVIndex(self) + async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) -> None: """ Asynchronously resize the array to a new shape. diff --git a/src/zarr/core/chunk_grids.py b/src/zarr/core/chunk_grids.py index 4bf03c89de..6a3d6816a6 100644 --- a/src/zarr/core/chunk_grids.py +++ b/src/zarr/core/chunk_grids.py @@ -18,10 +18,10 @@ ChunkCoords, ChunkCoordsLike, ShapeLike, + ceildiv, parse_named_configuration, parse_shapelike, ) -from zarr.core.indexing import ceildiv if TYPE_CHECKING: from collections.abc import Iterator diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index e86347d808..33590c83a5 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -2,6 +2,7 @@ import asyncio import functools +import math import operator import warnings from collections.abc import Iterable, Mapping, Sequence @@ -69,6 +70,12 @@ def product(tup: ChunkCoords) -> int: return functools.reduce(operator.mul, tup, 1) +def ceildiv(a: float, b: float) -> int: + if a == 0: + return 0 + return math.ceil(a / b) + + T = TypeVar("T", bound=tuple[Any, ...]) V = TypeVar("V") diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index c11889f7f4..00814a8863 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -12,6 +12,7 @@ from typing import ( TYPE_CHECKING, Any, + Generic, Literal, NamedTuple, Protocol, @@ -25,14 +26,16 @@ import numpy as np import numpy.typing as npt -from zarr.core.common import product +from zarr.core.common import ceildiv, product +from zarr.core.metadata import T_ArrayMetadata if TYPE_CHECKING: - from zarr.core.array import Array + from zarr.core.array import Array, AsyncArray from zarr.core.buffer import NDArrayLikeOrScalar from zarr.core.chunk_grids import ChunkGrid from zarr.core.common import ChunkCoords + IntSequence = list[int] | npt.NDArray[np.intp] ArrayOfIntOrBool = npt.NDArray[np.intp] | npt.NDArray[np.bool_] BasicSelector = int | slice | EllipsisType @@ -93,12 +96,6 @@ class Indexer(Protocol): def __iter__(self) -> Iterator[ChunkProjection]: ... -def ceildiv(a: float, b: float) -> int: - if a == 0: - return 0 - return math.ceil(a / b) - - _ArrayIndexingOrder: TypeAlias = Literal["lexicographic"] @@ -960,6 +957,25 @@ def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> N ) +@dataclass(frozen=True) +class AsyncOIndex(Generic[T_ArrayMetadata]): + array: AsyncArray[T_ArrayMetadata] + + async def getitem(self, selection: OrthogonalSelection | Array) -> NDArrayLikeOrScalar: + from zarr.core.array import Array + + # if input is a Zarr array, we materialize it now. + if isinstance(selection, Array): + selection = _zarr_array_to_int_or_bool_array(selection) + + fields, new_selection = pop_fields(selection) + new_selection = ensure_tuple(new_selection) + new_selection = replace_lists(new_selection) + return await self.array.get_orthogonal_selection( + cast(OrthogonalSelection, new_selection), fields=fields + ) + + @dataclass(frozen=True) class BlockIndexer(Indexer): dim_indexers: list[SliceDimIndexer] @@ -1268,6 +1284,30 @@ def __setitem__( raise VindexInvalidSelectionError(new_selection) +@dataclass(frozen=True) +class AsyncVIndex(Generic[T_ArrayMetadata]): + array: AsyncArray[T_ArrayMetadata] + + # TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool + async def getitem( + self, selection: CoordinateSelection | MaskSelection | Array + ) -> NDArrayLikeOrScalar: + from zarr.core.array import Array + + # if input is a Zarr array, we materialize it now. + if isinstance(selection, Array): + selection = _zarr_array_to_int_or_bool_array(selection) + fields, new_selection = pop_fields(selection) + new_selection = ensure_tuple(new_selection) + new_selection = replace_lists(new_selection) + if is_coordinate_selection(new_selection, self.array.shape): + return await self.array.get_coordinate_selection(new_selection, fields=fields) + elif is_mask_selection(new_selection, self.array.shape): + return await self.array.get_mask_selection(new_selection, fields=fields) + else: + raise VindexInvalidSelectionError(new_selection) + + def check_fields(fields: Fields | None, dtype: np.dtype[Any]) -> np.dtype[Any]: # early out if fields is None: diff --git a/tests/test_array.py b/tests/test_array.py index 42f4a1cbdd..f672006f9a 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -41,7 +41,7 @@ from zarr.core.buffer import NDArrayLike, NDArrayLikeOrScalar, default_buffer_prototype from zarr.core.chunk_grids import _auto_partition from zarr.core.chunk_key_encodings import ChunkKeyEncodingParams -from zarr.core.common import JSON, ZarrFormat +from zarr.core.common import JSON, ZarrFormat, ceildiv from zarr.core.dtype import ( DateTime64, Float32, @@ -59,7 +59,7 @@ from zarr.core.dtype.npy.common import NUMPY_ENDIANNESS_STR, endianness_from_numpy_str from zarr.core.dtype.npy.string import UTF8Base from zarr.core.group import AsyncGroup -from zarr.core.indexing import BasicIndexer, ceildiv +from zarr.core.indexing import BasicIndexer from zarr.core.metadata.v2 import ArrayV2Metadata from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import sync diff --git a/tests/test_properties.py b/tests/test_properties.py index 27f847fa69..705cfd1b59 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -105,33 +105,52 @@ def test_array_creates_implicit_groups(array): # this decorator removes timeout; not ideal but it should avoid intermittent CI failures +@pytest.mark.asyncio @settings(deadline=None) @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") @given(data=st.data()) -def test_basic_indexing(data: st.DataObject) -> None: +async def test_basic_indexing(data: st.DataObject) -> None: zarray = data.draw(simple_arrays()) nparray = zarray[:] indexer = data.draw(basic_indices(shape=nparray.shape)) + + # sync get actual = zarray[indexer] assert_array_equal(nparray[indexer], actual) + # async get + async_zarray = zarray._async_array + actual = await async_zarray.getitem(indexer) + assert_array_equal(nparray[indexer], actual) + + # sync set new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype)) zarray[indexer] = new_data nparray[indexer] = new_data assert_array_equal(nparray, zarray[:]) + # TODO test async setitem? + +@pytest.mark.asyncio @given(data=st.data()) @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") -def test_oindex(data: st.DataObject) -> None: +async def test_oindex(data: st.DataObject) -> None: # integer_array_indices can't handle 0-size dimensions. zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1))) nparray = zarray[:] - zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape)) + + # sync get actual = zarray.oindex[zindexer] assert_array_equal(nparray[npindexer], actual) + # async get + async_zarray = zarray._async_array + actual = await async_zarray.oindex.getitem(zindexer) + assert_array_equal(nparray[npindexer], actual) + + # sync get assume(zarray.shards is None) # GH2834 for idxr in npindexer: if isinstance(idxr, np.ndarray) and idxr.size != np.unique(idxr).size: @@ -142,22 +161,32 @@ def test_oindex(data: st.DataObject) -> None: zarray.oindex[zindexer] = new_data assert_array_equal(nparray, zarray[:]) + # note: async oindex setitem not yet implemented + +@pytest.mark.asyncio @given(data=st.data()) @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") -def test_vindex(data: st.DataObject) -> None: +async def test_vindex(data: st.DataObject) -> None: # integer_array_indices can't handle 0-size dimensions. zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1))) nparray = zarray[:] - indexer = data.draw( npst.integer_array_indices( shape=nparray.shape, result_shape=npst.array_shapes(min_side=1, max_dims=None) ) ) + + # sync get actual = zarray.vindex[indexer] assert_array_equal(nparray[indexer], actual) + # async get + async_zarray = zarray._async_array + actual = await async_zarray.vindex.getitem(indexer) + assert_array_equal(nparray[indexer], actual) + + # sync set # FIXME! # when the indexer is such that a value gets overwritten multiple times, # I think the output depends on chunking. @@ -166,6 +195,8 @@ def test_vindex(data: st.DataObject) -> None: # zarray.vindex[indexer] = new_data # assert_array_equal(nparray, zarray[:]) + # note: async vindex setitem not yet implemented + @given(store=stores, meta=array_metadata()) # type: ignore[misc] @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")