Skip to content

Optimization: Convert Hidden States to Bfloat16 #560

@FabianSchuetze

Description

@FabianSchuetze

In line with #543 and #164 I have worked a bit with changing the precision of the pre_att_rms_out from float to BF16. First, as also suggested, I made only a tiny code change. The patch would be:

diff --git a/gemma/activations.h b/gemma/activations.h
index 86345e2..a994e98 100644
--- a/gemma/activations.h
+++ b/gemma/activations.h
@@ -41,7 +41,7 @@ struct Activations {
   RowVectorBatch<float> logits;
 
   // Attention
-  RowVectorBatch<float> pre_att_rms_out;
+  RowVectorBatch<BF16> pre_att_rms_out;
   RowVectorBatch<float> att;      // attention vector
   RowVectorBatch<float> att_out;  // attention output
   // Accumulation of attention outputs over heads
@@ -88,7 +88,7 @@ struct Activations {
       logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
     }
 
-    pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
+    pre_att_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
     att = RowVectorBatch<float>(
         Extents2D(batch_size, heads * weights_config.seq_len));
     att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h
index ccb34f0..e7eece4 100644
--- a/gemma/gemma-inl.h
+++ b/gemma/gemma-inl.h
@@ -293,7 +293,7 @@ class GemmaAttention {
         // Proceed row by row because there will be wraparound.
         for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
              ++interleaved_idx) {
-          const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx);
+          const hwy::bfloat16_t* x = activations_.pre_att_rms_out.Batch(interleaved_idx);
           const size_t query_idx = interleaved_idx % num_queries_;
           const size_t batch_idx = interleaved_idx / num_queries_;
           KVCache& kv_cache = kv_caches_[query_idx];

I wonder if this change was too small?

My (unscientific) measurements showed no speed gain in doing so. I have tried it on an x86 laptop that doesn't support BF16 and on an ARM chip that does (Cortex X3, inside a tablet). I am a bit surprised that the latency hasn't improved.

I understand the matmul is done with bf16 types. The decompression of B seems to take up a major chunk of the runtime, whilst "decompression" (depending on the type of pre_att_rms_out) of A is unnoticeable. The profiler says:

MM.NT.DecB                              :  140318080 x              26 = 190.491139
MM.NT_K.DecB                            :   12238848 x             121 = 77.107746
MM.NT                                   :     107076 x           10869 = 60.612900
Gen.input                               :          2 x       485175727 = 50.539138
MM.NT_K                                 :      21248 x           15169 = 16.787086
...
MM.DecompressA                          :     128324 x               3 =  0.022761

Which makes me wonder if B should be loaded in BF16 instead of decompressing it? If the precision of pre_att_rms_out predominantly affects the decompression time and the RMSNorm computation time, maybe it's unsurprising that the latency effects are not of first order.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions