From ab11bcee85e75139698daaf6ad376d75623e76d6 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 10 Jul 2025 15:25:38 -0700 Subject: [PATCH] make timeouts configurable Summary: while training, we need to set higher quorum timeouts to make sure all replicas can finish training these parameters are only exposed through the manager but users may not be able to access the manager directly e.g. when using torchtitan so make the timeouts configurable using env vars that take precedence over the parameters passed through manager --- torchft/manager.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/torchft/manager.py b/torchft/manager.py index 8a4590f3..07d37453 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -61,9 +61,33 @@ MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT" REPLICA_ID_KEY: str = "replica_id" +# Environment variables for various timeouts. These can also be passed +# in through the manager but the environment variables take precedence. +TIMEOUT_SEC_ENV: str = "TORCHFT_TIMEOUT_SEC" +QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC" +CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC" + T = TypeVar("T") +def get_timeout( + timeout_sec_env: str | None, default_timeout_sec: timedelta +) -> timedelta: + """ + Get the timeout from the environment variable or the default value. + + Args: + timeout_sec_env: The environment variable for the timeout + default_timeout_sec: The default timeout + Returns: + The timeout to use. Environment variable takes precedence. + """ + if timeout_sec_env is not None: + return timedelta(seconds=int(timeout_sec_env)) + + return default_timeout_sec + + class WorldSizeMode(Enum): """ This controls the numerics for the job when doing allreduces across replicas @@ -177,9 +201,17 @@ def __init__( self._pending_state_dict: Optional[Dict[str, object]] = None self._use_async_quorum = use_async_quorum - self._timeout = timeout - self._quorum_timeout = quorum_timeout - self._connect_timeout = connect_timeout + + self._timeout: timedelta = get_timeout( + os.environ.get(TIMEOUT_SEC_ENV, None), timeout + ) + self._quorum_timeout: timedelta = get_timeout( + os.environ.get(QUORUM_TIMEOUT_SEC_ENV, None), quorum_timeout + ) + self._connect_timeout: timedelta = get_timeout( + os.environ.get(CONNECT_TIMEOUT_SEC_ENV, None), connect_timeout + ) + self._replica_world_size_mode = world_size_mode self._init_sync = init_sync self._max_retries = max_retries