Skip to content

Add option to move param to device before quantization #699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from torchao.utils import unwrap_tensor_subclass
import copy
import tempfile
import gc
from torch.testing._internal.common_utils import TestCase


Expand Down Expand Up @@ -680,6 +681,29 @@ def test_quantized_tensor_subclass_save_load_map_location(self):
res = m_copy(*example_inputs)
self.assertEqual(res, ref)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_model_streaming(self):
def reset_memory():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

reset_memory()
m = ToyLinearModel()
quantize_(m.to(device="cuda"), int8_weight_only())
memory_baseline = torch.cuda.max_memory_allocated()

del m
reset_memory()
m = ToyLinearModel()
quantize_(m, int8_weight_only(), device="cuda")
memory_streaming = torch.cuda.max_memory_allocated()

for param in m.parameters():
assert param.is_cuda
self.assertLess(memory_streaming, memory_baseline)


if __name__ == "__main__":
unittest.main()
19 changes: 17 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def _replace_with_custom_fn_if_matches_filter(
replacement_fn,
filter_fn,
cur_fqn="",
device=None,
) -> None:
"""
Recursively replaces each child module in `model` with the result of `replacement_fn(child)`
Expand All @@ -171,20 +172,25 @@ def _replace_with_custom_fn_if_matches_filter(
replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules.
filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace.
cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "".
device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None.

Returns:
None
"""
if filter_fn(model, cur_fqn[:-1]):
if device is not None:
model.to(device=device) # move to device before quantization
model = replacement_fn(model)
return model
else:
for name, child in model.named_children():
new_child = _replace_with_custom_fn_if_matches_filter(
child, replacement_fn, filter_fn, f"{cur_fqn}{name}."
child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device
)
if new_child is not child:
setattr(model, name, new_child)
if device is not None:
model.to(device=device) # move parent module to device
return model


Expand Down Expand Up @@ -269,7 +275,13 @@ def insert_subclass(lin):

return insert_subclass

def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
def quantize_(
model: torch.nn.Module,
apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
set_inductor_config: bool = True,
device: Optional[torch.types.Device] = None,
):
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace

Args:
Expand All @@ -278,6 +290,8 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
the weight of the module
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`.
Defaults to None (do not change device).

Example::

Expand Down Expand Up @@ -329,6 +343,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
model,
apply_tensor_subclass,
_is_linear if filter_fn is None else filter_fn,
device=device,
)

def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading