Skip to content

Relaxed BFloat16 Dot Product instruction #77

@Maratyszcza

Description

@Maratyszcza

Introduction

BFloat16 is a 16-bit floating-point format that represents the IEEE FP32 numbers truncated to the high 16 bits. BFloat16 numbers have the same exponential range as the IEEE FP32 numbers, but way fewer mantissa bits, and is primarily used in neural network computations which are very tolerable to reduced precision. Despite being a non-IEEE format, BFloat16 computations are supported in the recent and upcoming processors from Intel (Cooper Lake, Alder Lake, Sapphire Rapids), AMD (Zen 4), and ARM (Cortex-A510, Cortex-A710, Cortex-X2), and can be efficiently emulated on older hardware by zero-padding BFloat16 numbers with zeroes to convert to IEEE FP32.

Both x86 and ARM BFloat16 extensions provide instructions to compute an FP32 dot product of two two-element BFloat16 elements with addition to the FP32 accumulator. However, there're some differences in details:

x86 introduced support for BFloat16 computations with the AVX512-BF16 extension, and the BFloat16 dot production functionality is exposed via the VDPBF16PS (Dot Product of BF16 Pairs Accumulated into Packed Single Precision) instruction. The instruction VDPBF16PS dest, src1, src2 computes

dest.fp32[i] = FMA(cast<fp32>(src1.bf16[2*i]), cast<fp32>(src2.bf16[2*i]), dest.fp32[i])
dest.fp32[i] = FMA(cast<fp32>(src1.bf16[2*i+1]), cast<fp32>(src2.bf16[2*i+1]), dest.fp32[i])

for each 32-bit lane i in the accumulator register dest. Additionally, denormalized numbers in inputs are treated as zeroes and denormalized numbers in outputs are replaced with zeroes.

ARM added support for BFloat16 computations with the ARMv8.2-A BFDOT instruction (mandatory with ARMv8.6-A), which implements the following computation:

temp.fp32[i] = cast<fp32>(src1.bf16[2*i]) * cast<fp32>(src2.bf16[2*i])
temp2.fp32[i] = cast<fp32>(src1.bf16[2*i+1]) * cast<fp32>(src2.bf16[2*i+1])
temp.fp32[i] = temp.fp32[i] + temp2.fp32[i]
dest.fp32[i] = dest.fp32[i] + temp.fp32[i]

where all multiplications and additions are unfused and use non-IEEE Round-to-Odd rounding mode. Additionally, denormalized numbers in inputs are treated as zeroes, denormalized numbers in outputs are replaced with zeroes.

ARMv9.2-A introduced additional "Extended BFloat16" (EBF16) mode, which modifies the behavior of BFDOT as follows:

temp.fp32[i] = cast<fp32>(src1.bf16[2*i] * src2.bf16[2*i] + src1.bf16[2*i+1] * src2.bf16[2*i+1])
dest.fp32[i] = dest.fp32[i] + temp.fp32[i]

where temp.fp32[i] is calculated as a fused sum-of-products operations with only a single (standard, Round-to-Nearest-Even) rounding at the end.

Furthermore, ARMv8.6-A introduced BFMLALB/BFMLALT instructions which perform widening fused multiply-add of even/odd BFloat16 elements to an IEEE FP32 accumulator. A pair of these instructions can be similarly used to implement BFloat16 Dot Product primitive.

What are the instructions being proposed?

I propose a 2-element dot product with accumulation instruction with BFloat16 input elements and FP32 accumulator & output elements. The instruction has relaxed semantics to allow lowering to native VDPBF16PS (x86) and BFDOT (ARM) instructions, as well as FMA instructions where available.

I suggest f32x4.relaxed_dot_bf16x8_add_f32x4 as a tentative name for the instruction. In a sense it is a floating-point equivalent of the i32x4.dot_i16x8_add_s instruction proposed in WebAssembly/simd#127.

What are the semantics of these instructions?

y = f32x4.relaxed_dot_bf16x8_add_f32(a, b, c) computes

y.fp32[i] = y.fp32[i] + cast<fp32>(a.bf16[2*i]) * cast<fp32>(b.bf16[2*i]) + cast<fp32>(a.bf16[2*i+1]) * cast<fp32>(b.bf16[2*i+1])

The relaxed nature of the instruction manifests in several allowable options:

Evaluation ordering options

We permit two evaluation orders for the computation:

  • Compute cast<fp32>(a.bf16[2*i]) * cast<fp32>(b.bf16[2*i]) + cast<fp32>(a.bf16[2*i+1]) * cast<fp32>(b.bf16[2*i+1]) in the first step, then compute the sum with y.fp32[i] in the second step.
  • Compute y.fp32[i] += cast<fp32>(a.bf16[2*i]) * cast<fp32>(b.bf16[2*i]) in the first step, then compute y.fp32[i] += cast<fp32>(a.bf16[2*i+1]) * cast<fp32>(b.bf16[2*i+1]) in the second step.

Fusion options

  • Either both steps of the computation are fused, or both are unfused

Rounding options

  • Operations that comprise the computation are rounded with either Round-to-Nearest-Even or Round-to-Odd rounding modes.

How will these instructions be implemented?

x86/x86-64 processors with AVX512-BF16 instruction set

  • f32x4.relaxed_dot_bf16x8_add_f32x4
    • c = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c) is lowered to VDPBF16PS xmm_c, xmm_a, xmm_b
    • y = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c) is lowered to VMOVDQA xmm_y, xmm_c + VDPBF16PS xmm_c, xmm_a, xmm_b

ARM64 processors with BF16 extension (using BFDOT instructions)

  • f32x4.relaxed_dot_bf16x8_add_f32x4
    • y = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c) is lowered to MOV Vy.16B, Vc.16B + BFDOT Vy.4S, Va.8H, Vb.8H
    • c = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c) is lowered to BFDOT Vc.4S, Va.8H, Vb.8H

ARM64 processors with BF16 extension (using BFMLALB/BFMLALT instructions)

  • f32x4.relaxed_dot_bf16x8_add_f32x4
    • y = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c) is lowered to MOV Vy.16B, Vc.16B + BFMLALB Vy.4S, Va.4H, Vb.4H + BFMLALT Vy.4S, Va.8H, Vb.8H
    • c = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c) is lowered to BFMLALB Vc.4S, Va.4H, Vb.4H + BFMLALT Vc.4S, Va.8H, Vb.8H

Reference lowering through the WAsm Relaxed SIMD instruction set

  • f32x4.relaxed_dot_bf16x8_add_f32x4
    • y = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c) is lowered to:
      • const v128_t a_lo = wasm_i32x4_shl(a, 16)
      • const v128_t b_lo = wasm_i32x4_shl(b, 16)
      • const v128_t a_hi = wasm_v128_and(a, wasm_i32x4_const_splat(0xFFFF0000))
      • const v128_t b_hi = wasm_v128_and(b, wasm_i32x4_const_splat(0xFFFF0000))
      • y = f32x4.relaxed_fma(a_lo, b_lo, c)
      • y = f32x4.relaxed_fma(a_hi, b_hi, y)

How does behavior differ across processors? What new fingerprinting surfaces will be exposed?

The use of AVX512-BF16 VDPBF16PS instruction can be detected through testing for behavior on denormal inputs and outputs.
The use of ARM BFDOT instruction can be detected through testing for behavior on denormal inputs and outputs, or testing for Round-to-Odd rounding. However, an ARM implementation using a pair of BFMLALB/BFMLALT instructions would be indistinguishable from a generic lowering using FMA instructions (but probably slower than BFDOT lowering).

What use cases are there?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions