diff --git a/async_timeout/__init__.py b/async_timeout/__init__.py index 7ce667c..514b05d 100644 --- a/async_timeout/__init__.py +++ b/async_timeout/__init__.py @@ -14,6 +14,9 @@ __all__ = ("timeout", "timeout_at") +_SENTINEL = object() + + def timeout(delay: Optional[float]) -> "Timeout": """timeout context manager. @@ -112,7 +115,7 @@ def __exit__( exc_val: BaseException, exc_tb: TracebackType, ) -> Optional[bool]: - self._do_exit(exc_type) + self._do_exit(exc_type, exc_val) return None async def __aenter__(self) -> "Timeout": @@ -125,7 +128,7 @@ async def __aexit__( exc_val: BaseException, exc_tb: TracebackType, ) -> Optional[bool]: - self._do_exit(exc_type) + self._do_exit(exc_type, exc_val) return None @property @@ -188,20 +191,34 @@ def _do_enter(self) -> None: raise RuntimeError(f"invalid state {self._state.value}") self._state = _State.ENTER - def _do_exit(self, exc_type: Type[BaseException]) -> None: - if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + def _do_exit(self, exc_type: Type[BaseException], exc_val: BaseException) -> None: + if sys.version_info >= (3, 9): + + def was_timeout_cancelled() -> bool: + return _SENTINEL in exc_val.args + + else: + + def was_timeout_cancelled() -> bool: + return self._state == _State.TIMEOUT + + if exc_type is asyncio.CancelledError and was_timeout_cancelled(): self._timeout_handler = None raise asyncio.TimeoutError - # timeout is not expired + # timeout has not expired self._state = _State.EXIT self._reject() return None def _on_timeout(self, task: "asyncio.Task[None]") -> None: # See Issue #229 and PR #230 for details - if task._fut_waiter and task._fut_waiter.cancelled(): # type: ignore[attr-defined] # noqa: E501 + if sys.version_info < (3, 9) and task._fut_waiter and task._fut_waiter.cancelled(): # type: ignore[attr-defined] # noqa: E501 return - task.cancel() + + if sys.version_info >= (3, 9): + task.cancel(_SENTINEL) + else: + task.cancel() self._state = _State.TIMEOUT