Skip to content

Commit cd53802

Browse files
committed
Check N & K % 32 == 0; update UT
1 parent 953ac13 commit cd53802

File tree

3 files changed

+59
-7
lines changed

3 files changed

+59
-7
lines changed

test/quantization/test_dynamic_float8_linear_cpu.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ def __init__(self, K=64, N=32, bias=False):
3636

3737
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
3838
return (
39-
torch.randn(
40-
batch_size, self.linear1.in_features, dtype=dtype, device=device
41-
),
39+
torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device)
40+
* 0.1,
4241
)
4342

4443
def forward(self, x):
@@ -88,7 +87,7 @@ def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs):
8887
)
8988
torch._dynamo.reset() # may segfault without this
9089
y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs)
91-
atol, rtol = 1e-6, 1e-6
90+
atol, rtol = 1e-4, 1e-6
9291
if dtype == torch.bfloat16:
9392
atol, rtol = 1.6e-2, 3e-3
9493
elif dtype == torch.half:
@@ -102,6 +101,56 @@ def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs):
102101
assert torch.allclose(dqw1, dqw1_ref)
103102
assert torch.allclose(dqw2, dqw2_ref)
104103

104+
@unittest.skipIf(
105+
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
106+
reason="cpp kernels not built",
107+
)
108+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
109+
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
110+
@common_utils.parametrize("x_dim", [2, 3])
111+
@common_utils.parametrize("bias", [True, False])
112+
def test_dynamic_float8_linear_ref_cpu(self, dtype, x_dim, bias):
113+
device = "cpu"
114+
# the shape is not supported by cpp kernel, so the ref path will be used.
115+
m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device)
116+
m2 = copy.deepcopy(m)
117+
bs = 4
118+
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
119+
if x_dim == 3:
120+
example_inputs = (example_inputs[0].unsqueeze(0),)
121+
122+
with torch.no_grad():
123+
quantize_(
124+
m,
125+
Float8DynamicActivationFloat8WeightConfig(
126+
granularity=PerRow(),
127+
layout=Float8DynamicActFloat8WeightCPULayout(),
128+
),
129+
)
130+
y, code = torch._inductor.utils.run_and_get_code(
131+
torch.compile(m, fullgraph=True, dynamic=True),
132+
*example_inputs,
133+
)
134+
# ensure the op is not in the code
135+
assert "torch.ops.torchao.float8_linear_cpu.default" not in code[0]
136+
quantize_(
137+
m2,
138+
Float8DynamicActivationFloat8WeightConfig(
139+
granularity=PerRow(),
140+
layout=PlainLayout(),
141+
),
142+
)
143+
torch._dynamo.reset() # may segfault without this
144+
y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs)
145+
assert torch.allclose(y, y2)
146+
# Test get_plain by dequantize()
147+
dqw1 = m.linear1.weight.original_weight_tensor.dequantize()
148+
dqw2 = m.linear2.weight.original_weight_tensor.dequantize()
149+
dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize()
150+
dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize()
151+
assert torch.allclose(dqw1, dqw1_ref)
152+
assert torch.allclose(dqw2, dqw2_ref)
153+
105154

106155
common_utils.instantiate_parametrized_tests(TestDynamicFloat8Linear)
107156

torchao/csrc/cpu/float8_linear.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ float8_linear_prepack_impl(
4444
int N = weight.size(0);
4545
int K = weight.size(1);
4646
int G = scales.size(1);
47+
TORCH_CHECK(K % G == 0, "K should be divisible by num_groups");
4748
int group_size = K / G;
4849
int block_k = group_size > 128 ? 128 : group_size;
4950
while (K % block_k != 0) {
@@ -52,6 +53,7 @@ float8_linear_prepack_impl(
5253
TORCH_CHECK(block_k > 0 && block_k <= group_size,
5354
"Float8 linear CPU: Invalid block_k size, should be in (0, group_size]");
5455
constexpr int block_n = BLOCK_N;
56+
TORCH_CHECK(N % block_n == 0, "N should be divisible by 32");
5557
int Nc = N / block_n;
5658
int Kc = K / block_k;
5759

torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __tensor_unflatten__(
8484
tensor_data_dict["packed_weight"],
8585
tensor_data_dict["scales"],
8686
)
87-
(_layout, transposed) = tensor_attributes
87+
(transposed, _layout) = tensor_attributes
8888
return cls(packed_weight, scales, transposed, _layout)
8989

9090
@classmethod
@@ -103,8 +103,9 @@ def from_plain(
103103
scale.unsqueeze_(-1)
104104
scale = scale.to(torch.float)
105105

106+
N = data.size(0)
106107
K = data.size(-1)
107-
if K % 32 == 0:
108+
if N % 32 == 0 and K % 32 == 0:
108109
# Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n].
109110
# Pack inner blocks [block_k, block_n] to VNNI layout if AMX is available.
110111
# Pack scales from [N, num_groups] to [N / block_n, num_groups, block_n].
@@ -178,7 +179,7 @@ def block_size(self):
178179
return (1, group_size)
179180

180181
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181-
if self._layout == PlainLayout:
182+
if isinstance(self._layout, PlainLayout):
182183
# If the layout is PlainLayout, return the packed weight and scales directly
183184
return (
184185
self.packed_weight,

0 commit comments

Comments
 (0)