@@ -36,9 +36,8 @@ def __init__(self, K=64, N=32, bias=False):
36
36
37
37
def example_inputs (self , batch_size = 1 , dtype = torch .float , device = "cpu" ):
38
38
return (
39
- torch .randn (
40
- batch_size , self .linear1 .in_features , dtype = dtype , device = device
41
- ),
39
+ torch .rand (batch_size , self .linear1 .in_features , dtype = dtype , device = device )
40
+ * 0.1 ,
42
41
)
43
42
44
43
def forward (self , x ):
@@ -88,7 +87,7 @@ def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs):
88
87
)
89
88
torch ._dynamo .reset () # may segfault without this
90
89
y2 = torch .compile (m2 , fullgraph = True , dynamic = True )(* example_inputs )
91
- atol , rtol = 1e-6 , 1e-6
90
+ atol , rtol = 1e-4 , 1e-6
92
91
if dtype == torch .bfloat16 :
93
92
atol , rtol = 1.6e-2 , 3e-3
94
93
elif dtype == torch .half :
@@ -102,6 +101,56 @@ def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs):
102
101
assert torch .allclose (dqw1 , dqw1_ref )
103
102
assert torch .allclose (dqw2 , dqw2_ref )
104
103
104
+ @unittest .skipIf (
105
+ "CPU" not in torch ._C ._dispatch_dump ("torchao::float8_linear_cpu" ),
106
+ reason = "cpp kernels not built" ,
107
+ )
108
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "Test only enabled for 2.6+" )
109
+ @common_utils .parametrize ("dtype" , [torch .float , torch .bfloat16 , torch .half ])
110
+ @common_utils .parametrize ("x_dim" , [2 , 3 ])
111
+ @common_utils .parametrize ("bias" , [True , False ])
112
+ def test_dynamic_float8_linear_ref_cpu (self , dtype , x_dim , bias ):
113
+ device = "cpu"
114
+ # the shape is not supported by cpp kernel, so the ref path will be used.
115
+ m = ToyLinearModel (120 , 120 , bias = bias ).eval ().to (dtype ).to (device )
116
+ m2 = copy .deepcopy (m )
117
+ bs = 4
118
+ example_inputs = m .example_inputs (batch_size = bs , dtype = dtype , device = device )
119
+ if x_dim == 3 :
120
+ example_inputs = (example_inputs [0 ].unsqueeze (0 ),)
121
+
122
+ with torch .no_grad ():
123
+ quantize_ (
124
+ m ,
125
+ Float8DynamicActivationFloat8WeightConfig (
126
+ granularity = PerRow (),
127
+ layout = Float8DynamicActFloat8WeightCPULayout (),
128
+ ),
129
+ )
130
+ y , code = torch ._inductor .utils .run_and_get_code (
131
+ torch .compile (m , fullgraph = True , dynamic = True ),
132
+ * example_inputs ,
133
+ )
134
+ # ensure the op is not in the code
135
+ assert "torch.ops.torchao.float8_linear_cpu.default" not in code [0 ]
136
+ quantize_ (
137
+ m2 ,
138
+ Float8DynamicActivationFloat8WeightConfig (
139
+ granularity = PerRow (),
140
+ layout = PlainLayout (),
141
+ ),
142
+ )
143
+ torch ._dynamo .reset () # may segfault without this
144
+ y2 = torch .compile (m2 , fullgraph = True , dynamic = True )(* example_inputs )
145
+ assert torch .allclose (y , y2 )
146
+ # Test get_plain by dequantize()
147
+ dqw1 = m .linear1 .weight .original_weight_tensor .dequantize ()
148
+ dqw2 = m .linear2 .weight .original_weight_tensor .dequantize ()
149
+ dqw1_ref = m2 .linear1 .weight .original_weight_tensor .dequantize ()
150
+ dqw2_ref = m2 .linear2 .weight .original_weight_tensor .dequantize ()
151
+ assert torch .allclose (dqw1 , dqw1_ref )
152
+ assert torch .allclose (dqw2 , dqw2_ref )
153
+
105
154
106
155
common_utils .instantiate_parametrized_tests (TestDynamicFloat8Linear )
107
156
0 commit comments