-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Introduction
BFloat16 is a 16-bit floating-point format that represents the IEEE FP32 numbers truncated to the high 16 bits. BFloat16 numbers have the same exponential range as the IEEE FP32 numbers, but way fewer mantissa bits, and is primarily used in neural network computations which are very tolerable to reduced precision. Despite being a non-IEEE format, BFloat16 computations are supported in the recent and upcoming processors from Intel (Cooper Lake, Alder Lake, Sapphire Rapids), AMD (Zen 4), and ARM (Cortex-A510, Cortex-A710, Cortex-X2), and can be efficiently emulated on older hardware by zero-padding BFloat16 numbers with zeroes to convert to IEEE FP32.
Both x86 and ARM BFloat16 extensions provide instructions to compute an FP32 dot product of two two-element BFloat16 elements with addition to the FP32 accumulator. However, there're some differences in details:
x86 introduced support for BFloat16 computations with the AVX512-BF16 extension, and the BFloat16 dot production functionality is exposed via the VDPBF16PS
(Dot Product of BF16 Pairs Accumulated into Packed Single Precision) instruction. The instruction VDPBF16PS dest, src1, src2
computes
dest.fp32[i] = FMA(cast<fp32>(src1.bf16[2*i]), cast<fp32>(src2.bf16[2*i]), dest.fp32[i])
dest.fp32[i] = FMA(cast<fp32>(src1.bf16[2*i+1]), cast<fp32>(src2.bf16[2*i+1]), dest.fp32[i])
for each 32-bit lane i
in the accumulator register dest
. Additionally, denormalized numbers in inputs are treated as zeroes and denormalized numbers in outputs are replaced with zeroes.
ARM added support for BFloat16 computations with the ARMv8.2-A BFDOT
instruction (mandatory with ARMv8.6-A), which implements the following computation:
temp.fp32[i] = cast<fp32>(src1.bf16[2*i]) * cast<fp32>(src2.bf16[2*i])
temp2.fp32[i] = cast<fp32>(src1.bf16[2*i+1]) * cast<fp32>(src2.bf16[2*i+1])
temp.fp32[i] = temp.fp32[i] + temp2.fp32[i]
dest.fp32[i] = dest.fp32[i] + temp.fp32[i]
where all multiplications and additions are unfused and use non-IEEE Round-to-Odd rounding mode. Additionally, denormalized numbers in inputs are treated as zeroes, denormalized numbers in outputs are replaced with zeroes.
ARMv9.2-A introduced additional "Extended BFloat16" (EBF16) mode, which modifies the behavior of BFDOT
as follows:
temp.fp32[i] = cast<fp32>(src1.bf16[2*i] * src2.bf16[2*i] + src1.bf16[2*i+1] * src2.bf16[2*i+1])
dest.fp32[i] = dest.fp32[i] + temp.fp32[i]
where temp.fp32[i]
is calculated as a fused sum-of-products operations with only a single (standard, Round-to-Nearest-Even) rounding at the end.
Furthermore, ARMv8.6-A introduced BFMLALB
/BFMLALT
instructions which perform widening fused multiply-add of even/odd BFloat16 elements to an IEEE FP32 accumulator. A pair of these instructions can be similarly used to implement BFloat16 Dot Product primitive.
What are the instructions being proposed?
I propose a 2-element dot product with accumulation instruction with BFloat16 input elements and FP32 accumulator & output elements. The instruction has relaxed semantics to allow lowering to native VDPBF16PS
(x86) and BFDOT
(ARM) instructions, as well as FMA instructions where available.
I suggest f32x4.relaxed_dot_bf16x8_add_f32x4
as a tentative name for the instruction. In a sense it is a floating-point equivalent of the i32x4.dot_i16x8_add_s
instruction proposed in WebAssembly/simd#127.
What are the semantics of these instructions?
y = f32x4.relaxed_dot_bf16x8_add_f32(a, b, c)
computes
y.fp32[i] = y.fp32[i] + cast<fp32>(a.bf16[2*i]) * cast<fp32>(b.bf16[2*i]) + cast<fp32>(a.bf16[2*i+1]) * cast<fp32>(b.bf16[2*i+1])
The relaxed nature of the instruction manifests in several allowable options:
Evaluation ordering options
We permit two evaluation orders for the computation:
- Compute
cast<fp32>(a.bf16[2*i]) * cast<fp32>(b.bf16[2*i]) + cast<fp32>(a.bf16[2*i+1]) * cast<fp32>(b.bf16[2*i+1])
in the first step, then compute the sum withy.fp32[i]
in the second step. - Compute
y.fp32[i] += cast<fp32>(a.bf16[2*i]) * cast<fp32>(b.bf16[2*i])
in the first step, then computey.fp32[i] += cast<fp32>(a.bf16[2*i+1]) * cast<fp32>(b.bf16[2*i+1])
in the second step.
Fusion options
- Either both steps of the computation are fused, or both are unfused
Rounding options
- Operations that comprise the computation are rounded with either Round-to-Nearest-Even or Round-to-Odd rounding modes.
How will these instructions be implemented?
x86/x86-64 processors with AVX512-BF16 instruction set
- f32x4.relaxed_dot_bf16x8_add_f32x4
c = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c)
is lowered toVDPBF16PS xmm_c, xmm_a, xmm_b
y = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c)
is lowered toVMOVDQA xmm_y, xmm_c
+VDPBF16PS xmm_c, xmm_a, xmm_b
ARM64 processors with BF16 extension (using BFDOT
instructions)
- f32x4.relaxed_dot_bf16x8_add_f32x4
y = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c)
is lowered toMOV Vy.16B, Vc.16B
+BFDOT Vy.4S, Va.8H, Vb.8H
c = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c)
is lowered toBFDOT Vc.4S, Va.8H, Vb.8H
ARM64 processors with BF16 extension (using BFMLALB
/BFMLALT
instructions)
- f32x4.relaxed_dot_bf16x8_add_f32x4
y = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c)
is lowered toMOV Vy.16B, Vc.16B
+BFMLALB Vy.4S, Va.4H, Vb.4H
+BFMLALT Vy.4S, Va.8H, Vb.8H
c = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c)
is lowered toBFMLALB Vc.4S, Va.4H, Vb.4H
+BFMLALT Vc.4S, Va.8H, Vb.8H
Reference lowering through the WAsm Relaxed SIMD instruction set
- f32x4.relaxed_dot_bf16x8_add_f32x4
y = f32x4.relaxed_dot_bf16x8_add_f32x4(a, b, c)
is lowered to:const v128_t a_lo = wasm_i32x4_shl(a, 16)
const v128_t b_lo = wasm_i32x4_shl(b, 16)
const v128_t a_hi = wasm_v128_and(a, wasm_i32x4_const_splat(0xFFFF0000))
const v128_t b_hi = wasm_v128_and(b, wasm_i32x4_const_splat(0xFFFF0000))
y = f32x4.relaxed_fma(a_lo, b_lo, c)
y = f32x4.relaxed_fma(a_hi, b_hi, y)
How does behavior differ across processors? What new fingerprinting surfaces will be exposed?
The use of AVX512-BF16 VDPBF16PS
instruction can be detected through testing for behavior on denormal inputs and outputs.
The use of ARM BFDOT
instruction can be detected through testing for behavior on denormal inputs and outputs, or testing for Round-to-Odd rounding. However, an ARM implementation using a pair of BFMLALB
/BFMLALT
instructions would be indistinguishable from a generic lowering using FMA instructions (but probably slower than BFDOT
lowering).