diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index c17de52028..0e74d5fc98 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -444,8 +444,14 @@ def ffn_or_attn_only(mod, fqn): _, precision, group_size = quantization.split("-") group_size = int(group_size) + block_size = [1, group_size] if precision == "int4": - quantize_(model, FbgemmConfig("bf16i4bf16", group_size)) + quantize_( + model, + FbgemmConfig( + torch.bfloat16, torch.int4, torch.bfloat16, block_size + ), + ) else: raise NotImplementedError( f"FbegemmConfig({precision=}) not supported yet"