diff --git a/torchft/manager.py b/torchft/manager.py index 8a4590f3..07d37453 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -61,9 +61,33 @@ MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT" REPLICA_ID_KEY: str = "replica_id" +# Environment variables for various timeouts. These can also be passed +# in through the manager but the environment variables take precedence. +TIMEOUT_SEC_ENV: str = "TORCHFT_TIMEOUT_SEC" +QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC" +CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC" + T = TypeVar("T") +def get_timeout( + timeout_sec_env: str | None, default_timeout_sec: timedelta +) -> timedelta: + """ + Get the timeout from the environment variable or the default value. + + Args: + timeout_sec_env: The environment variable for the timeout + default_timeout_sec: The default timeout + Returns: + The timeout to use. Environment variable takes precedence. + """ + if timeout_sec_env is not None: + return timedelta(seconds=int(timeout_sec_env)) + + return default_timeout_sec + + class WorldSizeMode(Enum): """ This controls the numerics for the job when doing allreduces across replicas @@ -177,9 +201,17 @@ def __init__( self._pending_state_dict: Optional[Dict[str, object]] = None self._use_async_quorum = use_async_quorum - self._timeout = timeout - self._quorum_timeout = quorum_timeout - self._connect_timeout = connect_timeout + + self._timeout: timedelta = get_timeout( + os.environ.get(TIMEOUT_SEC_ENV, None), timeout + ) + self._quorum_timeout: timedelta = get_timeout( + os.environ.get(QUORUM_TIMEOUT_SEC_ENV, None), quorum_timeout + ) + self._connect_timeout: timedelta = get_timeout( + os.environ.get(CONNECT_TIMEOUT_SEC_ENV, None), connect_timeout + ) + self._replica_world_size_mode = world_size_mode self._init_sync = init_sync self._max_retries = max_retries