Skip to content

Add Float8Tensor #2463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Add Float8Tensor #2463

wants to merge 1 commit into from

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Jun 30, 2025

Stacked PRs:


Add Float8Tensor

Summary:

  • Added Float8Tensor that's using fbgemm kernels:
    • per row activation + per row weight linear calling torch.ops.fbgemm.f8f8bf16_rowwise kernel
    • per tensor activation + per tensor weight quant linear calling torch.ops.fbgemm.f8f8bf16 kernel
    • per row activation + per row weight bmm calling torch.ops.fbgemm.f8f8bf16_rowwise_batched kernel
  • Added DynamicActivationQuantizationWrapper (trying to have a better name than the previous to_activation_quantized wrapper) for dynamic activation quantization
  • Added AOBaseTensorConfig and Float8TensorConfig to store the activation config (instead of a callable)
  • Added PackingFormat to represent different packing format (plain, preshuffled etc.)
  • Added common folder to host the above objects
  • Reusing Float8DynamicActivationFloat8WeightConfig for the above, not using _scaled_mm right now since it does not
    support bmm op yet, could move to use scaled_mm in the future if it's more mature: [RFC]: PyTorch Low-Precision GEMMs Public API pytorch#157950

Test Plan:
python test/dtypes/test_affine_quantized_float.py
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Jun 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2463

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 2 Unrelated Failures

As of commit c5f71e6 with merge base 975bd57 (image):

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jerryzh168 added a commit that referenced this pull request Jun 30, 2025
Summary:
Splits out the float8 rowwise quantized path (both act and weight) of AQT to Float8RowwiseTensor

Next: could potentially incorporate the per tensor activation path there as well
Next: we can split the per tensor weight path to another Tensor as well, so we can deprecate AQT path for float8

Test Plan:
python test/dtypes/test_affine_quantized_float.py
python test/quantization/quantize_/test_float8_rowwise_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2463, branch: jerryzh168/stack/9
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/9 branch from da79207 to 5cae4d0 Compare June 30, 2025 23:01
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 30, 2025
@jerryzh168 jerryzh168 added the topic: new feature Use this tag if this PR adds a new feature label Jun 30, 2025
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/4 to main July 2, 2025 01:58
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/9 branch from 5cae4d0 to 33ca58e Compare July 2, 2025 01:58
@jerryzh168 jerryzh168 changed the title Add Float8RowwiseTensor Add Float8Tensor Jul 2, 2025
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/4 July 2, 2025 01:58
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/4 to main July 2, 2025 20:35
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/9 branch from 33ca58e to 897ec7e Compare July 2, 2025 20:36
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/4 July 2, 2025 20:36
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/4 to main July 2, 2025 21:42
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/9 branch 2 times, most recently from 7897dcf to 99a1bb1 Compare July 2, 2025 21:42
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/11 July 2, 2025 21:42
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/11 to main July 2, 2025 23:44
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/9 branch from 99a1bb1 to 7e9f224 Compare July 2, 2025 23:44
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/11 July 2, 2025 23:44
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/11 to main July 3, 2025 00:09
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/9 branch from 7e9f224 to 442bd6c Compare July 3, 2025 00:09
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/11 July 3, 2025 00:09

quantized_weight = to_linear_activation_quantized(
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
act_quant_config = Float8TensorConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the user should specify this, it's confusing to have some kinds of configs specified by user and some kinds specified in glue code.

i.e. can we just remove Float8TensorConfig and inline the logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can expose this to the workflow config as a user facing arg I think, but another purpose for this config is that DynamicActivationQuantizationWrapper can be used for different types of activation quantization, instead of we passing around the input_quant_func and kwargs as we do before, what do you have in mind for how to work with DynamicActivationQuantizationWrapper when we inline this config?

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/9 branch from 1f0adfc to 2568d3d Compare July 15, 2025 18:15
] = {}


