From e2452c18231a1fcd7624d80689045638f0703593 Mon Sep 17 00:00:00 2001 From: Zhanyong Wan Date: Thu, 26 Jun 2025 21:13:56 +0000 Subject: [PATCH 1/2] Add a test on OOM error --- test/pjrt/test_runtime_tpu.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 89ad676ca383..0a129cd96278 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -86,6 +86,18 @@ def tearDown(self) -> None: os.environ.pop(xenv.TPU_VISIBLE_CHIPS, None) os.environ.pop(xenv.TPU_PROCESS_BOUNDS, None) + def test_tensor_oom(self): + a = torch.randn((1000000000, 1000000000), device="xla") + with self.assertRaises(RuntimeError) as cm: + # The above tensor doesn't fit in HBM. It should result in a python + # exception instead of crashing. + a.sum() + + self.assertEqual( + str(cm.exception), + "", + ) + def test_xla_devices_multiprocess(self): expected = _ordinal_to_device() From 0d378cb26412350d4902a753972c11a8d4934dd4 Mon Sep 17 00:00:00 2001 From: Zhanyong Wan Date: Thu, 26 Jun 2025 22:05:22 +0000 Subject: [PATCH 2/2] fix --- test/pjrt/test_runtime_tpu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 0a129cd96278..ca45ad00b078 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -87,15 +87,15 @@ def tearDown(self) -> None: os.environ.pop(xenv.TPU_PROCESS_BOUNDS, None) def test_tensor_oom(self): - a = torch.randn((1000000000, 1000000000), device="xla") - with self.assertRaises(RuntimeError) as cm: + a = torch.randn((4000000000, 4000000000), device="xla") + with self.assertRaises(Exception) as cm: # The above tensor doesn't fit in HBM. It should result in a python # exception instead of crashing. a.sum() self.assertEqual( - str(cm.exception), - "", + (type(cm.exception).__name__, str(cm.exception)), + ("RuntimeError", "") ) def test_xla_devices_multiprocess(self):