-
-
Notifications
You must be signed in to change notification settings - Fork 346
Open
0 / 10 of 1 issue completedDescription
Problem
I would like to read/write numpy
dtype extensions (such as bfloat16
) with zarr
version 2. I am using ml_dtypes
from JAX for the dtype extensions.
import numpy as np
import ml_dtypes
import zarr
arr = np.array([ml_dtypes.bfloat16(1)])
zarr.save('example.zarr', arr) # ValueError: No cast function available.
I experience a similar issue when trying to read such dtype extensions.
The problem is related to the extensibility (or lack thereof) of the kind
codes in numpy
. It is well described by the JAX team.
Background
bfloat16
is a very important dtype in the AI/ML community. I would like to use zarr
(and specifically the Python implementation) to share models such as LLMs. However, the lack of bfloat16
support is a major blocker.
Questions
- Is this something that could be resolved with
zarr
v2? - Is this something that I could resolve using
zarr
v3 today? - If the answer to the previous questions was no, what would be required to support it in
zarr
v3 in the future?
Related issues
#711
cc @jhamman (as suggested by @TomNicholas)
TomNicholas, alxmrs and carlobretti
Sub-issues
Metadata
Metadata
Assignees
Labels
No labels