From 8afdb1f65e5d172266439d3dbf2c4e9e71043392 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 17 Jun 2025 21:24:28 -0700 Subject: [PATCH] Make AWQ more general Summary: * Added AWQConfig that takes a base config and made corresponding changes in other parts of the flow Test Plan: Tested on Phi4-mini and Qwen3-8B Qwen3-8B |Task | calibration_limit | no-awq | awq | |-----+------------------+ ------+ ------+ |leaderboard_math_hard (v3) | 2 | 0.3543 | 0.4371 | |gpqa_main_zeroshot | 50 | 0.32 | 0.36 | |mmlu | 5 | 0.7372 | 0.7463 | |bbh | 1 | 0.7385 | 0.7556| Phi4-mini | Task | calibration_limit | no-awq | awq | |------+------------------+--------+------| | mmlu_pro | 2 | 0.4057 | 0.4757 | | gsm8k | 5 | 0.72 | 0.76 | Reviewers: Subscribers: Tasks: Tags: --- test/prototype/test_awq.py | 187 ++++++++++------ .../quantization/test_config_serialization.py | 5 + torchao/_models/_eval.py | 20 +- torchao/_models/llama/eval.py | 40 ++++ torchao/core/config.py | 1 + torchao/prototype/awq/__init__.py | 8 +- torchao/prototype/awq/api.py | 200 +++++------------- torchao/prototype/awq/core.py | 143 ++++--------- torchao/prototype/awq/example.py | 159 ++++++++------ torchao/prototype/moe_quant/utils.py | 13 +- .../quantization/linear_activation_scale.py | 60 +----- torchao/quantization/quant_api.py | 6 +- torchao/utils.py | 13 +- 13 files changed, 407 insertions(+), 448 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 34ddd9c5e9..22c9229bf0 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -3,21 +3,20 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import os -from copy import deepcopy +import copy +import tempfile import pytest import torch -from torchao.quantization import quantize_ -from torchao.testing.utils import skip_if_rocm +from torchao.quantization import FbgemmConfig, quantize_ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, ) if TORCH_VERSION_AT_LEAST_2_3: - from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ + from torchao.prototype.awq import AWQConfig, AWQStep class ToyLinearModel(torch.nn.Module): @@ -25,7 +24,7 @@ def __init__(self, m=512, n=256, k=128): super().__init__() self.linear1 = torch.nn.Linear(m, n, bias=False) self.linear2 = torch.nn.Linear(n, k, bias=False) - self.linear3 = torch.nn.Linear(k, 1, bias=False) + self.linear3 = torch.nn.Linear(k, 64, bias=False) def example_inputs( self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda" @@ -44,36 +43,74 @@ def forward(self, x): return x -devices = ["cpu", "cuda"] -# torch.uintx dtypes are introduced in 2.3 -if TORCH_VERSION_AT_LEAST_2_3: - qdtypes = (torch.uint4, torch.uint7) -else: - qdtypes = () - - @pytest.fixture(autouse=True) def run_before_and_after_tests(): yield torch._dynamo.reset() # reset cache between tests -@pytest.mark.parametrize("device", devices) -@pytest.mark.parametrize("qdtype", qdtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") -@pytest.mark.skip("Temporarily skipping to unpin nightiles") -def test_awq_loading(device, qdtype): - if qdtype == torch.uint4 and device == "cpu": - pytest.skip("uint4 not supported on cpu") +def test_awq_functionality(): + device = "cuda" + dataset_size = 100 + l1, l2, l3 = 512, 256, 128 + original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs + group_size = 128 + n_calibration_examples = 10 + sequence_length = 5 + + m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) + + # baseline quantization + base_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, + ) + m_baseline = copy.deepcopy(m) + quantize_(m_baseline, base_config) + + # awq quantization + dataset = m.example_inputs( + dataset_size, + sequence_length=sequence_length, + dtype=original_dtype, + device=device, + ) + ref_out = torch.cat([m(d.squeeze(0)) for d in dataset]) + + calibration_data = dataset[:n_calibration_examples] + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) + quantize_(m, quant_config) + + for example in calibration_data: + print("device:", example.device) + m(example) + + quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) + quantize_(m, quant_config) + + awq_out = torch.cat([m(d.squeeze(0)) for d in dataset]) + baseline_out = torch.cat([m_baseline(d.squeeze(0)) for d in dataset]) + + loss_awq = (ref_out - awq_out).pow(2).mean().item() + loss_base = (ref_out - baseline_out).pow(2).mean().item() + assert loss_awq < loss_base + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") +def test_awq_loading(): + device = "cuda" dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs - quant_dtype = qdtype group_size = 128 n_calibration_examples = 10 - n_validation_examples = 10 sequence_length = 5 m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) @@ -86,56 +123,60 @@ def test_awq_loading(device, qdtype): calibration_data = dataset[:n_calibration_examples] # calibrate - insert_awq_observer_( - m, - n_validation_examples, - sequence_length, - quant_dtype=quant_dtype, - group_size=group_size, + base_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, ) + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) + quantize_(m, quant_config) for example in calibration_data: - m(example.to(device)) + m(example) # quantize - is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - quantize_( - m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear - ) + quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) + quantize_(m, quant_config) - model_save_path = "awq_model.pth" - torch.save(m, model_save_path) - loaded_model = torch.load(model_save_path) - os.remove(model_save_path) + with tempfile.NamedTemporaryFile() as f: + torch.save(m.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) - if torch.cuda.is_available(): - m = torch.compile(m, fullgraph=True) - loaded_model = torch.compile(loaded_model, fullgraph=True) + loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) + loaded_model.load_state_dict(state_dict, assign=True) - awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset]) + m = torch.compile(m, fullgraph=True) + loaded_model = torch.compile(loaded_model, fullgraph=True) + + awq_out = torch.cat([m(d.squeeze(0)) for d in dataset]) + awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset]) assert awq_out is not None assert awq_save_load_out is not None assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2) -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@skip_if_rocm("ROCm enablement in progress") -def test_save_weights_only(): +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") +def test_awq_loading_vllm(): + """Simulate weight loading in vllm: + * prepare model weight to the same format (awq weight) + * use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint + + There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo + """ + device = "cuda" dataset_size = 100 l1, l2, l3 = 512, 256, 128 - original_dtype = torch.bfloat16 - quant_dtype = torch.uint4 - device = "cuda" + original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs group_size = 128 n_calibration_examples = 10 - n_validation_examples = 10 sequence_length = 5 m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) - m2 = deepcopy(m) dataset = m.example_inputs( dataset_size, sequence_length=sequence_length, @@ -145,35 +186,41 @@ def test_save_weights_only(): calibration_data = dataset[:n_calibration_examples] # calibrate - insert_awq_observer_( - m, - n_validation_examples, - sequence_length, - quant_dtype=quant_dtype, - group_size=group_size, + base_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, ) + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) + quantize_(m, quant_config) for example in calibration_data: - m(example.to(device)) + m(example) # quantize - is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - quantize_( - m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear - ) + quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) + quantize_(m, quant_config) + + with tempfile.NamedTemporaryFile() as f: + torch.save(m.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + + loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING) + quantize_(loaded_model, quant_config) - model_save_path = "awq_model.pth" - torch.save(m.state_dict(), model_save_path) - m2.load_state_dict( - torch.load(model_save_path), assign=True - ) # load weights only.torch.load(model_save_path) - os.remove(model_save_path) + loaded_model.linear1.weight.copy_(state_dict["linear1.weight"]) + loaded_model.linear2.weight.copy_(state_dict["linear2.weight"]) + loaded_model.linear3.weight.copy_(state_dict["linear3.weight"]) m = torch.compile(m, fullgraph=True) - m2 = torch.compile(m2, fullgraph=True) + loaded_model = torch.compile(loaded_model, fullgraph=True) - awq_out = torch.cat([m(i.squeeze(0)) for i in dataset]) - awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset]) + awq_out = torch.cat([m(d.squeeze(0)) for d in dataset]) + awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset]) assert awq_out is not None assert awq_save_load_out is not None diff --git a/test/quantization/test_config_serialization.py b/test/quantization/test_config_serialization.py index 71cf8e144d..0674680506 100644 --- a/test/quantization/test_config_serialization.py +++ b/test/quantization/test_config_serialization.py @@ -19,6 +19,10 @@ config_from_dict, config_to_dict, ) +from torchao.prototype.awq import ( + AWQConfig, + AWQStep, +) from torchao.quantization.quant_api import ( FbgemmConfig, Float8DynamicActivationFloat8WeightConfig, @@ -79,6 +83,7 @@ "linear2": Int8DynamicActivationInt4WeightConfig(), } ), + AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOAD), ] if TORCH_VERSION_AT_LEAST_2_6: diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index faf059c400..de7f010035 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -57,8 +57,13 @@ def _model_call(self, inps): max_seq_length = min(max(inps.size()), self.max_length) with torch.device(self._device): - self._model.setup_caches(self.batch_size, max_seq_length) + if hasattr(self._model, "setup_caches"): + self._model.setup_caches(self.batch_size, max_seq_length) logits = self._model(*input) + from transformers.modeling_outputs import CausalLMOutputWithPast + + if isinstance(logits, CausalLMOutputWithPast): + logits = logits.logits return logits def run_eval(self, tasks, limit): @@ -84,7 +89,11 @@ def eot_token_id(self): try: return self.tokenizer.eos_id() except: - return self.tokenizer.eos_id + try: + return self.tokenizer.eos_id + except: + idx = self.tokenizer.all_special_tokens.index("<|endoftext|>") + return self.tokenizer.all_special_ids[idx] @property def max_length(self): @@ -102,8 +111,8 @@ def batch_size(self): def device(self): return self._device - def tok_decode(self, tokens): - decoded = self.tokenizer.decode(tokens) + def tok_decode(self, tokens, **kwargs): + decoded = self.tokenizer.decode(tokens, **kwargs) return decoded def tok_encode(self, string: str, **kwargs): @@ -115,9 +124,6 @@ def tok_encode(self, string: str, **kwargs): tokens = [self.tokenizer.bos_id] + tokens return tokens - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") - class LMEvalInputRecorder(TransformerEvalWrapper): def __init__( diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 8ee15f1fd3..cc4e439a49 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -237,6 +237,46 @@ def run_evaluation( quantize_( model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64) ) + elif quantization.startswith("awq-uintx"): + from torchao._models._eval import TransformerEvalWrapper + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + + if not TORCH_VERSION_AT_LEAST_2_3: + print("Awq requires torch2.3+") + exit() + from torchao.prototype.awq import ( + AWQObservedLinear, + awq_uintx, + insert_awq_observer_, + ) + + quant_dtype = quantization.split("-")[1] + group_size = int(quantization.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.uint8) + model = model.to(device) + # get calibration data + insert_awq_observer_( + model, 1, 256, quant_dtype=quant_dtype, group_size=group_size + ) + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=256, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=["wikitext"], + limit=1, + ) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quantization + quantize_( + model, + awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ), + is_observed_linear, + ) if compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) diff --git a/torchao/core/config.py b/torchao/core/config.py index 024b29baa3..a51060accd 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -191,6 +191,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: "torchao.prototype.quantization", "torchao.prototype.mx_formats", "torchao.dtypes", + "torchao.prototype.awq", } diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index 570b0821d4..cd5c447d4c 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,8 +1,8 @@ -from .api import awq_uintx, insert_awq_observer_ -from .core import AWQObservedLinear +from .api import AWQConfig +from .core import AWQObservedLinear, AWQStep __all__ = [ - "awq_uintx", - "insert_awq_observer_", "AWQObservedLinear", + "AWQConfig", + "AWQStep", ] diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 5806c29ce6..58e2659893 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -5,182 +5,94 @@ # LICENSE file in the root directory of this source tree. import types from dataclasses import dataclass -from typing import Optional import torch -import torchao from torchao.core.config import AOBaseConfig -from torchao.dtypes import ( - Int4XPULayout, - Layout, - TensorCoreTiledLayout, - to_affine_quantized_intx, -) -from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata -from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_api import ( _linear_extra_repr, - _replace_with_custom_fn_if_matches_filter, -) -from torchao.quantization.quant_primitives import ( - _DTYPE_TO_QVALUE_BOUNDS, - MappingType, - ZeroPointDomain, ) from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) +from torchao.utils import DummyModule from .core import ( AWQObservedLinear, AWQObserver, + AWQStep, ) -assert len(_DTYPE_TO_BIT_WIDTH) > 0, ( - "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" -) - - -def insert_awq_observer_( - model: torch.nn.Module, - n_validation_examples: int, - validation_sequence_len: int, - quant_dtype: torch.dtype = torch.uint4, - scale_search_space_size: int = 20, - group_size: int = 128, -): - """ - Inserts AWQObserver into Linear layers of a given model. - - Args: - model: The model to be modified (in place). Ensure model is on the desired device for calibration - n_validation_examples: Number of examples used to validate scale options - validation_sequence_len: Number of tokens in each validation example - quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 - scale search space size: how many different scale options to try. Original AWQ implementation uses 20. A larger size can lead to better results but takes longer to calibrate - group_size: Quantization granularity. Use -1 for channel wise quantization - """ - _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, ( - "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" - ) - # AQT config - mapping_type = MappingType.ASYMMETRIC - quantization_granularity = PerGroup(group_size) - quant_min = 0 - quant_max = ( - 255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1 - ) - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - - def replace_with_observer(layer): - # creates observer and replaces linear layers with AWQObservedLinear layers - observer = AWQObserver( - layer.weight, - layer.bias, - quantization_granularity, - mapping_type, - quant_dtype, - n_validation_examples, - validation_sequence_len, - scale_search_space_size, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - zero_point_dtype=zero_point_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) - return AWQObservedLinear.from_float(layer, observer) - - _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) - @dataclass -class AWQUIntXConfig(AOBaseConfig): +class AWQConfig(AOBaseConfig): """ Configuration for quantizing linear layers when passed into quantize_() Args: - quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 - `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` - group_size: Quantization granularity. Use -1 for channel wise quantization - weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used - set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. + base_config (AOBaseConfig): The quantization config that we can apply awq on top of, e.g. 8da4w, int4 weight only + step (AWQStep): specifies the step for AWQ, one of PREPARE, CONVERT and PREPARE_FOR_LOAD indicating the step of AWQ process + PREPARE: insert AWQ Observers to linear + CONVERT: convert the observed linear modules to linear modules with awq quantized weights + PREPARE_FOR_LOADING: convert the floating point model to a dummy awq quantized model, so we can + load the quantized weights through copy_ later + scale_search_space_size (int): the number of scales to search for """ - quant_dtype: torch.dtype = torch.uint4 - layout: Optional[Layout] = TensorCoreTiledLayout(inner_k_tiles=8) - group_size: int = 64 - use_hqq: bool = False - set_inductor_config: bool = True - + base_config: AOBaseConfig + step: AWQStep + scale_search_space_size: int = 20 -# for bc -awq_uintx = AWQUIntXConfig - -@register_quantize_module_handler(AWQUIntXConfig) -def _awq_uintx_transform( +@register_quantize_module_handler(AWQConfig) +def _awq_transform( module: torch.nn.Module, - config: AWQUIntXConfig, + config: AWQConfig, ) -> torch.nn.Module: - quant_dtype = config.quant_dtype - group_size = config.group_size - use_hqq = config.use_hqq - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - observed_linear = module - - assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, ( - "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" - ) + step = config.step + scale_search_space_size = config.scale_search_space_size + observed_linear = None + base_config = config.base_config - equalization_scale = observed_linear.act_obs.calculate_qparams() - # AQT config - if quant_dtype == torch.uint4: - target_dtype = torch.int32 - eps = 1e-6 - preserve_zero = False - _layout = config.layout - if isinstance(_layout, Int4XPULayout): - zero_point_dtype = torch.int8 - zero_point_domain = ZeroPointDomain.INT - else: - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + if step == AWQStep.PREPARE: + observer = AWQObserver( + module.weight, + module.bias, + base_config, + scale_search_space_size, + ) + return AWQObservedLinear.from_float(module, observer) + elif step == AWQStep.PREPARE_FOR_LOADING: + # loading from pre-quantized checkpoint + observer = AWQObserver( + module.weight, + module.bias, + base_config, + scale_search_space_size, + ) + observed_linear = AWQObservedLinear.from_float(module, observer) + example_input = torch.randn( + (1, module.weight.shape[1]), + device=module.weight.device, + dtype=module.weight.dtype, + ) + observed_linear(example_input) else: - target_dtype = torch.uint8 - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - _layout = UintxLayout(quant_dtype) + assert step == AWQStep.CONVERT, f"Unexpected step: {step}" + if not isinstance(module, AWQObservedLinear): + print(f"convert: module is not AWQObservedLinear, skipping: {type(module)}") + return module + observed_linear = module - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] - qw = to_affine_quantized_intx( - observed_linear.weight * equalization_scale, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - _layout=_layout, - use_hqq=use_hqq, - ) + assert observed_linear is not None + equalization_scale = observed_linear.act_obs.calculate_qparams() + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + dummy_mod = DummyModule(observed_linear.weight * equalization_scale) + quant_mod = base_config_handler(dummy_mod, config.base_config) + qw = quant_mod.weight qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) linear = torch.nn.Linear( @@ -191,6 +103,6 @@ def _awq_uintx_transform( dtype=observed_linear.weight.dtype, ) linear.weight = torch.nn.Parameter(qw, requires_grad=False) - linear.extra_repr = types.MethodType(_linear_extra_repr, module) + linear.extra_repr = types.MethodType(_linear_extra_repr, linear) linear.bias = observed_linear.bias return linear diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index e5ee96fea2..c26a036733 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -3,145 +3,94 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from enum import Enum from typing import Optional import torch import torch.nn.functional as F -from torchao.dtypes import to_affine_quantized_intx -from torchao.dtypes.uintx.uintx_layout import UintxLayout -from torchao.quantization.granularity import Granularity -from torchao.quantization.observer import ( - AffineQuantizedObserverBase, -) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, +from torchao.core.config import AOBaseConfig +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, ) +from torchao.utils import DummyModule + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class AWQStep(str, Enum): + PREPARE = "prepare" + CONVERT = "convert" + PREPARE_FOR_LOADING = "prepare_for_loading" + +@torch.no_grad() +def get_act_scale(x): + return x.abs().view(-1, x.shape[-1]).mean(0) -class AWQObserver(AffineQuantizedObserverBase): + +class AWQObserver(torch.nn.Module): def __init__( self, weight: torch.Tensor, - bias: torch.Tensor, - quantization_granularity: Granularity, - mapping_type: MappingType, - target_dtype: torch.dtype, - n_validation_examples: int, - validation_sequence_len: int, + bias: Optional[torch.Tensor], + base_config: AOBaseConfig, scale_search_space_size: int = 20, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, - scale_dtype: Optional[torch.dtype] = None, - zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero: Optional[bool] = True, - zero_point_domain=ZeroPointDomain.INT, ): """ A custom observer for Activation aware Weight Quantization (AWQ) + Note: this only applies to weight only quantization: https://github.com/pytorch/ao/issues/2388#issuecomment-3062863647 Args: - weight: The weight tensor to be observed. - bias: The bias tensor to be observed. - quantization_granularity: Granularity which specifies how many weights share the same scale/zero point - input_dtype: The data type of the input tensor. - mapping_type: Always set to asymmetric - target_dtype: The target data type of the quantized tensor - n_validation_examples: Number of examples used to calibrate observer - validation_sequence_len: Number of tokens in each example - scale_search_space_size: The number of scales to search for. - quant_min: The minimum quantized value - quant_max: The maximum quantized value - eps: The minimum scale. - scale_dtype: The data type of the scale tensor. - zero_point_dtype: The data type of the zero point tensor. - preserve_zero: A flag to indicate whether we need zero to be exactly - representable or not. - zero_point_domain: The domain of the zero point. + weight (torch.Tensor: The weight tensor to be observed. + bias (Optional[torch.Tensor]): The bias tensor to be observed. + config (AOBaseConfig): the configuration for quantize_, that we'll use to apply awq on top of + scale_search_space_size (int): search space size for searching the best scale for weight and input activation """ - super().__init__( - mapping_type, - target_dtype, - quantization_granularity, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - ) - self.quantization_granularity = quantization_granularity + super().__init__() + self.base_config = base_config self.weight = weight self.bias = bias - self.n_validation_examples = n_validation_examples - self.validation_sequence_len = validation_sequence_len - self.calibration_token_count = 0 self.inputs = [] - self.outputs = [] self.scale_options = scale_search_space_size self.device = self.weight.device - self.average = torch.zeros((1, weight.shape[1]), device=self.device) if self.bias is not None: self.bias.to(self.device) @torch.no_grad() def forward(self, input: torch.Tensor, output: torch.Tensor): - # import pdb - # pdb.set_trace() - # print(input.shape, input.abs().sum(1).shape, self.average.shape) - if len(self.inputs) < self.n_validation_examples: - self.inputs.append(input.to("cpu")) - self.outputs.append(output.to("cpu")) - self.calibration_token_count += input.shape[-2] - self.average += input.abs().sum(-2) + self.inputs.append(input.to("cpu")) def calculate_qparams(self): - # import pdb - # pdb.set_trace() - assert self.outputs != None, ( + assert self.inputs != None, ( "calibrate observer first by running model on exemplar data" ) - self.average /= self.calibration_token_count - for i in range(self.n_validation_examples): + for i in range(len(self.inputs)): self.inputs[i] = self.inputs[i].to(self.device) - self.outputs[i] = self.outputs[i].to(self.device) + if self.bias is not None: + self.bias = self.bias.to(self.device) + + acc = torch.cat(self.inputs, dim=-2) + x_max = get_act_scale(acc) best_loss = float("inf") best_scales = None for i in range(self.scale_options): ratio = i * 1 / self.scale_options - scales = self.average.pow(ratio).to(self.weight.dtype) + scales = x_max.pow(ratio).to(self.weight.dtype).clamp(min=1e-4).view(-1) + if best_scales is None: + best_scales = scales scales = scales / (scales.max() * scales.min()).sqrt() - layout = UintxLayout(self.target_dtype) - # regardless of weight dtype, we have to store as packed uint8 tensors - tensor_dtype = torch.uint8 - w = to_affine_quantized_intx( - self.weight * scales, - self.mapping_type, - (1, self.quantization_granularity.group_size), - tensor_dtype, - quant_min=self.quant_min, - quant_max=self.quant_max, - eps=self.eps, - scale_dtype=self.scale_dtype, - zero_point_dtype=self.zero_point_dtype, - preserve_zero=self.preserve_zero, - zero_point_domain=self.zero_point_domain, - _layout=layout, - ) - loss = 0 - for i in range(self.n_validation_examples): - q_out = F.linear(self.inputs[i] / scales, w, self.bias) - loss += (self.outputs[i] - q_out).pow(2).mean().item() + config_handler = _QUANTIZE_CONFIG_HANDLER[type(self.base_config)] + dummy_mod = DummyModule(self.weight * scales) + quant_mod = config_handler(dummy_mod, self.base_config) + w = quant_mod.weight + orig_out = F.linear(acc, self.weight, self.bias) + q_out = F.linear(acc / scales, w, self.bias) + loss = (orig_out - q_out).pow(2).mean().item() if loss < best_loss: best_scales = scales best_loss = loss - for i in range(self.n_validation_examples): - self.inputs[i].to("cpu") - self.outputs[i].to("cpu") return best_scales.detach() diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 7ff6092b05..533e174740 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -9,11 +9,15 @@ import torch from datasets import load_dataset from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig -from torchao.dtypes import Int4XPULayout -from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ -from torchao.quantization import int4_weight_only, quantize_ +from torchao.prototype.awq import ( + AWQConfig, + AWQStep, +) +from torchao.quantization import ( + quantize_, +) # adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 @@ -111,6 +115,7 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): "hellaswag", "gsm8k", "mmlu", + "bbh", ] results = {} if "PPL" in tasks: @@ -180,20 +185,30 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): print("MMLU avg acc", np.mean(k)) results["mmlu"] = np.mean(k) + if "bbh" in tasks: + for task in [("leaderboard_bbh", 3)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + results["bbh"] = results[tag] + return results -def wikitext2_ppl( +def quantize_and_eval( repo_id: str, quant: str, tasks: list[str], - calibration_size: int, + max_seq_length: int, + calibration_limit: int, validation_size: int, device: str, precision: torch.dtype, - sequence_length: int, compile: bool, model_save_path: str, + model_save_hf_hub_path: str, ): print(f"Loading model on {device}...") torch.manual_seed(34) @@ -206,60 +221,78 @@ def wikitext2_ppl( .to(device) ) print(f"Time to load model: {time.time() - t0:.02f} seconds") - if quant.startswith("awq"): - quant_dtype = quant.split("-")[1] + if quant.startswith("awq-int4wo"): group_size = int(quant.split("-")[2]) - quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) - print(f"running {quant_dtype} calibration") - t0 = time.time() - # insert observers to find average magnitude and calculate scales - insert_awq_observer_( - model, - validation_size, - sequence_length, - quant_dtype=quant_dtype, - group_size=group_size, - ) - calibration_data = get_calib_dataset( - tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length + print(f"running {quant} quantization with group size {group_size}") + # TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon + from torchao.quantization import FbgemmConfig + + # use_hqq = True + # base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) + base_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, ) - for batch in calibration_data: - model(batch.to(device)) - batch.to("cpu") - print(f"time for calibration: {time.time() - t0:.02f} seconds") - - is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - use_hqq = "hqq" in quant - print(f"running {quant_dtype} quantization") + print(f"running {quant} prepare and calibrate") t0 = time.time() - awq_uintx_config = awq_uintx( - quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq - ) - if "xpu" in device: - awq_uintx_config.layout = Int4XPULayout() + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) + quantize_( model, - awq_uintx_config, - is_observed_linear, + quant_config, ) - print(f"time for quantization: {time.time() - t0:.02f} seconds") - if model_save_path is not None: - print(f"Saving model to {model_save_path}") - torch.save(model, model_save_path) + from torchao._models._eval import TransformerEvalWrapper + + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=max_seq_length, + device=device, + ).run_eval( + tasks=tasks, + limit=calibration_limit, + ) + + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + print(f"running {quant} convert") + t0 = time.time() + quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING) + model.config.quantization_config = TorchAoConfig(quant_config) + elif quant.startswith("int4wo"): group_size = int(quant.split("-")[1]) - use_hqq = "hqq" in quant print(f"running {quant} quantization with group size {group_size}") - int4_weight_only_config = int4_weight_only( - group_size=group_size, use_hqq=use_hqq + # TODO: enable after refactor: https://github.com/pytorch/ao/pull/2474 + # use_hqq = "hqq" in quant + # base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) + int4_weight_only_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, ) - if "xpu" in device: - int4_weight_only_config.layout = Int4XPULayout() quantize_(model, int4_weight_only_config) + + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + if compile: model = torch.compile(model) - return benchmark(model, tokenizer, sequence_length, tasks=tasks, device=device) + return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device) if __name__ == "__main__": @@ -268,20 +301,21 @@ def wikitext2_ppl( ) # Optional arguments with default values - parser.add_argument("repo", type=str, help="Repository ID of the model.") + parser.add_argument("--repo", type=str, help="Repository ID of the model.") parser.add_argument( - "quant", + "--quant", type=str, - help="Quantization method. Options are either awq-uint- for x =[1..8], int4wo-, or int4wo--hqq.", + help="Quantization method. Options are either awq-int4wo-, or int4wo-.", ) parser.add_argument( "--tasks", - type=list[str], + nargs="+", + type=str, help="Task to benchmark model on. Either PPL or QA", default=["PPL"], ) parser.add_argument( - "--calibration_samples", + "--calibration_limit", type=int, default=10, help="Number of samples to use for calibration. Default is 10.", @@ -302,10 +336,10 @@ def wikitext2_ppl( help="Precision type. Default is 'bfloat16'.", ) parser.add_argument( - "--seq_len", + "--max_seq_length", type=int, - default=512, - help="Length of examples to calibrate and evaluate model on. Default is 512", + default=2048, + help="Maximum sequence length of examples to calibrate and evaluate model on. Default is 2048", ) parser.add_argument( "--compile", @@ -318,22 +352,29 @@ def wikitext2_ppl( default=None, help="Path to store the scale values.", ) + parser.add_argument( + "--model_save_hf_hub_path", + type=str, + default=None, + help="Huggingface hub path to store the quantized model and tokenizer.", + ) args = parser.parse_args() # Convert precision argument to torch dtype precision_dtype = getattr(torch, args.precision, torch.bfloat16) - ppl = wikitext2_ppl( + result = quantize_and_eval( args.repo, args.quant, args.tasks, - args.calibration_samples, + args.max_seq_length, + args.calibration_limit, args.validation_size, args.device, args.precision, - args.seq_len, args.compile, args.model_save_path, + args.model_save_hf_hub_path, ) - print(f"{args.quant} Results: {ppl}") + print(f"{args.quant} Results: {result}") diff --git a/torchao/prototype/moe_quant/utils.py b/torchao/prototype/moe_quant/utils.py index 0e75de2ee4..28291afdf4 100644 --- a/torchao/prototype/moe_quant/utils.py +++ b/torchao/prototype/moe_quant/utils.py @@ -20,18 +20,7 @@ dataclass, register_quantize_module_handler, ) -from torchao.utils import fill_defaults - - -class DummyModule(torch.nn.Module): - """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a - DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. - """ - - def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - super().__init__() - self.weight = weight - self.bias = bias +from torchao.utils import DummyModule, fill_defaults class FakeExtraDimTensor(torch.Tensor): diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index 6c433844a6..005bc8d32d 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -33,8 +33,8 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor): scale (torch.Tensor): The scale tensor to be applied to activation. """ - original_weight_tensor: torch.Tensor - scale: torch.Tensor + tensor_data_names = ["original_weight_tensor", "scale"] + tensor_attribute_names = [] def __new__( cls, @@ -57,21 +57,8 @@ def __init__( self.original_weight_tensor = original_weight_tensor self.scale = scale - def __repr__(self): - return f"WeightTensorWithLinearActivationScaleMetadata({self.original_weight_tensor}, scale={self.scale}" - - def __tensor_flatten__(self): - tensor_data = ["original_weight_tensor", "scale"] - return tensor_data, [] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - return cls( - tensor_data_dict["original_weight_tensor"], - tensor_data_dict["scale"], - ) + def _quantization_type(self): + return f"{self.__class__}" @staticmethod def _quantized_linear_op( @@ -93,20 +80,6 @@ def from_float( ): return cls(input_float, scale) - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.original_weight_tensor), - fn(self.scale), - ) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.original_weight_tensor.to(device), - self.scale.to(device), - ) - implements = WeightTensorWithLinearActivationScaleMetadata.implements @@ -126,28 +99,13 @@ def _(func, types, args, kwargs): ) -@implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -@implements(aten._to_copy.default) +@implements(aten.slice.Tensor) def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + self = args[0] + new = self.__class__( + func(self.original_weight_tensor, *args[1:], **kwargs), self.scale ) + return return_and_correct_aliasing(func, args, kwargs, new) @implements(aten.t.default) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ab820193b8..33439552a0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -545,10 +545,10 @@ def _quantization_type(weight: torch.Tensor): if hasattr(weight, "_quantization_type"): return f"{weight.__class__.__name__}({weight._quantization_type()})" - if type(weight) is torch.Tensor: - return "not quantized" + if type(weight) is torch.Tensor or isinstance(weight, torch.nn.Parameter): + return f"Tensor: {type(weight)}" - return "not recognized" + return f"not recognized: {type(weight)}" def _linear_extra_repr(self): diff --git a/torchao/utils.py b/torchao/utils.py index 40a7b6ed16..7bd1276605 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -11,7 +11,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable +from typing import Any, Callable, Optional import torch import torch.nn.utils.parametrize as parametrize @@ -43,6 +43,7 @@ "is_sm_at_least_89", "is_sm_at_least_90", "is_package_at_least", + "DummyModule", ] @@ -882,3 +883,13 @@ def _is_fbgemm_genai_gpu_available(): return False return True + +class DummyModule(torch.nn.Module): + """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a + DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. + """ + + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + super().__init__() + self.weight = weight + self.bias = bias