Skip to content

[WIP] Make AWQ more general #2400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions torchao/_models/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,13 @@ def _model_call(self, inps):

max_seq_length = min(max(inps.size()), self.max_length)
with torch.device(self._device):
self._model.setup_caches(self.batch_size, max_seq_length)
if hasattr(self._model, "setup_caches"):
self._model.setup_caches(self.batch_size, max_seq_length)
logits = self._model(*input)
from transformers.modeling_outputs import CausalLMOutputWithPast

if isinstance(logits, CausalLMOutputWithPast):
logits = logits.logits
return logits

def run_eval(self, tasks, limit):
Expand All @@ -84,7 +89,11 @@ def eot_token_id(self):
try:
return self.tokenizer.eos_id()
except:
return self.tokenizer.eos_id
try:
return self.tokenizer.eos_id
except:
idx = self.tokenizer.all_special_tokens.index("<|endoftext|>")
return self.tokenizer.all_special_ids[idx]

@property
def max_length(self):
Expand All @@ -102,8 +111,8 @@ def batch_size(self):
def device(self):
return self._device

def tok_decode(self, tokens):
decoded = self.tokenizer.decode(tokens)
def tok_decode(self, tokens, **kwargs):
decoded = self.tokenizer.decode(tokens, **kwargs)
return decoded

def tok_encode(self, string: str, **kwargs):
Expand All @@ -115,8 +124,8 @@ def tok_encode(self, string: str, **kwargs):
tokens = [self.tokenizer.bos_id] + tokens
return tokens

def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")
# def _model_generate(self, context, max_length, stop, **generation_kwargs):
# super()._model_generate(context, max_length, stop, **generation_kwargs)


class LMEvalInputRecorder(TransformerEvalWrapper):
Expand Down
81 changes: 81 additions & 0 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,87 @@ def run_evaluation(
quantize_(
model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64)
)
elif quantization.startswith("awq-uintx"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3

if not TORCH_VERSION_AT_LEAST_2_3:
print("Awq requires torch2.3+")
exit()
from torchao.prototype.awq import (
AWQObservedLinear,
awq_uintx,
insert_awq_observer_,
)

quant_dtype = quantization.split("-")[1]
group_size = int(quantization.split("-")[2])
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
model = model.to(device)
# get calibration data
insert_awq_observer_(
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
)
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
max_seq_length=256,
input_prep_func=prepare_inputs_for_model,
device=device,
).run_eval(
tasks=["wikitext"],
limit=1,
)
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = "hqq" in quantization
quantize_(
model,
awq_uintx(
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
),
is_observed_linear,
)

elif quantization.startswith("awq-8da4w"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3

if not TORCH_VERSION_AT_LEAST_2_3:
print("Awq requires torch2.3+")
exit()
from torchao.prototype.awq import (
AWQObservedLinear,
awq_uintx,
insert_awq_observer_,
)

quant_dtype = quantization.split("-")[1]
group_size = int(quantization.split("-")[2])
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
model = model.to(device)
# get calibration data
insert_awq_observer_(
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
)
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
max_seq_length=256,
input_prep_func=prepare_inputs_for_model,
device=device,
).run_eval(
tasks=["wikitext"],
limit=1,
)
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = "hqq" in quantization
quantize_(
model,
awq_uintx(
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
),
is_observed_linear,
)

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .api import awq_uintx, insert_awq_observer_
from .api import AWQConfig, awq_uintx, insert_awq_observer_
from .core import AWQObservedLinear

__all__ = [
"awq_uintx",
"insert_awq_observer_",
"AWQObservedLinear",
"AWQConfig",
]
Loading
Loading