Skip to content

Commit

Permalink
Using dummy store when user skips barrier to save open file descripto…
Browse files Browse the repository at this point in the history
…rs (#6834)
  • Loading branch information
barry-jin committed Aug 1, 2024
1 parent b9e7539 commit eef7bb4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
8 changes: 8 additions & 0 deletions configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ variables:
- Compiler cache size for the op by op executor.
type: int
default_value: 2048
XLA_USE_DUMMY_STORE:
description:
- If set to true, and user skips store based barrier by
setting TORCH_DIST_INIT_BARRIER=0, the `pjrt_rendezvous_handler`
will create a DummyStore to replace TCPStore to save open file
descriptors.
type: bool
default_value: false
device_variables:
TPU_NUM_DEVICES:
description:
Expand Down
15 changes: 14 additions & 1 deletion torch_xla/_internal/rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
_store_lock = threading.Lock()


class DummyStore(dist.Store):

def __init__(self, *args, **kwargs):
super().__init__()


def pjrt_rendezvous_handler(url: str,
timeout: datetime.timedelta = ...,
**kwargs):
Expand All @@ -34,7 +40,14 @@ def pjrt_rendezvous_handler(url: str,
with _store_lock:
global _store
if not _store:
if xu.getenv_as('TORCHELASTIC_USE_AGENT_STORE', str) == 'True':
# Create DummyStore when user skips store based barrier by setting TORCH_DIST_INIT_BARRIER=0
# and enables XLA_USE_DUMMY_STORE=1. It's safe to do so because store created by _pjrt_rendezvous_handler
# is only used as a barrier in process groups. If store is needed, user can set XLA_USE_DUMMY_STORE=0 to
# use TCPStore.
if xu.getenv_as('TORCH_DIST_INIT_BARRIER', int, 1) == 0 and xu.getenv_as(
'XLA_USE_DUMMY_STORE', int, 0) == 1:
_store = DummyStore()
elif xu.getenv_as('TORCHELASTIC_USE_AGENT_STORE', str) == 'True':
attempt = xu.getenv_as('TORCHELASTIC_RESTART_COUNT', int, defval=0)
tcp_store = dist.TCPStore(
master_ip, master_port, xr.process_count(), is_master=False)
Expand Down

0 comments on commit eef7bb4

Please sign in to comment.