|
61 | 61 | MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
|
62 | 62 | REPLICA_ID_KEY: str = "replica_id"
|
63 | 63 |
|
| 64 | +# Environment variables for various timeouts. These can also be passed |
| 65 | +# in through the manager but the environment variables take precedence. |
| 66 | +TIMEOUT_SEC_ENV: str = "TORCHFT_TIMEOUT_SEC" |
| 67 | +QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC" |
| 68 | +CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC" |
| 69 | + |
64 | 70 | T = TypeVar("T")
|
65 | 71 |
|
66 | 72 |
|
| 73 | +def get_timeout( |
| 74 | + timeout_sec_env: str | None, default_timeout_sec: timedelta |
| 75 | +) -> timedelta: |
| 76 | + """ |
| 77 | + Get the timeout from the environment variable or the default value. |
| 78 | +
|
| 79 | + Args: |
| 80 | + timeout_sec_env: The environment variable for the timeout |
| 81 | + default_timeout_sec: The default timeout |
| 82 | + Returns: |
| 83 | + The timeout to use. Environment variable takes precedence. |
| 84 | + """ |
| 85 | + if timeout_sec_env is not None: |
| 86 | + return timedelta(seconds=int(timeout_sec_env)) |
| 87 | + |
| 88 | + return default_timeout_sec |
| 89 | + |
| 90 | + |
67 | 91 | class WorldSizeMode(Enum):
|
68 | 92 | """
|
69 | 93 | This controls the numerics for the job when doing allreduces across replicas
|
@@ -177,9 +201,17 @@ def __init__(
|
177 | 201 |
|
178 | 202 | self._pending_state_dict: Optional[Dict[str, object]] = None
|
179 | 203 | self._use_async_quorum = use_async_quorum
|
180 |
| - self._timeout = timeout |
181 |
| - self._quorum_timeout = quorum_timeout |
182 |
| - self._connect_timeout = connect_timeout |
| 204 | + |
| 205 | + self._timeout: timedelta = get_timeout( |
| 206 | + os.environ.get(TIMEOUT_SEC_ENV, None), timeout |
| 207 | + ) |
| 208 | + self._quorum_timeout: timedelta = get_timeout( |
| 209 | + os.environ.get(QUORUM_TIMEOUT_SEC_ENV, None), quorum_timeout |
| 210 | + ) |
| 211 | + self._connect_timeout: timedelta = get_timeout( |
| 212 | + os.environ.get(CONNECT_TIMEOUT_SEC_ENV, None), connect_timeout |
| 213 | + ) |
| 214 | + |
183 | 215 | self._replica_world_size_mode = world_size_mode
|
184 | 216 | self._init_sync = init_sync
|
185 | 217 | self._max_retries = max_retries
|
|
0 commit comments