@@ -392,59 +392,5 @@ def test_dynamic_scale_numeric_parity(
392
392
assert torch .equal (float8_eager ._data , float8_compile ._data )
393
393
394
394
395
- @pytest .mark .parametrize (
396
- "float8_dtype" ,
397
- [
398
- torch .float8_e4m3fn ,
399
- torch .float8_e5m2 ,
400
- ],
401
- )
402
- @pytest .mark .parametrize (
403
- "hp_dtype" ,
404
- [
405
- torch .float32 ,
406
- torch .float16 ,
407
- torch .bfloat16 ,
408
- ],
409
- )
410
- def test_quantize_dequantize_fp8_inductor (float8_dtype , hp_dtype ):
411
- quantize_affine_float8 = torch .ops .torchao .quantize_affine_float8
412
- dequantize_affine_float8 = torch .ops .torchao .dequantize_affine_float8
413
- input = torch .randn (10 , 10 )
414
- with torch .no_grad ():
415
- torch ._dynamo .reset ()
416
- expected_scale = torch .tensor (2.0 )
417
- expected_quantized = quantize_affine_float8 (
418
- input ,
419
- expected_scale ,
420
- float8_dtype = float8_dtype ,
421
- )
422
- expected_dequantized = dequantize_affine_float8 (
423
- expected_quantized ,
424
- expected_scale ,
425
- output_dtype = hp_dtype ,
426
- )
427
- test_q , (code_q ,) = torch ._inductor .utils .run_and_get_code (
428
- torch .compile (quantize_affine_float8 ),
429
- input ,
430
- expected_scale ,
431
- float8_dtype = float8_dtype ,
432
- )
433
- torch .testing .FileCheck ().check (
434
- "torch.ops.torchao.quantize_affine_float8.default"
435
- ).run (code_q )
436
- test_dq , (code_dq ,) = torch ._inductor .utils .run_and_get_code (
437
- torch .compile (dequantize_affine_float8 ),
438
- test_q ,
439
- expected_scale ,
440
- hp_dtype ,
441
- )
442
- torch .testing .FileCheck ().check (
443
- "torch.ops.torchao.dequantize_affine_float8.default"
444
- ).run (code_dq )
445
- torch .testing .assert_close (expected_quantized , test_q )
446
- torch .testing .assert_close (expected_dequantized , test_dq )
447
-
448
-
449
395
if __name__ == "__main__" :
450
396
pytest .main ([__file__ ])
0 commit comments