Skip to content

Extend data type support (for bfloat16 in particular) #2656

@nenb

Description

@nenb

Problem
I would like to read/write numpy dtype extensions (such as bfloat16) with zarr version 2. I am using ml_dtypes from JAX for the dtype extensions.

import numpy as np
import ml_dtypes
import zarr

arr = np.array([ml_dtypes.bfloat16(1)])
zarr.save('example.zarr', arr)  # ValueError: No cast function available. 

I experience a similar issue when trying to read such dtype extensions.

The problem is related to the extensibility (or lack thereof) of the kind codes in numpy. It is well described by the JAX team.

Background
bfloat16 is a very important dtype in the AI/ML community. I would like to use zarr (and specifically the Python implementation) to share models such as LLMs. However, the lack of bfloat16 support is a major blocker.

Questions

  1. Is this something that could be resolved with zarr v2?
  2. Is this something that I could resolve using zarr v3 today?
  3. If the answer to the previous questions was no, what would be required to support it in zarr v3 in the future?

Related issues
#711

cc @jhamman (as suggested by @TomNicholas)

Sub-issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions