-
-
Notifications
You must be signed in to change notification settings - Fork 346
Add async oindex and vindex methods to AsyncArray #3083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4f51d23
535ebaa
6f25f82
e595f76
bdbdd61
fec243d
320e6d2
ea0f657
870b6b6
a7e9e43
102e411
0cd96aa
b503969
9b8ebde
125ebdf
b6d5b6d
e7cbaef
d5d5494
c0026e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Added support for async vectorized and orthogonal indexing. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,7 @@ | |
ZarrFormat, | ||
_default_zarr_format, | ||
_warn_order_kwarg, | ||
ceildiv, | ||
concurrent_map, | ||
parse_shapelike, | ||
product, | ||
|
@@ -76,6 +77,8 @@ | |
) | ||
from zarr.core.dtype.common import HasEndianness, HasItemSize, HasObjectCodec | ||
from zarr.core.indexing import ( | ||
AsyncOIndex, | ||
AsyncVIndex, | ||
BasicIndexer, | ||
BasicSelection, | ||
BlockIndex, | ||
|
@@ -92,7 +95,6 @@ | |
Selection, | ||
VIndex, | ||
_iter_grid, | ||
ceildiv, | ||
check_fields, | ||
check_no_multi_fields, | ||
is_pure_fancy_indexing, | ||
|
@@ -1425,6 +1427,56 @@ async def getitem( | |
) | ||
return await self._get_selection(indexer, prototype=prototype) | ||
|
||
async def get_orthogonal_selection( | ||
self, | ||
selection: OrthogonalSelection, | ||
*, | ||
out: NDBuffer | None = None, | ||
fields: Fields | None = None, | ||
prototype: BufferPrototype | None = None, | ||
) -> NDArrayLikeOrScalar: | ||
if prototype is None: | ||
prototype = default_buffer_prototype() | ||
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) | ||
return await self._get_selection( | ||
indexer=indexer, out=out, fields=fields, prototype=prototype | ||
) | ||
|
||
async def get_mask_selection( | ||
self, | ||
mask: MaskSelection, | ||
*, | ||
out: NDBuffer | None = None, | ||
fields: Fields | None = None, | ||
prototype: BufferPrototype | None = None, | ||
) -> NDArrayLikeOrScalar: | ||
if prototype is None: | ||
prototype = default_buffer_prototype() | ||
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) | ||
return await self._get_selection( | ||
indexer=indexer, out=out, fields=fields, prototype=prototype | ||
) | ||
|
||
async def get_coordinate_selection( | ||
self, | ||
selection: CoordinateSelection, | ||
*, | ||
out: NDBuffer | None = None, | ||
fields: Fields | None = None, | ||
prototype: BufferPrototype | None = None, | ||
) -> NDArrayLikeOrScalar: | ||
if prototype is None: | ||
prototype = default_buffer_prototype() | ||
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) | ||
out_array = await self._get_selection( | ||
indexer=indexer, out=out, fields=fields, prototype=prototype | ||
) | ||
|
||
if hasattr(out_array, "shape"): | ||
# restore shape | ||
out_array = np.array(out_array).reshape(indexer.sel_shape) | ||
return out_array | ||
|
||
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None: | ||
""" | ||
Asynchronously save the array metadata. | ||
|
@@ -1556,6 +1608,19 @@ async def setitem( | |
) | ||
return await self._set_selection(indexer, value, prototype=prototype) | ||
|
||
@property | ||
def oindex(self) -> AsyncOIndex[T_ArrayMetadata]: | ||
Comment on lines
+1611
to
+1612
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I chose this API to try to follow this pattern:
because python doesn't let you make an async version of the |
||
"""Shortcut for orthogonal (outer) indexing, see :func:`get_orthogonal_selection` and | ||
:func:`set_orthogonal_selection` for documentation and examples.""" | ||
return AsyncOIndex(self) | ||
|
||
@property | ||
def vindex(self) -> AsyncVIndex[T_ArrayMetadata]: | ||
"""Shortcut for vectorized (inner) indexing, see :func:`get_coordinate_selection`, | ||
:func:`set_coordinate_selection`, :func:`get_mask_selection` and | ||
:func:`set_mask_selection` for documentation and examples.""" | ||
return AsyncVIndex(self) | ||
|
||
async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) -> None: | ||
""" | ||
Asynchronously resize the array to a new shape. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Generic, | ||
Literal, | ||
NamedTuple, | ||
Protocol, | ||
|
@@ -25,14 +26,16 @@ | |
import numpy as np | ||
import numpy.typing as npt | ||
|
||
from zarr.core.common import product | ||
from zarr.core.common import ceildiv, product | ||
from zarr.core.metadata import T_ArrayMetadata | ||
|
||
if TYPE_CHECKING: | ||
from zarr.core.array import Array | ||
from zarr.core.array import Array, AsyncArray | ||
from zarr.core.buffer import NDArrayLikeOrScalar | ||
from zarr.core.chunk_grids import ChunkGrid | ||
from zarr.core.common import ChunkCoords | ||
|
||
|
||
IntSequence = list[int] | npt.NDArray[np.intp] | ||
ArrayOfIntOrBool = npt.NDArray[np.intp] | npt.NDArray[np.bool_] | ||
BasicSelector = int | slice | EllipsisType | ||
|
@@ -93,12 +96,6 @@ class Indexer(Protocol): | |
def __iter__(self) -> Iterator[ChunkProjection]: ... | ||
|
||
|
||
def ceildiv(a: float, b: float) -> int: | ||
if a == 0: | ||
return 0 | ||
return math.ceil(a / b) | ||
|
||
|
||
_ArrayIndexingOrder: TypeAlias = Literal["lexicographic"] | ||
|
||
|
||
|
@@ -960,6 +957,25 @@ def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> N | |
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class AsyncOIndex(Generic[T_ArrayMetadata]): | ||
array: AsyncArray[T_ArrayMetadata] | ||
|
||
async def getitem(self, selection: OrthogonalSelection | Array) -> NDArrayLikeOrScalar: | ||
from zarr.core.array import Array | ||
|
||
# if input is a Zarr array, we materialize it now. | ||
if isinstance(selection, Array): | ||
selection = _zarr_array_to_int_or_bool_array(selection) | ||
|
||
fields, new_selection = pop_fields(selection) | ||
new_selection = ensure_tuple(new_selection) | ||
new_selection = replace_lists(new_selection) | ||
return await self.array.get_orthogonal_selection( | ||
cast(OrthogonalSelection, new_selection), fields=fields | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class BlockIndexer(Indexer): | ||
dim_indexers: list[SliceDimIndexer] | ||
|
@@ -1268,6 +1284,30 @@ def __setitem__( | |
raise VindexInvalidSelectionError(new_selection) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class AsyncVIndex(Generic[T_ArrayMetadata]): | ||
array: AsyncArray[T_ArrayMetadata] | ||
|
||
# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool | ||
async def getitem( | ||
self, selection: CoordinateSelection | MaskSelection | Array | ||
) -> NDArrayLikeOrScalar: | ||
from zarr.core.array import Array | ||
|
||
# if input is a Zarr array, we materialize it now. | ||
if isinstance(selection, Array): | ||
selection = _zarr_array_to_int_or_bool_array(selection) | ||
fields, new_selection = pop_fields(selection) | ||
new_selection = ensure_tuple(new_selection) | ||
new_selection = replace_lists(new_selection) | ||
if is_coordinate_selection(new_selection, self.array.shape): | ||
return await self.array.get_coordinate_selection(new_selection, fields=fields) | ||
elif is_mask_selection(new_selection, self.array.shape): | ||
return await self.array.get_mask_selection(new_selection, fields=fields) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to add
|
||
else: | ||
raise VindexInvalidSelectionError(new_selection) | ||
|
||
|
||
def check_fields(fields: Fields | None, dtype: np.dtype[Any]) -> np.dtype[Any]: | ||
# early out | ||
if fields is None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_basic_selection
also doesn't exist onAsyncArray
- should I add that too?