def register_quantize_tensor_handler(config_type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove this handler abstraction and just inline the code

if isinstance(config_type, Float8ConfigNameTBD):
    tensor_q = Float8Tensor.to_float8(...)
 ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would delete this and inline, see comment below, to simplify and make the code easier to follow

return input_tensor
original_weight_tensor = weight_tensor.original_weight_tensor
c = weight_tensor.act_tensor_config
quantized_input = quantize_tensor(input_tensor, c)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline the logic here please so there is no need to look at quantize_tensor and think about what gets dispatched where

if isinstance(c, Float8ConfigNameTBD):
    quantized_input = Float8Tensor.to_float8(input_tensor, ...)
elif ...:
    ...

this way it's 10x easier to follow the code

Copy link
Contributor Author

@jerryzh168 jerryzh168 Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would mean adding each new type of activation we'll have to modify dynamic_activation_quantization_wrapper.py, but I'll change assuming you ha`ve considered this. also this would mean inlining the packing_format dispatch code as well btw, so it would be:

if isinstance(c, Float8ConfigNameTBD):
    if c.packing_format == "plain":
        quantized_input = Float8Tensor.to_float8(input_tensor, ...)
    elif ...
        ....
elif ...:
    ...
    

Copy link
Contributor Author

@jerryzh168 jerryzh168 Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually also for this one, I'm not exactly sure where the Float8Config should be defined, currently:

quantization
  quantize_
    common
       dynamic_activation_quantization_wrapper.py
       packing_format.py
	   ...
    workflows
      float8
        float8_tensor.py (Float8Tensor)
        # tried to define Float8Config here but seems like there are some circular deps
  # previously
  quant_api.py  (import from quantize_.common and workflows)
    Float8Config refers to `packing_format` from quantize_.common

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/9 branch 2 times, most recently from 38a093c to 5fa937b Compare July 15, 2025 23:12
Summary:
* Added Float8Tensor that's using fbgemm kernels:
    * per row activation + per row weight linear calling torch.ops.fbgemm.f8f8bf16_rowwise kernel
    * per tensor activation + per tensor weight quant linear calling torch.ops.fbgemm.f8f8bf16 kernel
    * per row activation + per row weight bmm calling torch.ops.fbgemm.f8f8bf16_rowwise_batched kernel
* Added DynamicActivationQuantizationWrapper (trying to have a better name than the previous to_activation_quantized wrapper) for dynamic activation quantization
* Added AOBaseTensorConfig and Float8TensorConfig to store the activation config (instead of a callable)
* Added PackingFormat to represent different packing format (plain, preshuffled etc.)
* Added common folder to host the above objects
* Reusing Float8DynamicActivationFloat8WeightConfig for the above, not using _scaled_mm right now since it does not
support bmm op yet, could move to use scaled_mm in the future if it's more mature: pytorch/pytorch#157950

Test Plan:
python test/dtypes/test_affine_quantized_float.py
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2463, branch: jerryzh168/stack/9
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/9 branch from 5fa937b to c5f71e6 Compare July 16, 2025 00:18
return decorator


def quantize_tensor(tensor, config):
Copy link
Contributor

@vkuzo vkuzo Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about this instead:

quantize_tensor.py

def _choose_quant_func_and_quantize_tensor(
  tensor: torch.Tensor, 
  quantize_tensor_kwargs: AOQuantizeTensorKwargs,
) -> AOBaseTensor:
    if isinstance(quantize_tensor_kwargs, QuantizeTensorToFloat8Kwargs):
        quantized_tensor = Float8Tensor.to_float8(tensor, quantize_tensor_kwargs)
    elif isinstance(quantize_tensor_kwargs, QuantizeTensorToInt4Kwargs):
        quantized_tensor = Int4Tensor.to_float8(tensor, quantize_tensor_kwargs)
    ...
    else:
        raise Exception(...)

float8_tensor.py

... define QuantizeTensorToFloat8Kwargs...
... define Float8Tensor.to_float8(tensor, to_float8_kwargs)...

all the code is local to this function, so it's easy to see which config corresponds to which constructor without having to jump around the codebase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like there is a bit of circular reference here if we reuse _choose_quant_func_and_quantize_tensor in float8_tensor.py to quantize the activation right? can probably work around this, just flagging

class TestFloat8Tensor(TestCase):
def setUp(self):
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
self.CONFIG = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we test all supported combinations?

config = self.CONFIG
dtype = torch.bfloat16
device = "cuda"
input = torch.randn(1, 128, dtype=dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also test rank 3 and rank 4 for input?

@@ -1489,16 +1500,13 @@ class Float8WeightOnlyConfig(AOBaseConfig):


def _float8_weight_only_quant_tensor(weight, config):
from torchao.dtypes import to_affine_quantized_floatx
if config.set_inductor_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to keep this in workflow logic, here: _float8_dynamic_activation_float8_weight_transform

it's weird for a function which quantizes a tensor to have side effects

def from_float_hp(
cls,
hp_tensor: torch.Tensor,
dtype: torch.dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float8_dtype or lp_dtype? dtype is ambiguous

data = _quantize_affine_float8(hp_tensor, scale, dtype)

dtype = hp_tensor.dtype
del hp_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove? feels side-effect'y

args[1],
args[2] if len(args) > 2 else None,
)
if isinstance(weight_tensor, DynamicActivationQuantizationWrapper):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible for this to be False? if not, just remove the if?

)


if TORCH_VERSION_AT_LEAST_2_5:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use more recent PyTorch version?


@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's what would be intuitive to me:

input_tensor, quantize_activation_wrapper, maybe_bias = ...

# recover information to quantize activation
quantize_activation_kwargs = quantize_activation_wrapper.quantize_tensor_kwargs

# unwrap the weight
inner_weight = quantize_activation_wrapper.original_weight_tensor

# quantize the activation
quantized_activation = _choose_quant_func_and_quantize_tensor(input_tensor, quantize_activation_kwargs)

# dispatch the quantized activation to the unwrapped weight's linear override
output = F.linear(quantized_activation, inner_weight, maybe_bias)

else:
activation_dtype = input_tensor.dtype

if activation_dtype == torch.float8_e4m3fn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: check for Float8Tensor directly?

# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum)
# after python 3.10 is end of life (https://devguide.python.org/versions/)
class PackingFormat(str, Enum):
"""Packing format for Tensor subclasses in torchao, enum for how
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for "tensor subclass", or for "individual plain tensor attribute of a tensor subclass"?

preshuffled is referring to the preshuffled format used by fbgemm kernels
"""
PRESHUFFLED = "preshuffled"
_LEGACY = "_legacy"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used? if not, remove?



@implements([aten.mm.default, aten.addmm.default])
def _(func, types, args, kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this just quantize the activation and make sure all the other logic (transposing weight, checking dims, etc) is done on the weight tensor override?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even further, how about combining the logic for all math ops, such as this:

@implements([mm, addmm, bmm, linear, ...])
def _(...):
    input = args[0]
    act_kwargs = ...
    inner_weight = ...
    quantized_input = _quantize_with_kwargs(input, inner_weight)
    output = func([quantized_input, inner_weight, ...], ...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from jerry: sounds good

sqnr = compute_error(original, quantized)
self.assertTrue(sqnr > 20, f"sqnr: {sqnr}")

def test_slice(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have an integration test which tests all of these ops and works for all supported tensors? that would also help understand why all these ops are needed.



@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: combine dispatches which are doing the same thing? for example slice and select.int, etc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants