Skip to content

Support for bfloat16 data type #711

@jbms

Description

@jbms

Numpy does not natively support bfloat16, but Jax and Tensorflow define a bfloat16 numpy dtype. It would be great if the zarr format provided a way to store it.

Currently the zarr library already works for writing bfloat16 data, but as the dtype is stored as "<V2" reading is not supported without explicitly calling .view after opening:

import zarr
import jax.numpy as jnp
import numpy as np

bfloat16 = jnp.bfloat16
np.typeDict['bfloat16'] = bfloat16

my_store = dict()

z1 = zarr.open(mode='w', shape=(1,), compressor=None, dtype=np.dtype(bfloat16), store=my_store)
z1[0] = np.array(42, dtype=bfloat16)
print('Original array: %r' % (z1[0],))
print('Original array with view: %r' % (z1.view(dtype=bfloat16)[0],))

z2 = zarr.open(mode='r', store=my_store)
print('Reopening with original dtype: %r' % (z2[0],))
print('Reopening with original dtype with view: %r' % (z2.view(dtype=bfloat16)[0],))

my_store['.zarray'] = my_store['.zarray'].replace(b'<V2', b'bfloat16')
z3 = zarr.open(mode='r', store=my_store)
print('With adjusted dtype: %r' % (z3[0],))

Output is:

Original array: void(b'\x28\x42')
Original array with view: 42
Reopening with original dtype: void(b'\x28\x42')
Reopening with original dtype with view: 42
With adjusted dtype: 42

Replacing the stored dtype in the .zarray file with "bfloat16" seems to be the only way to get zarr to use bfloat16 as the data type when opening. (That requires registering the data type in np.typeDict, which jax does not do but probably should do.) That does not allow specifying the byte order, but supporting a big endian machine may not be too important.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew features or improvements

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions