Skip to content

Commit 0c7f8ea

Browse files
committed
resolve conflict
1 parent 1c1f890 commit 0c7f8ea

File tree

1 file changed

+0
-54
lines changed

1 file changed

+0
-54
lines changed

test/float8/test_compile.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -392,59 +392,5 @@ def test_dynamic_scale_numeric_parity(
392392
assert torch.equal(float8_eager._data, float8_compile._data)
393393

394394

395-
@pytest.mark.parametrize(
396-
"float8_dtype",
397-
[
398-
torch.float8_e4m3fn,
399-
torch.float8_e5m2,
400-
],
401-
)
402-
@pytest.mark.parametrize(
403-
"hp_dtype",
404-
[
405-
torch.float32,
406-
torch.float16,
407-
torch.bfloat16,
408-
],
409-
)
410-
def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype):
411-
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
412-
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
413-
input = torch.randn(10, 10)
414-
with torch.no_grad():
415-
torch._dynamo.reset()
416-
expected_scale = torch.tensor(2.0)
417-
expected_quantized = quantize_affine_float8(
418-
input,
419-
expected_scale,
420-
float8_dtype=float8_dtype,
421-
)
422-
expected_dequantized = dequantize_affine_float8(
423-
expected_quantized,
424-
expected_scale,
425-
output_dtype=hp_dtype,
426-
)
427-
test_q, (code_q,) = torch._inductor.utils.run_and_get_code(
428-
torch.compile(quantize_affine_float8),
429-
input,
430-
expected_scale,
431-
float8_dtype=float8_dtype,
432-
)
433-
torch.testing.FileCheck().check(
434-
"torch.ops.torchao.quantize_affine_float8.default"
435-
).run(code_q)
436-
test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code(
437-
torch.compile(dequantize_affine_float8),
438-
test_q,
439-
expected_scale,
440-
hp_dtype,
441-
)
442-
torch.testing.FileCheck().check(
443-
"torch.ops.torchao.dequantize_affine_float8.default"
444-
).run(code_dq)
445-
torch.testing.assert_close(expected_quantized, test_q)
446-
torch.testing.assert_close(expected_dequantized, test_dq)
447-
448-
449395
if __name__ == "__main__":
450396
pytest.main([__file__])

0 commit comments

Comments
 (0)