diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 89ad676ca383..ca45ad00b078 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -86,6 +86,18 @@ def tearDown(self) -> None: os.environ.pop(xenv.TPU_VISIBLE_CHIPS, None) os.environ.pop(xenv.TPU_PROCESS_BOUNDS, None) + def test_tensor_oom(self): + a = torch.randn((4000000000, 4000000000), device="xla") + with self.assertRaises(Exception) as cm: + # The above tensor doesn't fit in HBM. It should result in a python + # exception instead of crashing. + a.sum() + + self.assertEqual( + (type(cm.exception).__name__, str(cm.exception)), + ("RuntimeError", "") + ) + def test_xla_devices_multiprocess(self): expected = _ordinal_to_device()