-
Notifications
You must be signed in to change notification settings - Fork 557
Description
In line with #543 and #164 I have worked a bit with changing the precision of the pre_att_rms_out
from float to BF16. First, as also suggested, I made only a tiny code change. The patch would be:
diff --git a/gemma/activations.h b/gemma/activations.h
index 86345e2..a994e98 100644
--- a/gemma/activations.h
+++ b/gemma/activations.h
@@ -41,7 +41,7 @@ struct Activations {
RowVectorBatch<float> logits;
// Attention
- RowVectorBatch<float> pre_att_rms_out;
+ RowVectorBatch<BF16> pre_att_rms_out;
RowVectorBatch<float> att; // attention vector
RowVectorBatch<float> att_out; // attention output
// Accumulation of attention outputs over heads
@@ -88,7 +88,7 @@ struct Activations {
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
}
- pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
+ pre_att_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
att = RowVectorBatch<float>(
Extents2D(batch_size, heads * weights_config.seq_len));
att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h
index ccb34f0..e7eece4 100644
--- a/gemma/gemma-inl.h
+++ b/gemma/gemma-inl.h
@@ -293,7 +293,7 @@ class GemmaAttention {
// Proceed row by row because there will be wraparound.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
- const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx);
+ const hwy::bfloat16_t* x = activations_.pre_att_rms_out.Batch(interleaved_idx);
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
KVCache& kv_cache = kv_caches_[query_idx];
I wonder if this change was too small?
My (unscientific) measurements showed no speed gain in doing so. I have tried it on an x86 laptop that doesn't support BF16 and on an ARM chip that does (Cortex X3, inside a tablet). I am a bit surprised that the latency hasn't improved.
I understand the matmul is done with bf16 types. The decompression of B seems to take up a major chunk of the runtime, whilst "decompression" (depending on the type of pre_att_rms_out) of A is unnoticeable. The profiler says:
MM.NT.DecB : 140318080 x 26 = 190.491139
MM.NT_K.DecB : 12238848 x 121 = 77.107746
MM.NT : 107076 x 10869 = 60.612900
Gen.input : 2 x 485175727 = 50.539138
MM.NT_K : 21248 x 15169 = 16.787086
...
MM.DecompressA : 128324 x 3 = 0.022761
Which makes me wonder if B should be loaded in BF16 instead of decompressing it? If the precision of pre_att_rms_out predominantly affects the decompression time and the RMSNorm computation time, maybe it's unsurprising that the latency effects are not of first order.