Skip to content

Commit ab11bce

Browse files
committed
make timeouts configurable
Summary: while training, we need to set higher quorum timeouts to make sure all replicas can finish training these parameters are only exposed through the manager but users may not be able to access the manager directly e.g. when using torchtitan so make the timeouts configurable using env vars that take precedence over the parameters passed through manager
1 parent 1682257 commit ab11bce

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

torchft/manager.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,33 @@
6161
MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
6262
REPLICA_ID_KEY: str = "replica_id"
6363

64+
# Environment variables for various timeouts. These can also be passed
65+
# in through the manager but the environment variables take precedence.
66+
TIMEOUT_SEC_ENV: str = "TORCHFT_TIMEOUT_SEC"
67+
QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC"
68+
CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC"
69+
6470
T = TypeVar("T")
6571

6672

73+
def get_timeout(
74+
timeout_sec_env: str | None, default_timeout_sec: timedelta
75+
) -> timedelta:
76+
"""
77+
Get the timeout from the environment variable or the default value.
78+
79+
Args:
80+
timeout_sec_env: The environment variable for the timeout
81+
default_timeout_sec: The default timeout
82+
Returns:
83+
The timeout to use. Environment variable takes precedence.
84+
"""
85+
if timeout_sec_env is not None:
86+
return timedelta(seconds=int(timeout_sec_env))
87+
88+
return default_timeout_sec
89+
90+
6791
class WorldSizeMode(Enum):
6892
"""
6993
This controls the numerics for the job when doing allreduces across replicas
@@ -177,9 +201,17 @@ def __init__(
177201

178202
self._pending_state_dict: Optional[Dict[str, object]] = None
179203
self._use_async_quorum = use_async_quorum
180-
self._timeout = timeout
181-
self._quorum_timeout = quorum_timeout
182-
self._connect_timeout = connect_timeout
204+
205+
self._timeout: timedelta = get_timeout(
206+
os.environ.get(TIMEOUT_SEC_ENV, None), timeout
207+
)
208+
self._quorum_timeout: timedelta = get_timeout(
209+
os.environ.get(QUORUM_TIMEOUT_SEC_ENV, None), quorum_timeout
210+
)
211+
self._connect_timeout: timedelta = get_timeout(
212+
os.environ.get(CONNECT_TIMEOUT_SEC_ENV, None), connect_timeout
213+
)
214+
183215
self._replica_world_size_mode = world_size_mode
184216
self._init_sync = init_sync
185217
self._max_retries = max_retries

0 commit comments

Comments
 (0)