-
Notifications
You must be signed in to change notification settings - Fork 298
Add Float8Tensor #2463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Float8Tensor #2463
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2463
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 2 Unrelated FailuresAs of commit c5f71e6 with merge base 975bd57 ( CANCELLED JOB - The following job was cancelled. Please retry:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Splits out the float8 rowwise quantized path (both act and weight) of AQT to Float8RowwiseTensor Next: could potentially incorporate the per tensor activation path there as well Next: we can split the per tensor weight path to another Tensor as well, so we can deprecate AQT path for float8 Test Plan: python test/dtypes/test_affine_quantized_float.py python test/quantization/quantize_/test_float8_rowwise_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2463, branch: jerryzh168/stack/9
da79207
to
5cae4d0
Compare
5cae4d0
to
33ca58e
Compare
33ca58e
to
897ec7e
Compare
7897dcf
to
99a1bb1
Compare
99a1bb1
to
7e9f224
Compare
7e9f224
to
442bd6c
Compare
torchao/quantization/quant_api.py
Outdated
|
||
quantized_weight = to_linear_activation_quantized( | ||
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs | ||
act_quant_config = Float8TensorConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the user should specify this, it's confusing to have some kinds of configs specified by user and some kinds specified in glue code.
i.e. can we just remove Float8TensorConfig
and inline the logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can expose this to the workflow config as a user facing arg I think, but another purpose for this config is that DynamicActivationQuantizationWrapper
can be used for different types of activation quantization, instead of we passing around the input_quant_func and kwargs as we do before, what do you have in mind for how to work with DynamicActivationQuantizationWrapper
when we inline this config?
torchao/quantization/quantize_/common/dynamic_activation_quantization_wrapper.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/common/dynamic_activation_quantization_wrapper.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/common/dynamic_activation_quantization_wrapper.py
Outdated
Show resolved
Hide resolved
1f0adfc
to
2568d3d
Compare
] = {} | ||
|
||
|
||
def register_quantize_tensor_handler(config_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should remove this handler abstraction and just inline the code
if isinstance(config_type, Float8ConfigNameTBD):
tensor_q = Float8Tensor.to_float8(...)
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would delete this and inline, see comment below, to simplify and make the code easier to follow
return input_tensor | ||
original_weight_tensor = weight_tensor.original_weight_tensor | ||
c = weight_tensor.act_tensor_config | ||
quantized_input = quantize_tensor(input_tensor, c) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline the logic here please so there is no need to look at quantize_tensor
and think about what gets dispatched where
if isinstance(c, Float8ConfigNameTBD):
quantized_input = Float8Tensor.to_float8(input_tensor, ...)
elif ...:
...
this way it's 10x easier to follow the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would mean adding each new type of activation we'll have to modify dynamic_activation_quantization_wrapper.py
, but I'll change assuming you ha`ve considered this. also this would mean inlining the packing_format dispatch code as well btw, so it would be:
if isinstance(c, Float8ConfigNameTBD):
if c.packing_format == "plain":
quantized_input = Float8Tensor.to_float8(input_tensor, ...)
elif ...
....
elif ...:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually also for this one, I'm not exactly sure where the Float8Config
should be defined, currently:
quantization
quantize_
common
dynamic_activation_quantization_wrapper.py
packing_format.py
...
workflows
float8
float8_tensor.py (Float8Tensor)
# tried to define Float8Config here but seems like there are some circular deps
# previously
quant_api.py (import from quantize_.common and workflows)
Float8Config refers to `packing_format` from quantize_.common
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
38a093c
to
5fa937b
Compare
Summary: * Added Float8Tensor that's using fbgemm kernels: * per row activation + per row weight linear calling torch.ops.fbgemm.f8f8bf16_rowwise kernel * per tensor activation + per tensor weight quant linear calling torch.ops.fbgemm.f8f8bf16 kernel * per row activation + per row weight bmm calling torch.ops.fbgemm.f8f8bf16_rowwise_batched kernel * Added DynamicActivationQuantizationWrapper (trying to have a better name than the previous to_activation_quantized wrapper) for dynamic activation quantization * Added AOBaseTensorConfig and Float8TensorConfig to store the activation config (instead of a callable) * Added PackingFormat to represent different packing format (plain, preshuffled etc.) * Added common folder to host the above objects * Reusing Float8DynamicActivationFloat8WeightConfig for the above, not using _scaled_mm right now since it does not support bmm op yet, could move to use scaled_mm in the future if it's more mature: pytorch/pytorch#157950 Test Plan: python test/dtypes/test_affine_quantized_float.py python test/quantization/quantize_/workflows/float8/test_float8_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2463, branch: jerryzh168/stack/9
5fa937b
to
c5f71e6
Compare
return decorator | ||
|
||
|
||
def quantize_tensor(tensor, config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about this instead:
quantize_tensor.py
def _choose_quant_func_and_quantize_tensor(
tensor: torch.Tensor,
quantize_tensor_kwargs: AOQuantizeTensorKwargs,
) -> AOBaseTensor:
if isinstance(quantize_tensor_kwargs, QuantizeTensorToFloat8Kwargs):
quantized_tensor = Float8Tensor.to_float8(tensor, quantize_tensor_kwargs)
elif isinstance(quantize_tensor_kwargs, QuantizeTensorToInt4Kwargs):
quantized_tensor = Int4Tensor.to_float8(tensor, quantize_tensor_kwargs)
...
else:
raise Exception(...)
float8_tensor.py
... define QuantizeTensorToFloat8Kwargs...
... define Float8Tensor.to_float8(tensor, to_float8_kwargs)...
all the code is local to this function, so it's easy to see which config corresponds to which constructor without having to jump around the codebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like there is a bit of circular reference here if we reuse _choose_quant_func_and_quantize_tensor
in float8_tensor.py
to quantize the activation right? can probably work around this, just flagging
class TestFloat8Tensor(TestCase): | ||
def setUp(self): | ||
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] | ||
self.CONFIG = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we test all supported combinations?
config = self.CONFIG | ||
dtype = torch.bfloat16 | ||
device = "cuda" | ||
input = torch.randn(1, 128, dtype=dtype, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also test rank 3 and rank 4 for input?
@@ -1489,16 +1500,13 @@ class Float8WeightOnlyConfig(AOBaseConfig): | |||
|
|||
|
|||
def _float8_weight_only_quant_tensor(weight, config): | |||
from torchao.dtypes import to_affine_quantized_floatx | |||
if config.set_inductor_config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to keep this in workflow logic, here: _float8_dynamic_activation_float8_weight_transform
it's weird for a function which quantizes a tensor to have side effects
def from_float_hp( | ||
cls, | ||
hp_tensor: torch.Tensor, | ||
dtype: torch.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float8_dtype
or lp_dtype
? dtype is ambiguous
data = _quantize_affine_float8(hp_tensor, scale, dtype) | ||
|
||
dtype = hp_tensor.dtype | ||
del hp_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove? feels side-effect'y
torchao/quantization/quantize_/common/dynamic_activation_quantization_wrapper.py
Show resolved
Hide resolved
args[1], | ||
args[2] if len(args) > 2 else None, | ||
) | ||
if isinstance(weight_tensor, DynamicActivationQuantizationWrapper): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible for this to be False? if not, just remove the if?
) | ||
|
||
|
||
if TORCH_VERSION_AT_LEAST_2_5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use more recent PyTorch version?
|
||
@implements([torch.nn.functional.linear, aten.linear.default]) | ||
def _(func, types, args, kwargs): | ||
input_tensor, weight_tensor, bias = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here's what would be intuitive to me:
input_tensor, quantize_activation_wrapper, maybe_bias = ...
# recover information to quantize activation
quantize_activation_kwargs = quantize_activation_wrapper.quantize_tensor_kwargs
# unwrap the weight
inner_weight = quantize_activation_wrapper.original_weight_tensor
# quantize the activation
quantized_activation = _choose_quant_func_and_quantize_tensor(input_tensor, quantize_activation_kwargs)
# dispatch the quantized activation to the unwrapped weight's linear override
output = F.linear(quantized_activation, inner_weight, maybe_bias)
else: | ||
activation_dtype = input_tensor.dtype | ||
|
||
if activation_dtype == torch.float8_e4m3fn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: check for Float8Tensor directly?
# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) | ||
# after python 3.10 is end of life (https://devguide.python.org/versions/) | ||
class PackingFormat(str, Enum): | ||
"""Packing format for Tensor subclasses in torchao, enum for how |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for "tensor subclass", or for "individual plain tensor attribute of a tensor subclass"?
preshuffled is referring to the preshuffled format used by fbgemm kernels | ||
""" | ||
PRESHUFFLED = "preshuffled" | ||
_LEGACY = "_legacy" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this used? if not, remove?
|
||
|
||
@implements([aten.mm.default, aten.addmm.default]) | ||
def _(func, types, args, kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make this just quantize the activation and make sure all the other logic (transposing weight, checking dims, etc) is done on the weight tensor override?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
even further, how about combining the logic for all math ops, such as this:
@implements([mm, addmm, bmm, linear, ...])
def _(...):
input = args[0]
act_kwargs = ...
inner_weight = ...
quantized_input = _quantize_with_kwargs(input, inner_weight)
output = func([quantized_input, inner_weight, ...], ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from jerry: sounds good
sqnr = compute_error(original, quantized) | ||
self.assertTrue(sqnr > 20, f"sqnr: {sqnr}") | ||
|
||
def test_slice(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have an integration test which tests all of these ops and works for all supported tensors? that would also help understand why all these ops are needed.
|
||
|
||
@implements(aten.slice.Tensor) | ||
def _(func, types, args, kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: combine dispatches which are doing the same thing? for example slice and select.int, etc
Stacked PRs:
Add Float8Tensor
Summary:
support bmm op yet, could move to use scaled_mm in the future if it's more mature: [RFC]: PyTorch Low-Precision GEMMs Public API pytorch#157950
Test Plan:
python test/dtypes/test_affine_quantized_float.py
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Reviewers:
Subscribers:
Tasks:
Tags: