-
-
Notifications
You must be signed in to change notification settings - Fork 348
Open
Labels
enhancementNew features or improvementsNew features or improvements
Description
Numpy does not natively support bfloat16, but Jax and Tensorflow define a bfloat16 numpy dtype. It would be great if the zarr format provided a way to store it.
Currently the zarr library already works for writing bfloat16 data, but as the dtype is stored as "<V2" reading is not supported without explicitly calling .view
after opening:
import zarr
import jax.numpy as jnp
import numpy as np
bfloat16 = jnp.bfloat16
np.typeDict['bfloat16'] = bfloat16
my_store = dict()
z1 = zarr.open(mode='w', shape=(1,), compressor=None, dtype=np.dtype(bfloat16), store=my_store)
z1[0] = np.array(42, dtype=bfloat16)
print('Original array: %r' % (z1[0],))
print('Original array with view: %r' % (z1.view(dtype=bfloat16)[0],))
z2 = zarr.open(mode='r', store=my_store)
print('Reopening with original dtype: %r' % (z2[0],))
print('Reopening with original dtype with view: %r' % (z2.view(dtype=bfloat16)[0],))
my_store['.zarray'] = my_store['.zarray'].replace(b'<V2', b'bfloat16')
z3 = zarr.open(mode='r', store=my_store)
print('With adjusted dtype: %r' % (z3[0],))
Output is:
Original array: void(b'\x28\x42')
Original array with view: 42
Reopening with original dtype: void(b'\x28\x42')
Reopening with original dtype with view: 42
With adjusted dtype: 42
Replacing the stored dtype in the .zarray
file with "bfloat16"
seems to be the only way to get zarr to use bfloat16 as the data type when opening. (That requires registering the data type in np.typeDict
, which jax does not do but probably should do.) That does not allow specifying the byte order, but supporting a big endian machine may not be too important.
Metadata
Metadata
Assignees
Labels
enhancementNew features or improvementsNew features or improvements