From 3263c1d71aafa815e8242ac95eb3b71a534a229e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 29 May 2025 15:12:48 -0700 Subject: [PATCH] Fix generate.py for fbgemm int4 integration Summary: Updated table: | | overall tokens/sec | TTFT | Peak Memory | Model Size | | ---------| -------------------| ------| --------------| -----------| | baseline - 1 | 131.65 | 0.0220 | 16.24 GB | 15.01 GB | | baseline - 128| 76.38 | 0.0544 | 26.92 GB | 15.01 GB| | int4wo - 1 | 207.69 | 0.0288 | 6.41 GB | 3.99 GB | | int4wo - 128 | 12.85 | 0.4223 | 16.01 GB | 3.99 GB | | fbgemm-int4 - 1 (w/ compile) | 61.12 | 0.0212 | 7.59 GB | 3.00 GB | | fbgemm-int4 - 128 (w/ compile) | 71.23 | 0.0576 | 16.13 GB | 3.99 GB | Verified that compile works: python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization fbgemm-int4-128 --batch_size 1 --compile ========== Average overall tokens/sec: 61.12 Average decode tokens/sec: 61.5512 s Average TTFT: 0.0212 s Average tokens/sec: 61.12 Average Bandwidth: 243.70 GB/s Peak Memory Usage: 7.59 GB Model Size: 3.99 GB python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization fbgemm-int4-128 --batch_size 128 --compile ========== Average overall tokens/sec: 71.23 Average decode tokens/sec: 72.8871 s Average TTFT: 0.0576 s Average tokens/sec: 71.23 Average tokens/sec including batches 9116.81 Average Bandwidth: 284.00 GB/s Peak Memory Usage: 16.13 GB Model Size: 3.99 GB Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/generate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index c17de52028..0e74d5fc98 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -444,8 +444,14 @@ def ffn_or_attn_only(mod, fqn): _, precision, group_size = quantization.split("-") group_size = int(group_size) + block_size = [1, group_size] if precision == "int4": - quantize_(model, FbgemmConfig("bf16i4bf16", group_size)) + quantize_( + model, + FbgemmConfig( + torch.bfloat16, torch.int4, torch.bfloat16, block_size + ), + ) else: raise NotImplementedError( f"FbegemmConfig({precision=}) not supported yet"