From 60f88528fa7d1b43e6ec9531e10754ba670dd160 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 17 Jun 2025 21:24:28 -0700 Subject: [PATCH] Make AWQ more general Summary: * Added AWQConfig that takes a base config and made corresponding changes in other parts of the flow Test Plan: Tested on Phi4-mini and Qwen3-8B Qwen3-8B |Task | calibration_limit | no-awq | awq | |-----+------------------+ ------+ ------+ |leaderboard_math_hard (v3) | 2 | 0.3543 | 0.4371 | |gpqa_main_zeroshot | 50 | 0.32 | 0.36 | |mmlu | 5 | 0.7372 | 0.7463 | |bbh | 1 | 0.7385 | 0.7556| Phi4-mini | Task | calibration_limit | no-awq | awq | |------+------------------+--------+------| | mmlu_pro | 2 | 0.4057 | 0.4757 | | gsm8k | 5 | 0.72 | 0.76 | Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/_eval.py | 21 +- torchao/_models/llama/eval.py | 81 +++++ torchao/prototype/awq/__init__.py | 3 +- torchao/prototype/awq/api.py | 199 ++++-------- torchao/prototype/awq/core.py | 132 ++------ torchao/prototype/awq/example.py | 303 +++++++++++++++--- torchao/prototype/moe_quant/utils.py | 13 +- .../quantization/linear_activation_scale.py | 7 +- torchao/quantization/quant_api.py | 6 +- torchao/utils.py | 13 +- 10 files changed, 470 insertions(+), 308 deletions(-) diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index faf059c400..fca5b7f0af 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -57,8 +57,13 @@ def _model_call(self, inps): max_seq_length = min(max(inps.size()), self.max_length) with torch.device(self._device): - self._model.setup_caches(self.batch_size, max_seq_length) + if hasattr(self._model, "setup_caches"): + self._model.setup_caches(self.batch_size, max_seq_length) logits = self._model(*input) + from transformers.modeling_outputs import CausalLMOutputWithPast + + if isinstance(logits, CausalLMOutputWithPast): + logits = logits.logits return logits def run_eval(self, tasks, limit): @@ -84,7 +89,11 @@ def eot_token_id(self): try: return self.tokenizer.eos_id() except: - return self.tokenizer.eos_id + try: + return self.tokenizer.eos_id + except: + idx = self.tokenizer.all_special_tokens.index("<|endoftext|>") + return self.tokenizer.all_special_ids[idx] @property def max_length(self): @@ -102,8 +111,8 @@ def batch_size(self): def device(self): return self._device - def tok_decode(self, tokens): - decoded = self.tokenizer.decode(tokens) + def tok_decode(self, tokens, **kwargs): + decoded = self.tokenizer.decode(tokens, **kwargs) return decoded def tok_encode(self, string: str, **kwargs): @@ -115,8 +124,8 @@ def tok_encode(self, string: str, **kwargs): tokens = [self.tokenizer.bos_id] + tokens return tokens - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") + # def _model_generate(self, context, max_length, stop, **generation_kwargs): + # super()._model_generate(context, max_length, stop, **generation_kwargs) class LMEvalInputRecorder(TransformerEvalWrapper): diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 8ee15f1fd3..49e46c3d48 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -237,6 +237,87 @@ def run_evaluation( quantize_( model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64) ) + elif quantization.startswith("awq-uintx"): + from torchao._models._eval import TransformerEvalWrapper + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + + if not TORCH_VERSION_AT_LEAST_2_3: + print("Awq requires torch2.3+") + exit() + from torchao.prototype.awq import ( + AWQObservedLinear, + awq_uintx, + insert_awq_observer_, + ) + + quant_dtype = quantization.split("-")[1] + group_size = int(quantization.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.uint8) + model = model.to(device) + # get calibration data + insert_awq_observer_( + model, 1, 256, quant_dtype=quant_dtype, group_size=group_size + ) + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=256, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=["wikitext"], + limit=1, + ) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quantization + quantize_( + model, + awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ), + is_observed_linear, + ) + + elif quantization.startswith("awq-8da4w"): + from torchao._models._eval import TransformerEvalWrapper + from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + + if not TORCH_VERSION_AT_LEAST_2_3: + print("Awq requires torch2.3+") + exit() + from torchao.prototype.awq import ( + AWQObservedLinear, + awq_uintx, + insert_awq_observer_, + ) + + quant_dtype = quantization.split("-")[1] + group_size = int(quantization.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.uint8) + model = model.to(device) + # get calibration data + insert_awq_observer_( + model, 1, 256, quant_dtype=quant_dtype, group_size=group_size + ) + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=256, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=["wikitext"], + limit=1, + ) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quantization + quantize_( + model, + awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ), + is_observed_linear, + ) if compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index 570b0821d4..4f34d5375a 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,8 +1,9 @@ -from .api import awq_uintx, insert_awq_observer_ +from .api import AWQConfig, awq_uintx, insert_awq_observer_ from .core import AWQObservedLinear __all__ = [ "awq_uintx", "insert_awq_observer_", "AWQObservedLinear", + "AWQConfig", ] diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 5806c29ce6..5120147979 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -5,182 +5,107 @@ # LICENSE file in the root directory of this source tree. import types from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import torch import torchao from torchao.core.config import AOBaseConfig -from torchao.dtypes import ( - Int4XPULayout, - Layout, - TensorCoreTiledLayout, - to_affine_quantized_intx, -) -from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata -from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_api import ( _linear_extra_repr, - _replace_with_custom_fn_if_matches_filter, -) -from torchao.quantization.quant_primitives import ( - _DTYPE_TO_QVALUE_BOUNDS, - MappingType, - ZeroPointDomain, ) from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) +from torchao.utils import DummyModule from .core import ( AWQObservedLinear, AWQObserver, ) -assert len(_DTYPE_TO_BIT_WIDTH) > 0, ( - "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" -) - - -def insert_awq_observer_( - model: torch.nn.Module, - n_validation_examples: int, - validation_sequence_len: int, - quant_dtype: torch.dtype = torch.uint4, - scale_search_space_size: int = 20, - group_size: int = 128, -): - """ - Inserts AWQObserver into Linear layers of a given model. - - Args: - model: The model to be modified (in place). Ensure model is on the desired device for calibration - n_validation_examples: Number of examples used to validate scale options - validation_sequence_len: Number of tokens in each validation example - quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 - scale search space size: how many different scale options to try. Original AWQ implementation uses 20. A larger size can lead to better results but takes longer to calibrate - group_size: Quantization granularity. Use -1 for channel wise quantization - """ - _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, ( - "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" - ) - # AQT config - mapping_type = MappingType.ASYMMETRIC - quantization_granularity = PerGroup(group_size) - quant_min = 0 - quant_max = ( - 255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1 - ) - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - - def replace_with_observer(layer): - # creates observer and replaces linear layers with AWQObservedLinear layers - observer = AWQObserver( - layer.weight, - layer.bias, - quantization_granularity, - mapping_type, - quant_dtype, - n_validation_examples, - validation_sequence_len, - scale_search_space_size, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - zero_point_dtype=zero_point_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) - return AWQObservedLinear.from_float(layer, observer) - - _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) - @dataclass -class AWQUIntXConfig(AOBaseConfig): +class AWQConfig(AOBaseConfig): """ Configuration for quantizing linear layers when passed into quantize_() Args: - quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 - `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` - group_size: Quantization granularity. Use -1 for channel wise quantization - weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used + base_config (AOBaseConfig): The quantization config that we can apply awq on top of, e.g. 8da4w, int4 weight only + step (str): a string of "prepare", "convert" or "load" indicating the step of AWQ process + prepare: insert AWQ Observers to linear + convert: convert the observed linear modules to linear modules with awq quantized weights + load: convert the floating point model to a dummy awq quantized model + example_input_shape (Optional[List[int]])): This is used for load step to initialize a random example input + scale_search_space_size (int): the number of scales to search for set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. """ - quant_dtype: torch.dtype = torch.uint4 - layout: Optional[Layout] = TensorCoreTiledLayout(inner_k_tiles=8) - group_size: int = 64 - use_hqq: bool = False + base_config: AOBaseConfig + step: str + example_input_shape: Optional[List[int]] = None + scale_search_space_size: int = 20 set_inductor_config: bool = True + def __post_init__(self): + OPTIONS = ["prepare", "convert", "load"] + assert self.step in OPTIONS, f"Only {OPTIONS} are supported, got {self.step}" -# for bc -awq_uintx = AWQUIntXConfig - -@register_quantize_module_handler(AWQUIntXConfig) -def _awq_uintx_transform( +@register_quantize_module_handler(AWQConfig) +def _awq_transform( module: torch.nn.Module, - config: AWQUIntXConfig, + config: AWQConfig, ) -> torch.nn.Module: - quant_dtype = config.quant_dtype - group_size = config.group_size - use_hqq = config.use_hqq if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - observed_linear = module - assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, ( - "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" - ) + step = config.step + scale_search_space_size = config.scale_search_space_size + observed_linear = None + base_config = config.base_config - equalization_scale = observed_linear.act_obs.calculate_qparams() - # AQT config - if quant_dtype == torch.uint4: - target_dtype = torch.int32 - eps = 1e-6 - preserve_zero = False - _layout = config.layout - if isinstance(_layout, Int4XPULayout): - zero_point_dtype = torch.int8 - zero_point_domain = ZeroPointDomain.INT - else: - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + if step == "prepare": + observer = AWQObserver( + module.weight, + module.bias, + base_config, + scale_search_space_size, + ) + return AWQObservedLinear.from_float(module, observer) + elif step == "load": + # loading from pre-quantized checkpoint + observer = AWQObserver( + module.weight, + module.bias, + base_config, + scale_search_space_size, + ) + observed_linear = AWQObservedLinear.from_float(module, observer) + assert config.example_input_shape is not None, ( + "When step is load, we expect example_input_shape to be specified as well" + ) + example_input = torch.randn( + config.example_input_shape, + device=module.weight.device, + dtype=module.weight.dtype, + ) + observed_linear(example_input) else: - target_dtype = torch.uint8 - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - _layout = UintxLayout(quant_dtype) + if not isinstance(module, AWQObservedLinear): + print(f"convert: module is not AWQObservedLinear, skipping: {type(module)}") + return module + observed_linear = module - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] - qw = to_affine_quantized_intx( - observed_linear.weight * equalization_scale, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - _layout=_layout, - use_hqq=use_hqq, - ) + assert observed_linear is not None + equalization_scale = observed_linear.act_obs.calculate_qparams() + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + dummy_mod = DummyModule(observed_linear.weight * equalization_scale) + quant_mod = base_config_handler(dummy_mod, config.base_config) + qw = quant_mod.weight qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) linear = torch.nn.Linear( @@ -191,6 +116,6 @@ def _awq_uintx_transform( dtype=observed_linear.weight.dtype, ) linear.weight = torch.nn.Parameter(qw, requires_grad=False) - linear.extra_repr = types.MethodType(_linear_extra_repr, module) + linear.extra_repr = types.MethodType(_linear_extra_repr, linear) linear.bias = observed_linear.bias return linear diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index e5ee96fea2..bbe825298f 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -8,140 +8,78 @@ import torch import torch.nn.functional as F -from torchao.dtypes import to_affine_quantized_intx -from torchao.dtypes.uintx.uintx_layout import UintxLayout -from torchao.quantization.granularity import Granularity -from torchao.quantization.observer import ( - AffineQuantizedObserverBase, -) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, +from torchao.core.config import AOBaseConfig +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, ) +from torchao.utils import DummyModule + + +@torch.no_grad() +def get_act_scale(x): + return x.abs().view(-1, x.shape[-1]).mean(0) -class AWQObserver(AffineQuantizedObserverBase): +class AWQObserver(torch.nn.Module): def __init__( self, weight: torch.Tensor, - bias: torch.Tensor, - quantization_granularity: Granularity, - mapping_type: MappingType, - target_dtype: torch.dtype, - n_validation_examples: int, - validation_sequence_len: int, + bias: Optional[torch.Tensor], + base_config: AOBaseConfig, scale_search_space_size: int = 20, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, - scale_dtype: Optional[torch.dtype] = None, - zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero: Optional[bool] = True, - zero_point_domain=ZeroPointDomain.INT, ): """ A custom observer for Activation aware Weight Quantization (AWQ) + Note: this only applies to weight only quantization: https://github.com/pytorch/ao/issues/2388#issuecomment-3062863647 Args: - weight: The weight tensor to be observed. - bias: The bias tensor to be observed. - quantization_granularity: Granularity which specifies how many weights share the same scale/zero point - input_dtype: The data type of the input tensor. - mapping_type: Always set to asymmetric - target_dtype: The target data type of the quantized tensor - n_validation_examples: Number of examples used to calibrate observer - validation_sequence_len: Number of tokens in each example - scale_search_space_size: The number of scales to search for. - quant_min: The minimum quantized value - quant_max: The maximum quantized value - eps: The minimum scale. - scale_dtype: The data type of the scale tensor. - zero_point_dtype: The data type of the zero point tensor. - preserve_zero: A flag to indicate whether we need zero to be exactly - representable or not. - zero_point_domain: The domain of the zero point. + weight (torch.Tensor: The weight tensor to be observed. + bias (Optional[torch.Tensor]): The bias tensor to be observed. + config (AOBaseConfig): the configuration for quantize_, that we'll use to apply awq on top of + scale_search_space_size (int): search space size for searching the best scale for weight and input activation """ - super().__init__( - mapping_type, - target_dtype, - quantization_granularity, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - ) - self.quantization_granularity = quantization_granularity + super().__init__() + self.base_config = base_config self.weight = weight self.bias = bias - self.n_validation_examples = n_validation_examples - self.validation_sequence_len = validation_sequence_len - self.calibration_token_count = 0 self.inputs = [] - self.outputs = [] self.scale_options = scale_search_space_size self.device = self.weight.device - self.average = torch.zeros((1, weight.shape[1]), device=self.device) if self.bias is not None: self.bias.to(self.device) @torch.no_grad() def forward(self, input: torch.Tensor, output: torch.Tensor): - # import pdb - # pdb.set_trace() - # print(input.shape, input.abs().sum(1).shape, self.average.shape) - if len(self.inputs) < self.n_validation_examples: - self.inputs.append(input.to("cpu")) - self.outputs.append(output.to("cpu")) - self.calibration_token_count += input.shape[-2] - self.average += input.abs().sum(-2) + self.inputs.append(input.to("cpu")) def calculate_qparams(self): - # import pdb - # pdb.set_trace() - assert self.outputs != None, ( + assert self.inputs != None, ( "calibrate observer first by running model on exemplar data" ) - self.average /= self.calibration_token_count - for i in range(self.n_validation_examples): + for i in range(len(self.inputs)): self.inputs[i] = self.inputs[i].to(self.device) - self.outputs[i] = self.outputs[i].to(self.device) + if self.bias is not None: + self.bias = self.bias.to(self.device) + + acc = torch.cat(self.inputs, dim=-2) + x_max = get_act_scale(acc) best_loss = float("inf") best_scales = None for i in range(self.scale_options): ratio = i * 1 / self.scale_options - scales = self.average.pow(ratio).to(self.weight.dtype) + scales = x_max.pow(ratio).to(self.weight.dtype).clamp(min=1e-4).view(-1) scales = scales / (scales.max() * scales.min()).sqrt() - layout = UintxLayout(self.target_dtype) - # regardless of weight dtype, we have to store as packed uint8 tensors - tensor_dtype = torch.uint8 - w = to_affine_quantized_intx( - self.weight * scales, - self.mapping_type, - (1, self.quantization_granularity.group_size), - tensor_dtype, - quant_min=self.quant_min, - quant_max=self.quant_max, - eps=self.eps, - scale_dtype=self.scale_dtype, - zero_point_dtype=self.zero_point_dtype, - preserve_zero=self.preserve_zero, - zero_point_domain=self.zero_point_domain, - _layout=layout, - ) - loss = 0 - for i in range(self.n_validation_examples): - q_out = F.linear(self.inputs[i] / scales, w, self.bias) - loss += (self.outputs[i] - q_out).pow(2).mean().item() + config_handler = _QUANTIZE_CONFIG_HANDLER[type(self.base_config)] + dummy_mod = DummyModule(self.weight * scales) + quant_mod = config_handler(dummy_mod, self.base_config) + w = quant_mod.weight + orig_out = F.linear(acc, self.weight, self.bias) + q_out = F.linear(acc / scales, w, self.bias) + loss = (orig_out - q_out).pow(2).mean().item() if loss < best_loss: best_scales = scales best_loss = loss - for i in range(self.n_validation_examples): - self.inputs[i].to("cpu") - self.outputs[i].to("cpu") return best_scales.detach() diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 7ff6092b05..f3c9c1aa33 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -9,11 +9,21 @@ import torch from datasets import load_dataset from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer - -from torchao.dtypes import Int4XPULayout -from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ -from torchao.quantization import int4_weight_only, quantize_ +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +from torchao.dtypes import QDQLayout +from torchao.prototype.awq import ( + AWQConfig, +) +from torchao.quantization import ( + GemliteUIntXWeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + ModuleFqnToConfig, + PerAxis, + quantize_, +) +from torchao.quantization.granularity import PerGroup # adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 @@ -111,6 +121,7 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): "hellaswag", "gsm8k", "mmlu", + "bbh", ] results = {} if "PPL" in tasks: @@ -180,6 +191,15 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): print("MMLU avg acc", np.mean(k)) results["mmlu"] = np.mean(k) + if "bbh" in tasks: + for task in [("leaderboard_bbh", 3)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + results["bbh"] = results[tag] + return results @@ -187,13 +207,14 @@ def wikitext2_ppl( repo_id: str, quant: str, tasks: list[str], - calibration_size: int, + max_seq_length: int, + calibration_limit: int, validation_size: int, device: str, precision: torch.dtype, - sequence_length: int, compile: bool, model_save_path: str, + model_save_hf_hub_path: str, ): print(f"Loading model on {device}...") torch.manual_seed(34) @@ -206,60 +227,234 @@ def wikitext2_ppl( .to(device) ) print(f"Time to load model: {time.time() - t0:.02f} seconds") - if quant.startswith("awq"): - quant_dtype = quant.split("-")[1] + if quant.startswith("awq-8da4w"): + # applying int8 dynamic activation + int4 weight quant to linear + # and int8 weight only quant to `model.embed_tokens` + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + group_size = int(quant.split("-")[2]) - quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) - print(f"running {quant_dtype} calibration") + quant_dtype = torch.int4 + base_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=quant_dtype, + weight_granularity=PerGroup(group_size), + weight_scale_dtype=torch.bfloat16, + layout=QDQLayout(), + ) + print(f"running {quant_dtype} prepare and calibrate") t0 = time.time() - # insert observers to find average magnitude and calculate scales - insert_awq_observer_( + awq_config = AWQConfig(base_config, step="prepare") + + quant_config = ModuleFqnToConfig( + {"_default": awq_config, "model.embed_tokens": embedding_config} + ) + quantize_( model, - validation_size, - sequence_length, - quant_dtype=quant_dtype, - group_size=group_size, + quant_config, ) - calibration_data = get_calib_dataset( - tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length + from torchao._models._eval import TransformerEvalWrapper + + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=max_seq_length, + device=device, + ).run_eval( + tasks=tasks, + limit=calibration_limit, ) - for batch in calibration_data: - model(batch.to(device)) - batch.to("cpu") - print(f"time for calibration: {time.time() - t0:.02f} seconds") - - is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - use_hqq = "hqq" in quant - print(f"running {quant_dtype} quantization") + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + + print(f"running {quant_dtype} convert") t0 = time.time() - awq_uintx_config = awq_uintx( - quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + awq_config = AWQConfig(base_config, step="convert") + quant_config = ModuleFqnToConfig( + {"_default": awq_config, "model.embed_tokens": None} ) - if "xpu" in device: - awq_uintx_config.layout = Int4XPULayout() + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + print("model after awq:", model) + + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + awq_load_config = AWQConfig(base_config, step="load") + quant_config = ModuleFqnToConfig( + {"_default": awq_load_config, "model.embed_tokens": None} + ) + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + + elif quant.startswith("8da4w"): + group_size = int(quant.split("-")[1]) + quant_dtype = torch.int4 + base_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=quant_dtype, + weight_granularity=PerGroup(group_size), + weight_scale_dtype=torch.bfloat16, + layout=QDQLayout(), + ) + + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + + quant_config = ModuleFqnToConfig( + {"_default": base_config, "model.embed_tokens": embedding_config} + ) + quantize_(model, quant_config) + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + + elif quant.startswith("awq-int4wo"): + group_size = int(quant.split("-")[2]) + print(f"running {quant} quantization with group size {group_size}") + from torchao.quantization import FbgemmConfig + + # use_hqq = True + # base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) + base_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, + ) + print(f"running {quant} prepare and calibrate") + t0 = time.time() + awq_config = AWQConfig(base_config, step="prepare") + + quant_config = awq_config quantize_( model, - awq_uintx_config, - is_observed_linear, + quant_config, + ) + from torchao._models._eval import TransformerEvalWrapper + + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=max_seq_length, + device=device, + ).run_eval( + tasks=tasks, + limit=calibration_limit, ) - print(f"time for quantization: {time.time() - t0:.02f} seconds") + + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + print(f"running {quant} convert") + t0 = time.time() + awq_config = AWQConfig(base_config, step="convert") + quant_config = awq_config + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + print("model after awq:", model) + if model_save_path is not None: print(f"Saving model to {model_save_path}") torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + quant_config = AWQConfig(base_config, step="load") + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) elif quant.startswith("int4wo"): group_size = int(quant.split("-")[1]) - use_hqq = "hqq" in quant print(f"running {quant} quantization with group size {group_size}") - int4_weight_only_config = int4_weight_only( - group_size=group_size, use_hqq=use_hqq + # TODO: enable after refactor: https://github.com/pytorch/ao/pull/2474 + # use_hqq = "hqq" in quant + # base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) + int4_weight_only_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, ) - if "xpu" in device: - int4_weight_only_config.layout = Int4XPULayout() + # TODO: enable this again when it's supported + # if "xpu" in device: + # int4_weight_only_config.layout = Int4XPULayout() quantize_(model, int4_weight_only_config) + + elif quant.startswith("awq-gemlite"): + group_size = int(quant.split("-")[2]) + print(f"running {quant} quantization with group size {group_size}") + base_config = GemliteUIntXWeightOnlyConfig(group_size=group_size) + + print(f"running {quant} prepare and calibrate") + t0 = time.time() + awq_config = AWQConfig(base_config, step="prepare") + + quant_config = awq_config + quantize_( + model, + quant_config, + ) + from torchao._models._eval import TransformerEvalWrapper + + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=max_seq_length, + device=device, + ).run_eval( + tasks=tasks, + limit=calibration_limit, + ) + + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + print(f"running {quant} convert") + t0 = time.time() + awq_config = AWQConfig(base_config, step="convert") + quant_config = awq_config + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + print("model after awq:", model) + + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + quant_config = AWQConfig(base_config, step="load") + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + elif quant.startswith("gemlite"): + group_size = int(quant.split("-")[1]) + print(f"running {quant} quantization with group size {group_size}") + config = GemliteUIntXWeightOnlyConfig(group_size=group_size) + + print(f"running {quant} prepare and calibrate") + t0 = time.time() + quantize_(model, config) + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + quant_config = AWQConfig(config, step="load") + model.quantization_config = TorchAoConfig(quant_config) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) if compile: model = torch.compile(model) - return benchmark(model, tokenizer, sequence_length, tasks=tasks, device=device) + return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device) if __name__ == "__main__": @@ -268,20 +463,21 @@ def wikitext2_ppl( ) # Optional arguments with default values - parser.add_argument("repo", type=str, help="Repository ID of the model.") + parser.add_argument("--repo", type=str, help="Repository ID of the model.") parser.add_argument( - "quant", + "--quant", type=str, - help="Quantization method. Options are either awq-uint- for x =[1..8], int4wo-, or int4wo--hqq.", + help="Quantization method. Options are either awq-uint- for x =[1..8], awq-8da4w-, int4wo-, or int4wo--hqq.", ) parser.add_argument( "--tasks", - type=list[str], + nargs="+", + type=str, help="Task to benchmark model on. Either PPL or QA", default=["PPL"], ) parser.add_argument( - "--calibration_samples", + "--calibration_limit", type=int, default=10, help="Number of samples to use for calibration. Default is 10.", @@ -302,10 +498,10 @@ def wikitext2_ppl( help="Precision type. Default is 'bfloat16'.", ) parser.add_argument( - "--seq_len", + "--max_seq_len", type=int, - default=512, - help="Length of examples to calibrate and evaluate model on. Default is 512", + default=2048, + help="Maximum sequence length of examples to calibrate and evaluate model on. Default is 2048", ) parser.add_argument( "--compile", @@ -318,6 +514,12 @@ def wikitext2_ppl( default=None, help="Path to store the scale values.", ) + parser.add_argument( + "--model_save_hf_hub_path", + type=str, + default=None, + help="Huggingface hub path to store the quantized model and tokenizer.", + ) args = parser.parse_args() @@ -327,13 +529,14 @@ def wikitext2_ppl( args.repo, args.quant, args.tasks, - args.calibration_samples, + args.max_seq_length, + args.calibration_limit, args.validation_size, args.device, args.precision, - args.seq_len, args.compile, args.model_save_path, + args.model_save_hf_hub_path, ) print(f"{args.quant} Results: {ppl}") diff --git a/torchao/prototype/moe_quant/utils.py b/torchao/prototype/moe_quant/utils.py index 0e75de2ee4..28291afdf4 100644 --- a/torchao/prototype/moe_quant/utils.py +++ b/torchao/prototype/moe_quant/utils.py @@ -20,18 +20,7 @@ dataclass, register_quantize_module_handler, ) -from torchao.utils import fill_defaults - - -class DummyModule(torch.nn.Module): - """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a - DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. - """ - - def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - super().__init__() - self.weight = weight - self.bias = bias +from torchao.utils import DummyModule, fill_defaults class FakeExtraDimTensor(torch.Tensor): diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index 6c433844a6..a794ae8289 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -73,6 +73,11 @@ def __tensor_unflatten__( tensor_data_dict["scale"], ) + def _quantization_type(self): + return ( + f"original_weight_tensor={self.original_weight_tensor}, scale={self.scale}" + ) + @staticmethod def _quantized_linear_op( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor @@ -126,7 +131,7 @@ def _(func, types, args, kwargs): ) -@implements(aten.detach.default) +@implements([aten.detach.default, aten.alias.default]) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index d692b52bdc..e7ed3a5030 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -545,10 +545,10 @@ def _quantization_type(weight: torch.Tensor): if hasattr(weight, "_quantization_type"): return f"{weight.__class__.__name__}({weight._quantization_type()})" - if type(weight) is torch.Tensor: - return "not quantized" + if type(weight) is torch.Tensor or isinstance(weight, torch.nn.Parameter): + return f"Tensor: {type(weight)}" - return "not recognized" + return f"not recognized: {type(weight)}" def _linear_extra_repr(self): diff --git a/torchao/utils.py b/torchao/utils.py index c56b607b7b..0a06a8ef4f 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -11,7 +11,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable +from typing import Any, Callable, Optional import torch import torch.nn.utils.parametrize as parametrize @@ -42,6 +42,7 @@ "is_sm_at_least_89", "is_sm_at_least_90", "is_package_at_least", + "DummyModule", ] @@ -732,3 +733,13 @@ def _is_fbgemm_genai_gpu_available(): return False return True + +class DummyModule(torch.nn.Module): + """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a + DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. + """ + + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + super().__init__() + self.weight = weight + self.bias = bias