From cbc1a7f8acefae9a6d856cfd33da53fac433a8cd Mon Sep 17 00:00:00 2001 From: torchxlabot2 Date: Thu, 1 Aug 2024 00:17:09 +0000 Subject: [PATCH] Update doc from commit 4e94ff676ca3999e48cf90ba1a8da73d0bae37a3 --- master/_modules/index.html | 2 +- master/_modules/torch_xla/core/xla_model.html | 2 +- master/_modules/torch_xla/debug/metrics.html | 2 +- .../distributed/parallel_loader.html | 2 +- .../distributed/spmd/xla_sharding.html | 2 +- .../distributed/xla_multiprocessing.html | 2 +- .../torch_xla/experimental/eager.html | 2 +- master/_modules/torch_xla/runtime.html | 2 +- master/_modules/torch_xla/torch_xla.html | 31 +++++++++++++----- master/debug.html | 2 +- master/eager_mode.html | 2 +- master/genindex.html | 2 +- master/gpu.html | 2 +- master/index.html | 9 +++-- master/multi_process_distributed.html | 2 +- master/notes/source_of_recompilation.html | 2 +- master/objects.inv | Bin 1047 -> 1047 bytes master/py-modindex.html | 2 +- master/runtime.html | 2 +- master/search.html | 2 +- master/searchindex.js | 2 +- master/spmd.html | 2 +- master/torch_compile.html | 2 +- 23 files changed, 50 insertions(+), 30 deletions(-) diff --git a/master/_modules/index.html b/master/_modules/index.html index cf05e5790c9..94972f549c1 100644 --- a/master/_modules/index.html +++ b/master/_modules/index.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/_modules/torch_xla/core/xla_model.html b/master/_modules/torch_xla/core/xla_model.html index b894c278453..a266c4fbfa8 100644 --- a/master/_modules/torch_xla/core/xla_model.html +++ b/master/_modules/torch_xla/core/xla_model.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/_modules/torch_xla/debug/metrics.html b/master/_modules/torch_xla/debug/metrics.html index 281ca24f5ff..d1cb30b08d3 100644 --- a/master/_modules/torch_xla/debug/metrics.html +++ b/master/_modules/torch_xla/debug/metrics.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/_modules/torch_xla/distributed/parallel_loader.html b/master/_modules/torch_xla/distributed/parallel_loader.html index 6f1c52409a9..e603ea4b6cf 100644 --- a/master/_modules/torch_xla/distributed/parallel_loader.html +++ b/master/_modules/torch_xla/distributed/parallel_loader.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/_modules/torch_xla/distributed/spmd/xla_sharding.html b/master/_modules/torch_xla/distributed/spmd/xla_sharding.html index 228b6cd7b2a..3e99a8147dd 100644 --- a/master/_modules/torch_xla/distributed/spmd/xla_sharding.html +++ b/master/_modules/torch_xla/distributed/spmd/xla_sharding.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/_modules/torch_xla/distributed/xla_multiprocessing.html b/master/_modules/torch_xla/distributed/xla_multiprocessing.html index 903f68d03c1..8664d1de581 100644 --- a/master/_modules/torch_xla/distributed/xla_multiprocessing.html +++ b/master/_modules/torch_xla/distributed/xla_multiprocessing.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/_modules/torch_xla/experimental/eager.html b/master/_modules/torch_xla/experimental/eager.html index ac8ba97cc33..c0021e6bad8 100644 --- a/master/_modules/torch_xla/experimental/eager.html +++ b/master/_modules/torch_xla/experimental/eager.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/_modules/torch_xla/runtime.html b/master/_modules/torch_xla/runtime.html index 93e8ab26ed3..0cc92bb5b25 100644 --- a/master/_modules/torch_xla/runtime.html +++ b/master/_modules/torch_xla/runtime.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/_modules/torch_xla/torch_xla.html b/master/_modules/torch_xla/torch_xla.html index 18eccf76017..71bd84013f5 100644 --- a/master/_modules/torch_xla/torch_xla.html +++ b/master/_modules/torch_xla/torch_xla.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
@@ -442,9 +442,20 @@

Source code for torch_xla.torch_xla

   return len(real_devices())
-
[docs]def sync(): - """Launches all pending graph operations.""" - xm.mark_step()
+
[docs]def sync(wait: bool = False): + """Launches all pending graph operations. + + Args: + wait (bool): whether to block the current process until the execution finished. + + """ + torch_xla._XLAC._xla_step_marker( + torch_xla._XLAC._xla_get_default_device(), + [], + wait=wait, + ) + devctx = xm._run_step_closures() + torch_xla._XLAC._set_all_reduce_token(devctx.device, None)
def step(): @@ -489,13 +500,15 @@

Source code for torch_xla.torch_xla

         res = foo2(x)
   """
 
+  def _clear_pending_ops_before_compile():
+    sync()
+
   @contextlib.contextmanager
-  def _step():
+  def _compile():
     saved_eager_mode_status = torch_xla._XLAC._get_use_eager_mode()
     saved_allow_execution = torch_xla._XLAC._get_allow_execution()
     torch_xla._XLAC._set_use_eager_mode(False)
-    # Clear pending operations
-    sync()
+    _clear_pending_ops_before_compile()
 
     # if full_graph sets to true execution can not happen before the sync below
     torch_xla._XLAC._set_allow_execution(not full_graph)
@@ -504,10 +517,12 @@ 

Source code for torch_xla.torch_xla

       yield
     finally:
       torch_xla._XLAC._set_allow_execution(saved_allow_execution)
+      # Collect the traced graph after running the target function and
+      # execute the graph.
       sync()
       torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)
 
-  return _step() if not f else _step()(f)
+ return _compile() if not f else _compile()(f)
[docs]def manual_seed(seed, device=None): diff --git a/master/debug.html b/master/debug.html index c32572db24c..7e62040b5f2 100644 --- a/master/debug.html +++ b/master/debug.html @@ -267,7 +267,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/eager_mode.html b/master/eager_mode.html index 8c0a55f8d15..bb304dcf713 100644 --- a/master/eager_mode.html +++ b/master/eager_mode.html @@ -267,7 +267,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/genindex.html b/master/genindex.html index 5e43cda6316..8c529b80dfb 100644 --- a/master/genindex.html +++ b/master/genindex.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/gpu.html b/master/gpu.html index 91cc6d944cc..56555d18744 100644 --- a/master/gpu.html +++ b/master/gpu.html @@ -267,7 +267,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/index.html b/master/index.html index c95c8dcc7ae..ce8e347b930 100644 --- a/master/index.html +++ b/master/index.html @@ -266,7 +266,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
@@ -750,8 +750,13 @@

PyTorch/XLA API
-torch_xla.sync()[source]
+torch_xla.sync(wait: bool = False)[source]

Launches all pending graph operations.

+
+
Parameters
+

wait (bool) – whether to block the current process until the execution finished.

+
+
diff --git a/master/multi_process_distributed.html b/master/multi_process_distributed.html index 0bcddfeac3a..2c389fd5d97 100644 --- a/master/multi_process_distributed.html +++ b/master/multi_process_distributed.html @@ -267,7 +267,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/notes/source_of_recompilation.html b/master/notes/source_of_recompilation.html index 917ad369bc1..e75f591dd00 100644 --- a/master/notes/source_of_recompilation.html +++ b/master/notes/source_of_recompilation.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/objects.inv b/master/objects.inv index 01f52ef3d8f4f6261762f85c97641085250a0c9f..672ae67f0a26779963d4f2df9de9d508a190373a 100644 GIT binary patch delta 18 ZcmbQvF`Z*V0J}-5rAb - master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )

diff --git a/master/runtime.html b/master/runtime.html index 6d5fe5ff58c..5e72cc9fd65 100644 --- a/master/runtime.html +++ b/master/runtime.html @@ -267,7 +267,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/search.html b/master/search.html index 7a3366ec62c..a77225cd1ce 100644 --- a/master/search.html +++ b/master/search.html @@ -265,7 +265,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/searchindex.js b/master/searchindex.js index 651039d1a53..fbc8285722e 100644 --- a/master/searchindex.js +++ b/master/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["debug", "eager_mode", "gpu", "index", "multi_process_distributed", "notes/source_of_recompilation", "runtime", "spmd", "torch_compile"], "filenames": ["debug.rst", "eager_mode.rst", "gpu.rst", "index.rst", "multi_process_distributed.rst", "notes/source_of_recompilation.md", "runtime.rst", "spmd.rst", "torch_compile.rst"], "titles": ["Troubleshooting", "Eager Mode + Compile API", "How to run with PyTorch/XLA:GPU", "PyTorch/XLA documentation", "How to do DistributedDataParallel(DDP)", "Source of recompilations in torch_xla", "PJRT Runtime", "PyTorch/XLA SPMD User Guide", "TorchDynamo(torch.compile) integration in PyTorch XLA"], "terms": {"note": [0, 1, 2, 3, 4, 5, 6, 7, 8], "inform": [0, 3, 5, 6], "thi": [0, 1, 2, 3, 4, 5, 6, 7, 8], "section": [0, 2, 3, 6, 7], "i": [0, 1, 2, 3, 4, 6, 8], "subject": 0, "remov": [0, 6], "futur": [0, 3, 5, 6, 7], "releas": [0, 2, 3, 4, 6, 7, 8], "softwar": 0, "sinc": [0, 3, 4, 5, 6, 7, 8], "mani": [0, 3, 5, 6], "them": [0, 3, 5, 6], "ar": [0, 1, 4, 5, 6, 7, 8], "peculiar": 0, "given": [0, 3, 4, 5, 7], "intern": [0, 3, 5, 6, 7], "implement": [0, 1, 4, 5, 6, 7, 8], "which": [0, 2, 3, 4, 5, 6, 7, 8], "might": [0, 3, 5], "chang": [0, 3, 4, 5, 7], "befor": [0, 3, 4, 5, 6, 7, 8], "ani": [0, 3, 4, 5, 6, 7], "depth": [0, 6], "we": [0, 1, 2, 3, 4, 5, 6, 7, 8], "want": [0, 1, 3, 5, 6, 7, 8], "do": [0, 2, 3, 5, 6, 7], "instal": [0, 2, 6], "should": [0, 1, 2, 3, 4, 5, 6, 7], "match": [0, 3], "out": [0, 1, 3, 6, 7, 8], "our": [0, 2, 3, 4, 5, 6, 7, 8], "readm": 0, "detial": 0, "avail": [0, 3, 4, 5, 6], "vm": [0, 2, 3, 4, 6], "python": [0, 2, 3, 4, 5, 6, 7, 8], "import": [0, 1, 3, 4, 6, 7, 8], "torch": [0, 1, 2, 3, 4, 5], "torch_xla": [0, 1, 2, 4, 6, 7, 8], "print": [0, 2, 3, 4, 5, 6, 7, 8], "__version__": 0, "2": [0, 1, 2, 3, 4, 6, 8], "1": [0, 1, 2, 3, 4, 6, 7, 8], "0": [0, 2, 3, 4, 5, 6, 7, 8], "cu121": 0, "export": [0, 2, 6], "pjrt_devic": [0, 2, 3, 6], "tpu": [0, 2, 8], "python3": [0, 2, 3, 4, 6], "core": [0, 1, 3, 4, 6, 7, 8], "xla_model": [0, 4, 6, 7, 8], "xm": [0, 1, 3, 4, 7, 8], "t1": [0, 3, 7], "100": [0, 2, 4], "devic": [0, 1, 2, 4, 5, 6, 7, 8], "xla_devic": [0, 3, 4, 6, 7, 8], "t2": [0, 7], "200": 0, "300": [0, 1], "For": [0, 2, 3, 4, 5, 6, 7, 8], "nightli": [0, 4, 7], "git": [0, 2, 4, 6], "clone": [0, 2, 6], "http": [0, 2, 3, 4, 6, 7], "github": [0, 2, 4, 6, 7], "com": [0, 2, 4, 6, 7], "test_train_mp_imagenet": [0, 2, 4, 6], "py": [0, 2, 3, 4, 6, 7], "fake_data": [0, 2, 4, 6], "x": [0, 3, 4, 5, 7], "y": [0, 2, 3, 4, 5, 7], "you": [0, 1, 2, 3, 4, 6, 7, 8], "us": [0, 1, 2, 3, 6, 8], "branch": [0, 5, 6], "rx": 0, "exampl": [0, 1, 3, 5, 6, 8], "r2": [0, 6, 7], "If": [0, 2, 3, 5, 6, 7], "can": [0, 1, 2, 3, 4, 6, 7, 8], "conclud": 0, "correctli": [0, 2, 7], "To": [0, 1, 2, 3, 4, 5, 6, 7], "diagnos": 0, "issu": [0, 1, 3, 4, 6, 7], "counter": [0, 3], "provid": [0, 3, 4, 5, 7, 8], "first": [0, 3, 4, 6, 7], "thing": 0, "when": [0, 1, 3, 4, 6, 7, 8], "model": [0, 1, 4, 5, 6, 7, 8], "slow": 0, "gener": [0, 1, 3, 5, 6], "extrem": 0, "help": [0, 5], "pleas": [0, 2, 3, 4, 6, 7], "try": [0, 5], "includ": [0, 2, 3, 5, 6, 7], "your": [0, 2, 3, 4, 5, 6, 7], "bug": [0, 4, 6], "sent": [0, 3], "u": [0, 2, 5, 6, 8], "have": [0, 2, 3, 4, 5, 6, 7, 8], "enabl": [0, 1, 2, 3, 4, 7], "set": [0, 3, 4, 5, 6, 7, 8], "pt_xla_debug_level": 0, "coupl": [0, 3], "featur": [0, 4, 6, 7], "also": [0, 1, 2, 3, 4, 5, 6, 7, 8], "lower": [0, 5], "level": [0, 7, 8], "slip": 0, "analyz": 0, "summari": 0, "some": [0, 1, 3, 4, 6, 7], "output": [0, 3, 4, 6, 8], "would": [0, 3, 5, 6], "pt": [0, 3, 6], "compiletim": 0, "too": [0, 5], "frequent": 0, "21": 0, "count": [0, 3], "dure": [0, 3, 4, 7, 8], "11": [0, 2, 5], "step": [0, 1, 2, 3, 4, 5, 6, 7, 8], "transferfromdevicetim": 0, "op": [0, 1, 3, 5, 7], "": [0, 1, 2, 3, 4, 6, 7, 8], "aten": [0, 5], "_ctc_loss": 0, "_ctc_loss_backward": 0, "open": [0, 6], "abov": [0, 1, 2, 3, 4, 5, 6, 7, 8], "request": [0, 3, 4, 5, 7], "23": [0, 2], "12": [0, 2, 4, 6, 8], "everi": [0, 3, 5, 6, 7, 8], "caus": [0, 1, 3, 5, 6], "mark_step": [0, 1, 3, 4, 6], "parallel": [0, 3, 6], "loader": [0, 3, 8], "end": [0, 2, 3, 4, 6, 7], "graph": [0, 1, 3, 4, 5, 6, 7, 8], "info": [0, 2, 3, 5, 7], "hash": 0, "c74c3b91b855b2b123f833b0d5f86943": 0, "number": [0, 1, 3, 4, 6, 7], "input": [0, 1, 3, 6, 7], "35": [0, 2, 6], "107": 0, "frame": 0, "trigger": [0, 5], "workspac": 0, "dk3": 0, "1055": 0, "next": [0, 3, 5], "distribut": [0, 4], "parallel_load": [0, 3, 6], "44": 0, "__next__": 0, "32": [0, 3], "train_loop_fn": 0, "train_decoder_only_bas": 0, "48": [0, 4], "start_train": 0, "65": [0, 1], "modul": [0, 3, 4, 7], "73": 0, "post": [0, 3], "size": [0, 2, 3, 5, 6, 7], "548000": 0, "gb": 0, "7": [0, 3, 4, 8], "922460": 0, "alias": 0, "547871": 0, "intermedi": [0, 3, 6], "124478": 0, "program": [0, 3, 5, 7, 8], "028210": 0, "user": [0, 1, 2, 3, 5, 6, 8], "manual": [0, 1, 4], "call": [0, 1, 3, 4, 5, 6, 7, 8], "configur": [0, 2, 3, 6, 7], "batch": [0, 3, 6, 7], "exit": [0, 3, 4], "steptrac": 0, "region": [0, 1, 3, 7], "decid": [0, 3, 5], "access": [0, 3, 5, 6, 7], "often": [0, 1, 5], "due": [0, 6], "log": [0, 2], "valu": [0, 3, 5, 6, 7], "4": [0, 2, 3, 4, 5, 6, 7, 8], "expect": [0, 1, 5, 6, 8], "avoid": [0, 2], "5": [0, 2, 3, 4, 5], "either": [0, 2, 3, 5, 6], "reduc": [0, 1, 3, 4, 6], "frequenc": 0, "add": [0, 3, 4, 5, 8], "see": [0, 2, 3, 4, 5, 6, 8], "pair": 0, "after": [0, 2, 3, 5, 6, 7], "stabil": [0, 6], "onli": [0, 1, 2, 3, 5, 6, 7, 8], "disabl": [0, 1, 3], "effici": [0, 8], "same": [0, 1, 3, 5, 6, 7], "code": [0, 1, 3, 4, 5, 6, 7, 8], "happen": [0, 1, 3, 5, 6], "onc": [0, 3, 5, 7, 8], "keep": [0, 5, 6], "dump": [0, 3], "ir": [0, 5], "hlo": [0, 3], "follow": [0, 1, 2, 3, 4, 5, 6, 7], "compar": [0, 1, 3, 4, 6, 8], "each": [0, 3, 4, 5, 6, 7, 8], "sourc": [0, 3], "differ": [0, 3, 4, 5, 7], "explain": [0, 3, 5, 7], "how": [0, 1, 3, 5, 6], "detail": [0, 3, 5, 6], "put": [0, 3, 4], "line": [0, 1, 3, 4, 5], "met": 0, "short": [0, 5], "contain": [0, 2, 3, 5, 6], "few": [0, 3, 4, 5, 7], "kei": [0, 6, 7], "short_metrics_report": [0, 3], "full": [0, 2, 3, 4], "all": [0, 2, 3, 4, 5, 6, 7], "metrics_report": [0, 3], "like": [0, 3, 4, 5, 6, 7], "time": [0, 2, 3, 5, 6, 7, 8], "spent": 0, "handl": [0, 1, 4, 5, 7], "creat": [0, 4, 6, 7], "destroi": 0, "etc": [0, 1, 2, 3, 5, 7], "term": [0, 1, 5], "percentil": 0, "sampl": [0, 3, 6], "an": [0, 4, 5, 6, 7, 8], "totalsampl": 0, "202": 0, "06m09s401ms746": 0, "001u": 0, "valuer": 0, "778ms572": 0, "062u": 0, "second": [0, 4, 6, 7], "rate": [0, 2, 4], "425201": 0, "001ms32": 0, "778u": 0, "001ms61": 0, "283u": 0, "10": [0, 2, 3, 5, 6, 7, 8], "001ms79": 0, "236u": 0, "20": [0, 2, 3, 4], "001ms110": 0, "973u": 0, "50": [0, 2], "001ms228": 0, "773u": 0, "80": [0, 2], "001ms339": 0, "183u": 0, "90": 0, "001ms434": 0, "305u": 0, "95": 0, "002ms921": 0, "063u": 0, "99": [0, 4], "21s102ms853": 0, "173u": 0, "name": [0, 2, 3, 5, 6, 7], "integ": [0, 3], "track": [0, 7], "statu": 0, "cachedsynctensor": 0, "395": [0, 4], "In": [0, 1, 2, 3, 5, 6, 7, 8], "start": [0, 1, 3, 6], "indic": [0, 3, 5], "context": [0, 3, 5, 6], "switch": [0, 3, 4, 5], "between": [0, 2, 3, 4, 5, 6, 7], "cpu": [0, 2, 4, 5, 7], "potenti": [0, 3, 6, 7], "optim": [0, 1, 2, 3, 4, 5, 6, 7, 8], "area": 0, "oper": [0, 3, 6, 7], "rout": 0, "back": [0, 3, 7], "engin": 0, "thei": [0, 3, 5, 6, 7], "fulli": [0, 1, 3, 6], "qualifi": 0, "c": [0, 3, 5, 6], "namespac": 0, "nonzero": [0, 5], "33": [0, 2, 4, 8], "other": [0, 2, 3, 4, 5, 6, 7], "than": [0, 4, 5, 6], "_local_scalar_dens": 0, "usual": [0, 1, 3], "mean": [0, 3, 4, 5, 6, 7], "miss": [0, 3], "feel": [0, 4], "free": [0, 4], "epoch": [0, 2, 4], "clear_al": 0, "xla_dynamo_debug": 0, "workload": [0, 3, 6, 7], "bottleneck": 0, "resourc": [0, 3], "offici": 0, "tutori": [0, 4, 7], "colab": 0, "notebook": 0, "mnist": [0, 2, 3, 6], "train": [0, 2, 3, 7], "script": [0, 3, 6], "util": [0, 2, 3, 4, 7], "captur": [0, 3], "take": [0, 3, 5, 7], "look": [0, 3], "train_resnet_benchmark": 0, "blob": [0, 4, 6, 7], "master": [0, 3, 4, 6, 7], "_": [0, 4, 6, 8], "behav": 0, "semant": [0, 5], "regular": [0, 3], "share": [0, 2, 3, 6, 7], "interfac": [0, 3, 7], "gpu": [0, 3, 7], "howev": [0, 7], "constraint": [0, 6], "hardwar": [0, 3], "lazi": [0, 5, 7, 8], "evalu": [0, 5], "suggest": 0, "certain": [0, 5], "pattern": [0, 5, 8], "result": [0, 3, 4, 6, 7], "bad": 0, "show": [0, 3, 4, 6], "mind": [0, 6], "yield": [0, 3], "degrad": 0, "recompil": [0, 1, 3], "expens": [0, 1, 5], "automat": [0, 3, 4, 5, 6, 7], "new": [0, 1, 3, 5, 7, 8], "shape": [0, 3, 7], "encount": [0, 6], "within": [0, 3, 7], "huge": [0, 4, 5], "speedup": [0, 8], "rest": [0, 5, 6], "order": [0, 2, 3, 7], "must": [0, 3, 6, 7], "constant": [0, 7], "comput": [0, 2, 3, 5, 6, 7], "across": [0, 3, 4, 6, 7], "host": [0, 2, 3, 4, 6, 7], "possibl": [0, 3, 4, 6, 7], "direct": [0, 6], "indirect": 0, "introduc": [0, 1, 4, 6, 7], "dynam": [0, 8], "mask": [0, 5], "index": [0, 3, 6], "base": [0, 1, 2, 3, 4, 5, 6, 7], "where": [0, 3, 4, 5, 6, 7], "loop": [0, 1, 3, 5, 7], "iter": [0, 3, 7, 8], "thu": [0, 2, 6], "requir": [0, 2, 3, 5, 6, 7], "solut": [0, 5], "low": 0, "variat": 0, "pad": [0, 5], "fix": [0, 7, 8], "don": [0, 1, 4, 5, 6], "t": [0, 1, 3, 4, 5, 6, 7], "nativ": [0, 1, 2, 4, 6, 7], "translat": 0, "transfer": [0, 3, 6, 7], "memori": [0, 2, 4, 5], "lead": 0, "signific": [0, 8], "slowdown": [0, 4], "item": 0, "explicitli": [0, 3, 5], "ask": [0, 1, 5], "unless": [0, 5], "necessari": [0, 3], "most": [0, 3, 6, 8], "checkout": [0, 2], "find": [0, 4, 6, 7], "even": [0, 3, 4, 5, 6], "scalar": [0, 5], "substitut": 0, "control": [0, 3, 7], "flow": 0, "applic": [0, 7], "e": [0, 3, 4, 5, 6, 7], "g": [0, 2, 3, 5, 6, 7], "clip_grad": 0, "norm": 0, "problemat": 0, "impact": [0, 3, 4, 5, 6], "so": [0, 2, 3, 4, 5, 6, 7], "patch": 0, "clip_grad_norm_": 0, "instead": [0, 1, 3, 4, 5, 6, 7, 8], "give": [0, 7], "dramat": 0, "improv": [0, 3, 6, 7, 8], "block": [0, 3, 4, 7], "els": [0, 5], "paramet": [0, 3, 6, 7], "total_norm": 0, "zero": [0, 4, 7], "none": [0, 3, 7], "p": [0, 2, 5, 6], "param_norm": 0, "grad": 0, "norm_typ": 0, "add_": 0, "clip_coef": 0, "max_norm": 0, "1e": [0, 8], "6": [0, 2, 3, 5], "mul_": 0, "data_parallel": 0, "mai": [0, 3, 5, 6, 7], "drop": 0, "last": 0, "make": [0, 1, 2, 3, 4, 5, 6, 7, 8], "sure": [0, 2, 3], "amount": [0, 3, 5], "work": [0, 3, 4, 5, 6, 7, 8], "dataset": [0, 4], "small": [0, 1, 4, 5, 8], "therefor": 0, "better": [0, 1, 3, 5, 6, 8], "those": [0, 4], "case": [0, 3, 6, 7, 8], "opaqu": [0, 3], "alwai": [0, 3, 5, 6, 7], "appear": [0, 3], "contigu": [0, 3], "without": [0, 3, 6, 7], "storag": [0, 2, 3, 4, 7], "network": [0, 3, 6, 7], "stride": 0, "move": [0, 4, 5, 6, 7], "save": [0, 4, 7], "directli": [0, 3, 4, 5, 6, 7], "load": [0, 4, 6, 7], "were": [0, 3, 5], "from": [0, 4, 7, 8], "unavail": [0, 3], "fail": [0, 3, 7], "let": [0, 3, 6, 7, 8], "machin": [0, 2, 6], "care": [0, 3, 5], "taken": [0, 3, 4, 5, 7], "type": [0, 2, 3, 4, 6], "doe": [0, 3, 5, 6, 7], "preserv": [0, 3], "view": [0, 3], "relationship": [0, 3], "reconstruct": 0, "copi": [0, 3, 6], "return": [0, 1, 3, 4, 5, 7, 8], "deep": 0, "shallow": 0, "weight": [0, 3, 7], "one": [0, 3, 4, 5, 6, 7, 8], "anoth": [0, 3, 5], "ty": 0, "done": [0, 3, 5], "otherwis": [0, 3, 5, 7], "two": [0, 3, 5, 6, 7], "independ": [0, 3, 6], "made": [0, 5, 7], "But": [0, 3, 5], "submit": 0, "addit": [0, 2, 3, 4, 6], "doesn": [0, 3, 5, 7], "_xlac": [0, 5], "_get_xla_tensors_text": [0, 5], "re": [0, 1, 3, 5, 6, 7], "_get_xla_tensors_hlo": 0, "function": [0, 1, 3, 7, 8], "prior": [0, 7], "alreadi": [0, 2, 3, 4, 5, 7], "materi": [0, 3, 5, 7], "There": [0, 1, 3, 4, 5, 7, 8], "behavior": [0, 3, 6], "stack": [0, 3, 5, 7], "degre": 0, "xla_ir_debug": 0, "trace": [0, 1, 3, 4, 5, 6, 7, 8], "node": [0, 5], "henc": [0, 8], "allow": [0, 3, 7], "wa": [0, 3, 5, 6, 7], "respons": [0, 7, 8], "xla_hlo_debug": [0, 3], "_xla_ir": 0, "activ": [0, 3, 4], "propag": 0, "metadata": 0, "xla_save_tensors_fil": 0, "path": [0, 2, 3, 4, 5], "file": [0, 2, 3, 4, 6], "becom": [0, 5, 6], "realli": [0, 5, 8], "big": [0, 5], "option": [0, 3, 6, 7], "left": 0, "long": [0, 1, 4, 5], "append": 0, "clean": [0, 8], "sheet": 0, "xla_save_tensors_fmt": 0, "format": [0, 3, 8], "store": [0, 3], "_xla_save_tensor": 0, "text": 0, "default": [0, 1, 2, 3, 4, 6, 7], "dot": 0, "graphviz": 0, "xla_flag": 0, "xla_dump_to": 0, "tmp": [0, 4], "dir_nam": 0, "unoptim": 0, "optimz": 0, "per": [0, 2, 3, 4, 6, 8], "xla_metrics_fil": 0, "local": [0, 2, 3, 6, 7], "exist": [0, 1, 3, 6, 7, 8], "xla_save_hlo_fil": 0, "error": [0, 3], "offend": 0, "xla_sync_wait": 0, "forc": [0, 5, 6], "sync": [0, 1, 2, 3], "wait": [0, 3], "its": [0, 3, 4, 6, 7, 8], "complet": [0, 3], "xla_use_eager_debug_mod": 0, "eagerli": [0, 1, 3, 5], "bypass": 0, "overal": 0, "lot": [0, 3, 5], "slower": [0, 4], "usag": [0, 2, 3, 4, 5, 7], "higher": [0, 7], "optimizaiton": 0, "skip": [0, 8], "tf_cpp_log_thread_id": 0, "tf": [0, 5], "thread": [0, 3, 6, 7], "id": [0, 2, 3, 6], "multithread": [0, 3], "process": [0, 1, 2, 4, 6, 7], "tf_cpp_vmodul": 0, "vlog": 0, "form": [0, 5, 6], "tf_cpp_min_log_level": 0, "messag": [0, 3], "turn": 0, "warn": 0, "tf_vlog": 0, "tensorflow": [0, 3, 5, 6], "xla_dump_hlo_graph": 0, "part": [0, 1, 3, 6, 7], "runtim": [0, 2, 4, 7], "rais": 0, "xla_util": 0, "cc": 0, "record": [0, 3], "save1": 0, "xla_graph_executor": 0, "pjrt_computation_cli": 0, "3": [0, 1, 2, 3, 4, 7, 8], "pr": [0, 4], "repo": [0, 3], "dir": 0, "pytorch_test_with_slow": 0, "test_torch": 0, "k": 0, "test_put_xla_uint8": 0, "command": [0, 2, 3, 4, 6], "need": [0, 2, 3, 4, 5, 6, 7], "torch_test_devic": 0, "pytorch_test_bas": 0, "doc": [1, 2, 5, 6, 7], "go": [1, 2, 3, 7], "over": [1, 2, 3, 4, 6, 7], "pytorch": [1, 5, 6], "xla": [1, 5, 6], "experiment": [1, 4, 6, 7, 8], "The": [1, 2, 3, 4, 5, 6, 7, 8], "goal": 1, "experi": [1, 4, 6, 7], "more": [1, 2, 3, 5, 6, 7], "align": 1, "develop": [1, 3, 4, 7, 8], "easier": [1, 5], "current": [1, 2, 3, 4, 5, 6, 7, 8], "run": [1, 4, 5, 6, 8], "lazytensor": [1, 3], "torchvis": [1, 8], "resnet18": [1, 8], "randn": [1, 3, 4, 6, 7, 8], "64": [1, 4, 8], "224": 1, "execut": [1, 2, 3, 4, 5, 6, 7, 8], "actual": [1, 4, 5, 7], "multipl": [1, 5, 8], "drawback": 1, "approach": [1, 4, 5], "confus": 1, "about": [1, 3, 5, 6], "framework": [1, 3, 5], "non": [1, 5, 7], "data": [1, 2, 3, 5, 6, 8], "preprocess": 1, "pend": [1, 3], "get": [1, 2, 3, 4, 5, 6], "leak": 1, "main": [1, 3, 6, 7], "whole": [1, 3, 5, 8], "veri": [1, 2, 3, 5], "It": [1, 2, 3, 4, 5, 7, 8], "hard": [1, 4, 5, 8], "debug": [1, 5], "why": [1, 5], "mitig": 1, "ux": 1, "eager_mod": [1, 3], "true": [1, 3, 4, 5, 6, 7], "mark": [1, 3], "compiled_model": 1, "right": [1, 5, 8], "awai": 1, "ha": [1, 3, 5, 6, 7], "wrap": [1, 3, 4, 7], "pretti": [1, 3, 4, 5], "straight": 1, "forward": [1, 4, 7, 8], "enter": 1, "target": [1, 3, 5, 6, 8], "reenabl": 1, "perfomr": 1, "backend": [1, 3, 5, 6, 7, 8], "openxla": [1, 8], "recommen": 1, "overhad": 1, "def": [1, 3, 4, 6, 7, 8], "step_fn": 1, "loss_fn": [1, 3, 4, 6, 8], "zero_grad": [1, 3, 4, 6], "logit": [1, 7], "loss": [1, 2, 3, 4, 6, 7, 8], "backward": [1, 3, 4, 6, 7, 8], "refactor": 1, "becaus": [1, 3, 6, 7], "togeth": [1, 3, 4, 6, 7], "now": [1, 3, 5, 6, 7], "recommend": [1, 2, 3, 6, 7], "reason": [1, 4, 6], "layer": [1, 4, 7], "decod": 1, "much": [1, 3, 5, 6, 8], "just": [1, 3, 4, 5, 6, 7], "llama2": 1, "fake": [1, 7], "singl": [1, 4, 5, 7, 8], "chip": [1, 6], "v4": [1, 3, 6, 7, 8], "8": [1, 2, 3, 5, 6, 7, 8], "below": [1, 2, 5, 6, 7], "observ": [1, 4, 6], "token": 1, "147": 1, "achiev": [1, 4], "45": [1, 2], "perform": [1, 3, 4, 7, 8], "trainer": 1, "test": [1, 2, 4, 6], "found": [1, 2, 6], "here": [1, 2, 3, 4, 5, 7, 8], "perfomran": 1, "depend": [1, 3, 5], "tri": 1, "resnet50": [1, 3, 6, 8], "exepct": 1, "meant": 1, "logic": [1, 3, 5, 7], "random": [1, 3, 6], "compil": [2, 5, 6], "acceler": [2, 3, 6], "basic": [2, 4, 5], "nvidia": 2, "attach": [2, 7], "cloud": [2, 3, 6, 7, 8], "googl": [2, 3, 6], "cuda": [2, 3, 5, 6], "driver": 2, "publish": 2, "prebuilt": 2, "imag": [2, 4, 5, 6], "cuda11": 2, "correspond": [2, 3, 4, 7], "config": 2, "list": [2, 3, 7], "refer": [2, 3, 4, 6, 7], "sudo": [2, 6], "pull": [2, 4], "central1": 2, "pkg": 2, "dev": [2, 4], "nightly_3": 2, "8_cuda_12": 2, "toolkit": 2, "datacent": 2, "latest": 2, "guid": [2, 3, 4, 6], "html": [2, 4, 6], "curl": 2, "fssl": 2, "io": [2, 4, 6], "libnvidia": 2, "gpgkei": 2, "gpg": 2, "dearmor": 2, "o": [2, 4, 6], "usr": 2, "keyr": 2, "l": 2, "stabl": [2, 4, 6], "deb": 2, "sed": 2, "sign": 2, "tee": 2, "apt": 2, "d": [2, 3, 5], "updat": [2, 3, 5, 7], "ctk": 2, "systemctl": 2, "restart": [2, 6], "shm": 2, "16g": 2, "net": [2, 6], "bin": 2, "bash": [2, 4], "exec": 2, "awk": 2, "nr": 2, "visibl": [2, 3, 5], "smi": 2, "verifi": 2, "root": [2, 3, 5], "20ab2c7a2d06": 2, "dec": 2, "06": 2, "24": 2, "29": [2, 4, 8], "2022": 2, "510": 2, "47": 2, "03": 2, "version": [2, 6, 7], "persist": [2, 3, 7], "m": [2, 4, 5], "bu": 2, "disp": 2, "A": [2, 3, 5, 6, 7], "volatil": 2, "uncorr": 2, "ecc": 2, "fan": 2, "temp": 2, "perf": [2, 5], "pwr": 2, "cap": 2, "mig": 2, "tesla": 2, "v100": 2, "sxm2": 2, "off": 2, "00000000": 2, "00": [2, 4], "04": [2, 8], "n": [2, 3], "36c": 2, "p0": 2, "38w": 2, "300w": 2, "0mib": 2, "16384mib": 2, "gi": 2, "ci": 2, "pid": 2, "No": [2, 5, 6], "ld_library_path": 2, "account": 2, "echo": 2, "link": 2, "bashrc": 2, "lib64": 2, "compat": [2, 3, 6, 7], "x86_64": 2, "linux": 2, "architecutr": 2, "architectur": [2, 4, 6], "system": [2, 7], "unam": 2, "pip3": 2, "whl": 2, "googleapi": 2, "cp310": 2, "manylinux_2_28_x86_64": 2, "repositori": [2, 6], "imagenet": 2, "what": [2, 3], "gpu_num_devic": [2, 6], "recurs": [2, 4, 7], "prepar": 2, "begin": [2, 7], "38": 2, "89059": 2, "82": 2, "globalr": 2, "13": [2, 3, 4, 6], "79297": 2, "117": 2, "16": [2, 3, 4, 7], "84": 2, "36": 2, "40": [2, 4], "43628": 2, "281": 2, "49": [2, 8], "43": [2, 8], "60": [2, 4], "83108": 2, "346": 2, "88": [2, 7], "108": 2, "99023": 2, "373": 2, "62": [2, 8], "132": 2, "56": 2, "92699": 2, "384": 2, "152": 2, "14": 2, "02": [2, 4], "120": 2, "68816": 2, "388": 2, "169": 2, "09": 2, "train_resnet_bas": 2, "35pm": 2, "utc": 2, "jun": 2, "08": 2, "2024": 2, "887794017791748": 2, "746502586051985": 2, "877807140350342": 2, "238": 2, "4789458412044": 2, "867819786071777": 2, "329": 2, "86095958663503": 2, "30": [2, 4, 6], "857839584350586": 2, "367": 2, "3038003653586": 2, "847847938537598": 2, "381": 2, "53141087190835": 2, "837860584259033": 2, "387": 2, "80462249591113": 2, "260": 2, "628140926361084": 2, "391": 2, "135639565343": 2, "270": 2, "618192195892334": 2, "6901797745233": 2, "280": 2, "608224391937256": 2, "1602680460045": 2, "290": 2, "598264217376709": 2, "6731498290759": 2, "36pm": 2, "reus": [2, 3], "rule": 2, "modifi": [2, 7, 8], "insid": [2, 7], "cd": [2, 4], "use_cuda": 2, "bdist_wheel": 2, "hermet": 2, "xla_cuda": 2, "been": [2, 3, 5, 6, 7], "successfulli": 2, "packag": [3, 4], "learn": [3, 6], "connect": [3, 6, 7], "troubleshoot": 3, "eager": [3, 4, 5], "mode": [3, 4, 5], "distributeddataparallel": [3, 6], "ddp": [3, 6], "pjrt": [3, 7], "shard": 3, "fsdp": 3, "via": [3, 4, 6], "advanc": 3, "topic": 3, "checkpoint": [3, 4, 6], "torchdynamo": 3, "integr": 3, "describ": [3, 4, 7], "familiar": [3, 7], "initi": [3, 4, 6, 7], "environ": [3, 4, 6, 7], "ad": [3, 5, 7, 8], "t0": 3, "Or": [3, 5, 6], "matrix": 3, "multipli": [3, 7], "mm": 3, "neural": 3, "l_in": 3, "linear": [3, 4, 6], "nn": [3, 4, 6, 7, 8], "l_out": 3, "floattensor": 3, "throw": 3, "build": [3, 4], "convert": [3, 4], "specif": [3, 4], "snippet": [3, 7], "highlight": 3, "nllloss": 3, "sgd": [3, 4, 6, 8], "lr": [3, 4, 6, 7, 8], "momentum": 3, "train_load": [3, 7], "easi": [3, 5, 6], "definit": [3, 5], "dataload": [3, 4, 7], "acquir": 3, "pl": [3, 6, 7], "_mp_fn": [3, 6], "mp_device_load": 3, "mpdeviceload": [3, 7], "optimizer_step": [3, 4], "__name__": [3, 4, 6], "__main__": [3, 4, 6], "launch": [3, 4, 6, 8], "arg": [3, 4], "three": 3, "previou": [3, 5, 6], "wrapper": [3, 4, 7], "spawn": [3, 6], "torchrun": [3, 6], "abl": [3, 5, 7], "assign": 3, "being": [3, 4, 7], "up": [3, 5, 6, 7], "own": [3, 4], "v2": 3, "v3": 3, "check": [3, 7], "onto": 3, "preload": 3, "overlap": [3, 7, 8], "batches_per_execut": 3, "consolid": [3, 4], "gradient": [3, 4], "all_reduce_gradi": 3, "remain": [3, 5], "retriev": [3, 5, 7, 8], "parent": 3, "multiprocess": [3, 6], "setup": [3, 4], "talk": 3, "bit": 3, "basi": 3, "gcloud": [3, 6], "project": [3, 6], "howto": 3, "focu": [3, 5], "perspect": [3, 6], "assum": [3, 4, 5, 7], "train_mnist_xla": 3, "ssh": [3, 6], "tpuvm": [3, 6, 7], "scp": [3, 6], "alpha": [3, 6], "zone": [3, 6], "worker": [3, 4, 6, 7], "outsid": 3, "underli": 3, "infrastructur": 3, "awar": 3, "global": [3, 6, 7], "topologi": [3, 7], "ordin": 3, "cross": [3, 7], "commun": [3, 6, 7, 8], "regard": [3, 8], "fakedata": 3, "though": [3, 4], "act": 3, "uniqu": [3, 5], "immedi": [3, 7], "hand": 3, "until": [3, 7], "defer": 3, "separ": [3, 4, 7, 8], "fuse": 3, "invis": 3, "caller": 3, "construct": [3, 4, 7], "send": [3, 6, 7], "synchron": [3, 6, 7], "insert": 3, "barrier": [3, 6], "design": [3, 6, 7, 8], "paper": 3, "represent": [3, 7], "expos": [3, 6, 7], "unlik": 3, "adjust": 3, "wai": [3, 4, 5, 6, 7, 8], "again": 3, "appreci": 3, "accommod": 3, "transit": 3, "recreat": 3, "destin": 3, "previous": 3, "state_dict": [3, 4, 7], "limit": [3, 6], "footprint": 3, "serial": [3, 6], "xser": 3, "stream": 3, "restor": [3, 7], "load_state_dict": [3, 7], "under": [3, 4, 6], "consum": [3, 5], "disk": 3, "significantli": [3, 6], "still": [3, 4, 5, 6, 7], "occur": 3, "opt": 3, "through": [3, 5, 7], "initialize_cach": 3, "xr": [3, 4, 6, 7], "your_cache_path": 3, "readonli": 3, "fals": [3, 4, 7], "specifi": [3, 4], "whether": 3, "write": [3, 7], "mount": 3, "int": [3, 5, 6, 7], "instanc": [3, 4, 7], "virtual": [3, 7], "device_count": [3, 7], "address": [3, 6, 7], "f": [3, 4, 7], "callabl": [3, 4], "full_graph": 3, "repres": [3, 5, 6], "funciton": 3, "pass": [3, 4, 6, 7], "manag": [3, 7], "bool": 3, "foo": 3, "sin": 3, "co": 3, "foo2": 3, "compiled_foo2": 3, "manual_se": [3, 6], "seed": 3, "state": [3, 4, 7], "rng": [3, 6], "device_typ": 3, "str": 3, "select": [3, 6, 7], "local_process_count": 3, "local_device_count": 3, "total": [3, 5, 7], "addressable_device_count": 3, "global_device_count": 3, "global_runtime_device_count": [3, 7], "especi": [3, 6, 7, 8], "world_siz": [3, 4, 6, 7], "particip": [3, 6], "job": [3, 8], "global_ordin": [3, 4, 6], "rang": [3, 6, 7], "guarante": 3, "predict": 3, "nor": 3, "local_ordin": 3, "get_master_ip": 3, "ip": [3, 6, 7], "discoveri": 3, "string": [3, 7], "use_spmd": [3, 7], "auto": [3, 4], "is_spmd": 3, "devkind": 3, "custom": [3, 4, 5, 7], "deprec": 3, "xla_device_hw": 3, "map": 3, "real": [3, 8], "is_master_ordin": 3, "replic": [3, 7], "while": [3, 4, 5], "num_host": 3, "boolean": 3, "all_reduc": 3, "reduce_typ": 3, "scale": [3, 6, 7, 8], "group": [3, 4, 6, 7], "pin_layout": 3, "inplac": [3, 7], "One": [3, 4], "reduce_sum": 3, "reduce_mul": 3, "reduce_and": 3, "reduce_or": 3, "reduce_min": 3, "reduce_max": 3, "float": [3, 5], "appli": [3, 4, 7], "replica": [3, 6], "defin": [3, 7], "pin": 3, "pine": 3, "prevent": [3, 7, 8], "corrupt": 3, "slightli": 3, "unpin": 3, "hlomodul": 3, "mix": [3, 7], "constrain": [3, 6], "hold": [3, 7], "tupl": [3, 5, 7], "itself": [3, 4], "all_gath": [3, 6], "dim": 3, "gather": [3, 7], "along": [3, 4], "dimens": [3, 7], "all_to_al": 3, "split_dimens": 3, "concat_dimens": 3, "split_count": 3, "alltoal": 3, "www": 3, "org": [3, 4, 6], "operation_semant": 3, "upon": 3, "split": 3, "concat": 3, "add_step_closur": 3, "closur": 3, "run_async": 3, "ones": [3, 5], "report": 3, "consol": 3, "tensorboard": 3, "content": 3, "intermediari": 3, "inspect": 3, "point": [3, 5], "typic": 3, "ensur": [3, 5, 7], "live": [3, 5], "argument": [3, 4, 8], "queu": 3, "sequenti": 3, "advis": 3, "throttl": 3, "event": 3, "asynchron": [3, 7], "wait_device_op": 3, "async": [3, 8], "whose": 3, "empti": 3, "optimizer_arg": 3, "parallelload": [3, 7], "dataparallel": 3, "support": [3, 4, 5, 6, 7, 8], "dict": [3, 4], "dictionari": 3, "file_or_path": 3, "master_onli": [3, 4], "global_mast": 3, "nest": [3, 4], "combin": [3, 5], "object": [3, 7], "overrid": 3, "locat": 3, "flag": 3, "hang": 3, "rendezv": 3, "tag": [3, 6], "payload": [3, 6], "b": [3, 5, 6, 7, 8], "mesh": [3, 6], "client": [3, 6], "reach": 3, "xrt": 3, "server": [3, 6], "effect": 3, "alia": 3, "xla_rendezv": 3, "join": 3, "byte": 3, "exchang": 3, "posit": 3, "mesh_reduc": 3, "reduce_fn": 3, "reduct": 3, "receiv": 3, "come": [3, 5], "set_rng_stat": 3, "get_rng_stat": 3, "get_memory_info": 3, "memoryinfo": 3, "get_stablehlo": 3, "stablehlo": 3, "todo": 3, "lsy323": 3, "investig": [3, 4], "infer": [3, 6, 7], "straightforward": 3, "identifi": [3, 7], "env": [3, 6, 7], "var": [3, 7], "get_stablehlo_bytecod": 3, "bytecod": [3, 8], "class": [3, 4, 7], "batchdim": 3, "loader_prefetch_s": 3, "device_prefetch_s": 3, "host_to_device_transfer_thread": 3, "input_shard": [3, 7], "background": [3, 7], "upload": [3, 7], "th": [3, 7], "len": 3, "max": [3, 5, 7], "capac": 3, "queue": 3, "deposit": 3, "shardingspec": [3, 7], "spec": 3, "per_device_load": [3, 7], "structur": [3, 4, 7], "resid": 3, "xla_multiprocess": 3, "fn": 3, "nproc": [3, 6], "daemon": 3, "start_method": 3, "At": 3, "moment": 3, "maximum": 3, "creation": 3, "method": [3, 6, 7], "mark_shard": [3, 7], "union": 3, "xlashardedtensor": 3, "partition_spec": [3, 7], "annot": [3, 7], "partit": 3, "xlatensor": [3, 7], "spmdpartition": [3, 7], "param": 3, "device_mesh": [3, 7], "axi": [3, 7], "rank": [3, 4, 6, 7], "mesh_shap": [3, 7], "ax": [3, 7], "row": 3, "wise": 3, "8x10": 3, "column": 3, "dynamo_custom_op": 3, "dynamo": [3, 8], "variant": [3, 5], "recogniz": 3, "traceabl": 3, "num_devic": [3, 7], "device_id": [3, 7], "np": [3, 7], "arrai": [3, 7], "clear_shard": 3, "clear": 3, "cast": 3, "place": [3, 7], "get_1d_mesh": 3, "set_global_mesh": 3, "get_global_mesh": 3, "axis_nam": [3, 7], "helper": 3, "ndarrai": 3, "ravel": 3, "reshap": 3, "fill": 3, "element": [3, 5, 7], "sequenc": 3, "Its": 3, "length": [3, 5], "get_xla_supported_devic": 3, "get_logical_mesh": 3, "ordereddict": [3, 7], "hybridmesh": [3, 7], "ici_mesh_shap": [3, 7], "dcn_mesh_shap": [3, 7], "hybrid": 3, "ici": 3, "dcn": [3, 7], "increas": 3, "intens": 3, "mdl": 3, "inner": [3, 4, 7], "outer": [3, 4, 7], "slice": [3, 7], "metric": [3, 4], "counter_nam": 3, "metric_nam": 3, "counter_valu": 3, "metric_data": 3, "total_sampl": 3, "accumul": 3, "retain": 3, "circular": 3, "buffer": 3, "sum": [3, 4, 7], "document": [4, 6], "further": 4, "against": 4, "minimum": [4, 7], "runnabl": [4, 7], "abil": [4, 5], "api": [4, 5, 6, 7, 8], "And": [4, 5, 7], "who": 4, "know": [4, 5], "xla_backend": [4, 6, 7], "init": [4, 6, 8], "similar": [4, 6], "nccl": 4, "gloo": [4, 6, 7], "dist": [4, 6, 7], "init_process_group": [4, 6, 7], "new_rank": 4, "gradient_as_bucket_view": [4, 6], "ddp_model": 4, "final": [4, 7], "launcher": 4, "demo_fn": 4, "everyth": [4, 5], "touch": [4, 7], "plu": 4, "five": 4, "sy": 4, "tempfil": 4, "master_addr": [4, 6], "localhost": [4, 6], "master_port": [4, 6], "12355": [4, 6], "cleanup": 4, "destroy_process_group": 4, "toymodel": 4, "__init__": 4, "self": [4, 7], "super": [4, 8], "net1": 4, "1000000": 4, "relu": 4, "net2": 4, "demo_bas": 4, "assert": 4, "graident_as_bucket_view": 4, "mseloss": [4, 6], "001": [4, 6], "label": 4, "run_demo": 4, "collect": [4, 6, 7, 8], "num_epoch": [4, 6], "tot": 4, "statist": 4, "produc": [4, 5], "unit": 4, "median": 4, "90th": 4, "std": 4, "cv": 4, "418": 4, "54": 4, "419": 4, "22": 4, "430": 4, "9": [4, 5], "76": 4, "97": 4, "407": 4, "39": 4, "seem": 4, "extra": [4, 7], "overhead": [4, 6, 8], "test_train_mp_mnist": [4, 6], "17864": 4, "19": [4, 8], "20108": 4, "96": 4, "24351": 4, "74": 4, "5866": 4, "83": 4, "10701": 4, "11770": 4, "14313": 4, "78": 4, "3102": 4, "92": 4, "41": [4, 8], "round": 4, "heavili": [4, 8], "sens": 4, "amort": 4, "logdir": 4, "converg": 4, "high": 4, "accuraci": 4, "caution": 4, "interest": 4, "known": 4, "enforc": [4, 5], "crash": 4, "xlafullyshardeddataparallel": 4, "my_modul": [4, 7], "adam": [4, 7], "0001": [4, 7], "individu": [4, 7], "leftov": [4, 7], "both": [4, 5, 6, 7, 8], "arxiv": 4, "ab": 4, "1910": 4, "02054": 4, "reshard_after_forward": 4, "test_train_mp_mnist_fsdp_with_ckpt": 4, "test_train_mp_imagenet_fsdp": 4, "larg": [4, 5, 6, 7], "cannot": [4, 5], "fit": 4, "interleav": 4, "submodul": 4, "fsdpvitmodel": 4, "ronghanghu": 4, "vit_10b_fsdp_exampl": 4, "run_vit_train": 4, "simpl": [4, 6, 7], "checkpoint_modul": [4, 7], "3524": 4, "auto_wrap_polici": [4, 7], "size_based_auto_wrap_polici": 4, "polici": [4, 7], "larger": [4, 8], "100m": 4, "transformer_auto_wrap_polici": [4, 7], "transform": [4, 7], "conv2d": 4, "partial": [4, 5, 7], "transformer_layer_cl": [4, 7], "addition": 4, "auto_wrapper_cal": 4, "remateri": 4, "lambda": 4, "kwarg": [4, 7], "latter": 4, "resum": 4, "get_shard_metadata": 4, "consolidate_sharded_model_checkpoint": 4, "stitch": 4, "ckpt": 4, "shard_metadata": 4, "ckpt_path": 4, "pth": 4, "tool": 4, "consolidate_sharded_ckpt": 4, "ckpt_prefix": 4, "your_sharded_checkpoint_fil": 4, "ckpt_suffix": 4, "_rank": 4, "inspir": 4, "mostli": [4, 6], "fairscal": 4, "fullyshardeddataparallel": 4, "readthedoc": 4, "en": 4, "biggest": [4, 8], "explicit": 4, "resort": 4, "train_resnet_fsdp_auto_wrap": 4, "newer": 4, "wheel": 4, "around": [4, 5, 6], "98": 4, "batch_siz": [4, 6], "drop_last": 4, "use_nested_fsdp": 4, "use_gradient_checkpoint": 4, "final_ckpt": 4, "75": 4, "download": 4, "1k": 4, "datadir": 4, "test_set_batch_s": 4, "eval_interv": 4, "128": [4, 6], "num_warmup_epoch": 4, "lr_scheduler_divide_every_n_epoch": 4, "lr_scheduler_divisor": 4, "residu": 4, "entir": [4, 6], "algorithm": [4, 7], "vision": 4, "vit": 4, "static": 5, "word": 5, "hurt": 5, "understand": 5, "normal": [5, 6, 7], "pov": 5, "sai": 5, "assur": 5, "magic": 5, "gone": 5, "good": [5, 7], "coverag": 5, "aim": [5, 7], "explan": 5, "common": [5, 6, 7], "rid": 5, "mainli": 5, "problem": 5, "beginn": 5, "propos": 5, "reli": 5, "impract": 5, "assumpt": 5, "ye": 5, "sentenc": 5, "vari": [5, 6, 7], "ll": 5, "bucket": [5, 7], "kinda": 5, "anti": 5, "frontend": 5, "matter": 5, "workaround": 5, "okai": 5, "teach": 5, "practic": [5, 6, 7], "enough": 5, "theoret": 5, "trade": 5, "less": [5, 6, 8], "faster": [5, 6, 8], "speed": [5, 8], "well": [5, 6, 7], "sort": 5, "obviou": 5, "shown": [5, 6], "s64": 5, "num_output": 5, "mul": 5, "although": [5, 6], "inde": 5, "_get_xla_tensor_dimension_s": 5, "commonli": 5, "dtype": [5, 6], "cut": 5, "correct": 5, "wrong": 5, "wors": 5, "probabl": 5, "upper": 5, "nit": 5, "simplic": 5, "rand": 5, "solv": 5, "world": [5, 6, 7, 8], "kept": 5, "earli": 5, "accessor": 5, "2d": [5, 7], "implicitli": 5, "doubl": 5, "overload": 5, "easili": [5, 8], "explod": 5, "convers": 5, "cheap": 5, "ve": 5, "hoc": 5, "think": 5, "verison": 5, "bla": 5, "blabla": 5, "interpret": 5, "proce": 5, "choic": 5, "wide": 5, "adopt": 5, "uglier": 5, "win": 5, "pars": 5, "statement": 5, "torchscript": 5, "somehow": 5, "merg": 5, "lazili": [5, 7], "properli": 5, "haven": 5, "thought": 5, "trivial": 5, "effort": [5, 7], "side": 5, "That": 5, "hit": 5, "bandwidth": 5, "automag": 5, "gold": 5, "smart": 5, "trick": 5, "tbh": 5, "longer": 5, "sometim": 5, "unawar": 5, "hope": 5, "smash": 5, "ideal": [5, 8], "blocker": 5, "ahead": 5, "nnc": 5, "symbol": 5, "By": [5, 6], "concret": 5, "kernel": 5, "exactli": 5, "transpos": 5, "With": [5, 6, 8], "brian": 5, "hirsh": 5, "bdhirsh": 5, "question": 5, "comment": 5, "worth": 5, "stick": 5, "torch_warn": 5, "yea": 5, "tell": 5, "hei": 5, "won": 5, "blaze": 5, "fast": 5, "isn": [5, 7], "rewrit": [5, 7], "devirtu": 5, "v": [5, 6], "sound": 5, "great": 5, "carri": [5, 7], "truth": 5, "As": [5, 7], "irvalu": 5, "discrep": 5, "followup": 5, "mention": [5, 8], "1000": 5, "my": [5, 7], "properti": 5, "presenc": 5, "get_dimention_s": 5, "didn": 5, "altern": [5, 6], "condit": 5, "middl": [5, 6], "exponenti": 5, "blowup": 5, "smaller": 5, "fewer": 5, "opportun": 5, "recogn": [5, 8], "could": [5, 6, 7], "break": 5, "feasibl": 5, "annoi": 5, "z": 5, "subgraph": 5, "variabl": [5, 6], "wasn": 5, "materiz": 5, "involv": [5, 7], "combo": 5, "migrat": 6, "jax": 6, "public": 6, "renam": 6, "regist": [6, 7], "init_method": [6, 7], "plugin": 6, "xpu": 6, "neuron": 6, "continu": [6, 8], "xrt_tpu_config": 6, "libtpu": 6, "thousand": 6, "preview": 6, "On": [6, 7], "safe": 6, "broadcast": 6, "broadcast_master_param": 6, "pjrt_backend": 6, "These": [6, 7], "diff": 6, "42": 6, "confirm": 6, "localservic": 6, "51011": 6, "grpc": 6, "torchbench": 6, "2048": 6, "read": 6, "central2": 6, "256": 6, "tpu_process_bound": 6, "tpu_visible_chip": 6, "r1": 6, "preinstal": 6, "docker_imag": 6, "gcr": 6, "authent": 6, "privat": 6, "gcp": 6, "auth": 6, "rm": 6, "privileg": 6, "simpli": 6, "nnode": 6, "num_gpu_devic": 6, "pjrt_distribut": 6, "physic": [6, 7], "number_gpu_vm": 6, "node_rank": 6, "current_node_rank": 6, "nproc_per_nod": 6, "number_local_gpu_devic": 6, "rdzv_endpoint": 6, "internal_ip_address": 6, "port": 6, "multinode_train": 6, "endpoint": 6, "omit": [6, 7], "machine_0": 6, "machine_1": 6, "machine_0_internal_ip_address": 6, "ident": 6, "page": 6, "interchang": 6, "subtl": 6, "importantli": 6, "latenc": 6, "deseri": 6, "gain": 6, "interact": 6, "profil": 6, "plan": 6, "simpler": 6, "xla_dist": 6, "sdk": 6, "reimplement": 6, "enhanc": 6, "xmp": 6, "substanti": 6, "consist": 6, "servic": 6, "unreli": 6, "inbound": 6, "failur": 6, "impos": 6, "unwant": 6, "permit": 6, "subset": 6, "old": 6, "alter": 6, "consid": 6, "all_gather_object": 6, "new_group": 6, "subgroup": 6, "reliabl": 6, "md": 6, "strongli": 6, "queri": 6, "_all_gath": 6, "tensor": [6, 7, 8], "int32": 6, "zeros_lik": 6, "get_world_s": 6, "averag": 6, "task": 6, "175": 6, "chart": 6, "breakdown": 6, "tfrt": 6, "legaci": 6, "streamexecutor": 6, "tpu_legaci": 6, "comparison": [6, 7], "discuss": 7, "gspmd": 7, "overview": 7, "illustr": 7, "ml": 7, "proper": 7, "hint": 7, "figur": 7, "strategi": 7, "numpi": 7, "concept": 7, "librari": 7, "cluster": 7, "interconnect": 7, "almost": 7, "encourag": 7, "fist": 7, "express": 7, "paral": 7, "fsdpv2": 7, "famou": 7, "offer": 7, "enjoi": 7, "benefit": 7, "bring": 7, "tabl": 7, "review": 7, "proceed": 7, "spmd_fully_sharded_data_parallel": 7, "spmdfullyshardeddataparallel": 7, "autowrap": 7, "decoderlay": 7, "functool": 7, "decoder_only_model": 7, "shard_output": 7, "fall": 7, "categori": 7, "0th": 7, "children": 7, "infinit": 7, "fork": 7, "hf": 7, "demonstr": 7, "cover": 7, "proced": 7, "src": 7, "_input_sharding_": 7, "4d": 7, "input_mesh": 7, "_after": 7, "_the": 7, "unnecessari": 7, "forth": 7, "techniqu": 7, "decis": 7, "nice": 7, "abstract": 7, "arrang": 7, "center": 7, "box": 7, "multislic": 7, "accept": 7, "denot": 7, "hardcod": 7, "rfc": 7, "delai": 7, "except": 7, "satisfi": 7, "subclass": 7, "__torch_dispatch__": 7, "invok": [7, 8], "global_tensor": 7, "special": 7, "strictli": 7, "local_shard": 7, "xlashard": 7, "4e8e5511555073ce8b6d1a436bf808c9333dcac6": 7, "xla_sharded_tensor": 7, "l12": 7, "ongo": 7, "distributedtensor": 7, "prototyp": 7, "proof": 7, "distribute_tensor": 7, "devicemesh": 7, "big_tensor": 7, "100000": 7, "my_dtensor": 7, "stai": 7, "tune": 7, "upcom": [7, 8], "dynamo_mark_shard": 7, "placement": 7, "visual": 7, "multi": 7, "visualize_tensor_shard": 7, "visualize_shard": 7, "rich": 7, "2x2": 7, "generated_t": 7, "use_color": 7, "style": 7, "tile": 7, "partial_repl": 7, "envvar": 7, "xla_auto_spmd": 7, "_tensor": 7, "distribute_modul": 7, "auto_polici": 7, "mymodul": 7, "sharded_model": 7, "behvaior": 7, "xla_auto_use_group_shard": 7, "reshard": 7, "xla_auto_spmd_mesh": 7, "unset": 7, "dedic": 7, "planner": 7, "spmdsaveplann": 7, "spmdloadplann": 7, "dist_cp": 7, "distributed_checkpoint": 7, "xc": 7, "storage_writ": 7, "filesystemwrit": 7, "checkpoint_dir": 7, "desir": 7, "storage_read": 7, "filesystemread": 7, "checkpointmanag": 7, "all_step": 7, "save_async": 7, "written": 7, "unblock": 7, "durat": 7, "dispatch": 7, "preemption": 7, "detect": 7, "termin": 7, "provis": 7, "queuedresourc": 7, "autocheckpoint": 7, "chkpt_on_preempt": 7, "fsspec": 7, "filesystem": 7, "gc": 7, "prime_optim": 7, "chkpt_mgr": 7, "tracked_step": 7, "choos": 7, "highest": 7, "best_step": 7, "prime": 7, "enumer": 7, "present": 7, "attempt": 7, "unprim": 7, "destruct": 7, "discov": 7, "jit": 8, "unmodifi": 8, "hook": 8, "bridg": 8, "torchfx": 8, "technologi": 8, "fx": 8, "a_xla": 8, "b_xla": 8, "compiled_cod": 8, "eval_model": 8, "xla_resnet18": 8, "eval": 8, "dynamo_resnet18": 8, "no_grad": 8, "resent18": 8, "binari": 8, "analysi": 8, "bench": 8, "59": 8, "resnext50_32x4d": 8, "91": 8, "alexnet": 8, "28": 8, "mobilenet_v2": 8, "18": 8, "mnasnet1_0": 8, "68": 8, "vgg16": 8, "bert_pytorch": 8, "squeezenet1_1": 8, "timm_vision_transform": 8, "52": 8, "geomean": 8, "team": 8, "train_model": 8, "crossentropyloss": 8, "pred": 8, "train_model_main": 8, "dynamo_train_model": 8, "xla_optim": 8, "weight_decai": 8, "extract": 8, "07": 8, "81": 8, "87": 8, "fwd": 8, "bwd": 8, "e2": 8, "hide": 8, "cost": 8, "scenario": 8, "best": 8, "promis": 8, "complex": 8, "tradit": 8, "seen": 8, "expand": 8, "excit": 8, "invest": 8, "upstream": 8, "matur": 8, "stori": 8}, "objects": {"": [[3, 0, 0, "-", "torch_xla"]], "torch_xla": [[3, 1, 1, "", "compile"], [3, 1, 1, "", "device"], [3, 1, 1, "", "device_count"], [3, 1, 1, "", "devices"], [3, 0, 0, "-", "experimental"], [3, 1, 1, "", "manual_seed"], [3, 0, 0, "-", "runtime"], [3, 1, 1, "", "sync"]], "torch_xla.core": [[3, 0, 0, "-", "xla_model"]], "torch_xla.core.xla_model": [[3, 1, 1, "", "add_step_closure"], [3, 1, 1, "", "all_gather"], [3, 1, 1, "", "all_reduce"], [3, 1, 1, "", "all_to_all"], [3, 1, 1, "", "get_memory_info"], [3, 1, 1, "", "get_rng_state"], [3, 1, 1, "", "get_stablehlo"], [3, 1, 1, "", "get_stablehlo_bytecode"], [3, 1, 1, "", "is_master_ordinal"], [3, 1, 1, "", "mesh_reduce"], [3, 1, 1, "", "optimizer_step"], [3, 1, 1, "", "rendezvous"], [3, 1, 1, "", "save"], [3, 1, 1, "", "set_rng_state"], [3, 1, 1, "", "wait_device_ops"], [3, 1, 1, "", "xla_device"], [3, 1, 1, "", "xla_device_hw"]], "torch_xla.debug": [[3, 0, 0, "-", "metrics"]], "torch_xla.debug.metrics": [[3, 1, 1, "", "counter_names"], [3, 1, 1, "", "counter_value"], [3, 1, 1, "", "metric_data"], [3, 1, 1, "", "metric_names"], [3, 1, 1, "", "metrics_report"], [3, 1, 1, "", "short_metrics_report"]], "torch_xla.distributed": [[3, 0, 0, "-", "parallel_loader"], [3, 0, 0, "-", "spmd"], [3, 0, 0, "-", "xla_multiprocessing"]], "torch_xla.distributed.parallel_loader": [[3, 2, 1, "", "ParallelLoader"]], "torch_xla.distributed.parallel_loader.ParallelLoader": [[3, 3, 1, "", "per_device_loader"]], "torch_xla.distributed.spmd": [[3, 2, 1, "", "HybridMesh"], [3, 2, 1, "", "Mesh"], [3, 1, 1, "", "clear_sharding"], [3, 1, 1, "", "get_1d_mesh"], [3, 1, 1, "", "get_global_mesh"], [3, 1, 1, "", "mark_sharding"], [3, 1, 1, "", "set_global_mesh"]], "torch_xla.distributed.xla_multiprocessing": [[3, 1, 1, "", "spawn"]], "torch_xla.experimental": [[3, 1, 1, "", "eager_mode"]], "torch_xla.runtime": [[3, 1, 1, "", "addressable_device_count"], [3, 1, 1, "", "device_type"], [3, 1, 1, "", "get_master_ip"], [3, 1, 1, "", "global_device_count"], [3, 1, 1, "", "global_ordinal"], [3, 1, 1, "", "global_runtime_device_count"], [3, 1, 1, "", "initialize_cache"], [3, 1, 1, "", "is_spmd"], [3, 1, 1, "", "local_device_count"], [3, 1, 1, "", "local_ordinal"], [3, 1, 1, "", "local_process_count"], [3, 1, 1, "", "use_spmd"], [3, 1, 1, "", "world_size"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class", "3": "py:method"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"], "3": ["py", "method", "Python method"]}, "titleterms": {"troubleshoot": 0, "saniti": 0, "check": [0, 2], "pytorch": [0, 2, 3, 4, 7, 8], "xla": [0, 2, 3, 4, 7, 8], "version": 0, "perform": [0, 6], "A": 0, "simpl": [0, 2], "calcul": 0, "run": [0, 2, 3, 7], "resnet": [0, 2, 4], "With": 0, "fake": [0, 4], "data": [0, 4, 7], "debug": [0, 3, 7], "tool": [0, 7], "auto": [0, 7], "metric": 0, "analysi": 0, "compil": [0, 1, 3, 7, 8], "execut": 0, "get": 0, "report": 0, "understand": 0, "The": 0, "clear": 0, "dynamo": 0, "profil": 0, "benchmark": [0, 1, 4], "known": 0, "caveat": 0, "tensor": [0, 3, 5], "quirk": 0, "more": 0, "environ": [0, 2], "variabl": [0, 2], "common": 0, "combin": 0, "reproduc": 0, "ci": 0, "cd": 0, "unit": 0, "test": 0, "failur": 0, "eager": 1, "mode": [1, 7], "api": [1, 3], "background": [1, 4], "basic": 1, "usag": 1, "infer": [1, 8], "train": [1, 4, 6, 8], "how": [2, 4, 7], "gpu": [2, 6], "creat": [2, 3], "instanc": 2, "setup": 2, "docker": [2, 6], "wheel": 2, "some": [2, 5], "model": [2, 3], "mp_imagenet": 2, "exampl": [2, 4, 7], "amp": 2, "automat": 2, "mix": 2, "precis": 2, "develop": 2, "build": 2, "from": [2, 3, 5, 6], "sourc": [2, 5], "support": 2, "document": 3, "doc": 3, "devic": 3, "an": 3, "ar": 3, "singl": [3, 6], "multipl": 3, "multi": [3, 6], "process": 3, "tpu": [3, 4, 6, 7], "pod": [3, 4, 6, 7], "deep": 3, "dive": 3, "lazi": 3, "memori": 3, "layout": 3, "move": 3, "cpu": [3, 6], "save": 3, "load": 3, "cach": 3, "further": [3, 7], "read": [3, 7], "torch_xla": [3, 5], "runtim": [3, 6], "xla_model": 3, "distribut": [3, 6, 7], "spmd": [3, 7], "experiment": 3, "do": 4, "distributeddataparallel": 4, "ddp": 4, "motiv": 4, "us": [4, 5, 7], "resnet50": 4, "mnist": 4, "real": [4, 5], "disclaim": 4, "fulli": [4, 7], "shard": [4, 7], "parallel": [4, 7], "fsdp": [4, 7], "script": 4, "imagenet": 4, "instal": 4, "clone": 4, "repo": 4, "v3": [4, 6], "8": 4, "50": 4, "10": 4, "billion": 4, "paramet": 4, "recompil": 5, "let": 5, "": 5, "first": 5, "start": 5, "fact": 5, "constraint": 5, "1": 5, "input": 5, "dataset": 5, "2": [5, 7], "oper": 5, "output": [5, 7], "bound": 5, "dynam": 5, "shape": 5, "can": 5, "fix": 5, "case": 5, "when": 5, "you": 5, "without": 5, "queri": 5, "its": 5, "dimens": 5, "what": [5, 7], "i": [5, 7], "3": 5, "control": 5, "flow": 5, "conclus": 5, "appendix": 5, "pjrt": 6, "tl": 6, "dr": 6, "benefit": 6, "quickstart": 6, "node": 6, "differ": 6, "xrt": 6, "multithread": 6, "v2": 6, "chang": 6, "xm": 6, "rendezv": 6, "torch": [6, 7, 8], "new": 6, "user": 7, "guid": 7, "mesh": 7, "partit": 7, "spec": 7, "via": 7, "gradient": 7, "checkpoint": 7, "huggingfac": 7, "llama": 7, "advanc": 7, "topic": 7, "hybrid": 7, "xlashardedtensor": 7, "dtensor": 7, "integr": [7, 8], "activ": 7, "torchdynamo": 8, "featur": 8, "gap": 8, "take": 8, "awai": 8}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"Troubleshooting": [[0, "troubleshooting"]], "Sanity Check": [[0, "sanity-check"]], "Check PyTorch/XLA Version": [[0, "check-pytorch-xla-version"]], "Perform A Simple Calculation": [[0, "perform-a-simple-calculation"]], "Run Resnet With Fake Data": [[0, "run-resnet-with-fake-data"]], "Performance Debugging": [[0, "performance-debugging"]], "PyTorch/XLA Debugging Tool": [[0, "pytorch-xla-debugging-tool"]], "Perform A Auto-Metrics Analysis": [[0, "perform-a-auto-metrics-analysis"]], "Compilation & Execution Analysis": [[0, "compilation-execution-analysis"]], "Get A Metrics Report": [[0, "get-a-metrics-report"]], "Understand The Metrics Report": [[0, "understand-the-metrics-report"]], "Clear The Metrics Report": [[0, "clear-the-metrics-report"]], "PyTorch/XLA + Dynamo Debugging Tool": [[0, "pytorch-xla-dynamo-debugging-tool"]], "Performance Profiling": [[0, "performance-profiling"]], "Simple Benchmarking": [[0, "simple-benchmarking"]], "Known Performance Caveats": [[0, "known-performance-caveats"]], "XLA Tensor Quirks": [[0, "xla-tensor-quirks"]], "More Debugging Tools": [[0, "more-debugging-tools"]], "Environment Variables": [[0, "environment-variables"]], "Common Debugging Environment Variables Combinations": [[0, "common-debugging-environment-variables-combinations"]], "Reproducing PyTorch/XLA CI/CD unit test failures.": [[0, "reproducing-pytorch-xla-ci-cd-unit-test-failures"]], "Eager Mode + Compile API": [[1, "eager-mode-compile-api"]], "Background": [[1, "background"]], "Basic Usage": [[1, "basic-usage"]], "Inference": [[1, "inference"], [8, "inference"]], "Training": [[1, "training"], [8, "training"]], "Benchmark": [[1, "benchmark"]], "How to run with PyTorch/XLA:GPU": [[2, "how-to-run-with-pytorch-xla-gpu"]], "Create a GPU instance": [[2, "create-a-gpu-instance"]], "Environment Setup": [[2, "environment-setup"]], "Docker": [[2, "docker"], [6, "docker"]], "Check environment variable": [[2, "check-environment-variable"]], "Wheel": [[2, "wheel"]], "Run some simple models": [[2, "run-some-simple-models"]], "MP_ImageNet Example": [[2, "mp-imagenet-example"]], "ResNet Example": [[2, "resnet-example"]], "AMP (AUTOMATIC MIXED PRECISION)": [[2, "amp-automatic-mixed-precision"]], "Develop PyTorch/XLA on a GPU instance (build PyTorch/XLA from source with GPU support)": [[2, "develop-pytorch-xla-on-a-gpu-instance-build-pytorch-xla-from-source-with-gpu-support"]], "PyTorch/XLA documentation": [[3, "pytorch-xla-documentation"]], "Docs": [[3, null]], "PyTorch on XLA Devices": [[3, "pytorch-on-xla-devices"]], "Creating an XLA Tensor": [[3, "creating-an-xla-tensor"]], "XLA Tensors are PyTorch Tensors": [[3, "xla-tensors-are-pytorch-tensors"]], "Running Models on XLA Devices": [[3, "running-models-on-xla-devices"]], "Running on a Single XLA Device": [[3, "running-on-a-single-xla-device"]], "Running on Multiple XLA Devices with Multi-processing": [[3, "running-on-multiple-xla-devices-with-multi-processing"]], "Running on TPU Pods": [[3, "running-on-tpu-pods"]], "XLA Tensor Deep Dive": [[3, "id3"]], "XLA Tensors are Lazy": [[3, "xla-tensors-are-lazy"]], "Memory Layout": [[3, "memory-layout"]], "Moving XLA Tensors to and from the CPU": [[3, "moving-xla-tensors-to-and-from-the-cpu"]], "Saving and Loading XLA Tensors": [[3, "saving-and-loading-xla-tensors"]], "Compilation Caching": [[3, "compilation-caching"]], "Further Reading": [[3, "further-reading"], [7, "further-reading"]], "PyTorch/XLA API": [[3, "pytorch-xla-api"]], "torch_xla": [[3, "module-torch_xla"]], "runtime": [[3, "module-torch_xla.runtime"]], "xla_model": [[3, "module-torch_xla.core.xla_model"]], "distributed": [[3, "module-torch_xla.distributed.parallel_loader"]], "spmd": [[3, "module-torch_xla.distributed.spmd"]], "experimental": [[3, "module-torch_xla.experimental"]], "debug": [[3, "module-torch_xla.debug.metrics"]], "How to do DistributedDataParallel(DDP)": [[4, "how-to-do-distributeddataparallel-ddp"]], "Background / Motivation": [[4, "background-motivation"]], "How to use DistributedDataParallel": [[4, "how-to-use-distributeddataparallel"]], "Benchmarking": [[4, "benchmarking"]], "Resnet50 with fake data": [[4, "resnet50-with-fake-data"]], "MNIST with fake data": [[4, "mnist-with-fake-data"]], "MNIST with real data": [[4, "mnist-with-real-data"]], "Disclaimer": [[4, "disclaimer"]], "Fully Sharded Data Parallel (FSDP) in PyTorch XLA": [[4, "fully-sharded-data-parallel-fsdp-in-pytorch-xla"]], "Example training scripts on MNIST and ImageNet": [[4, "example-training-scripts-on-mnist-and-imagenet"]], "Installation": [[4, "installation"]], "Clone PyTorch/XLA repo": [[4, "clone-pytorch-xla-repo"]], "Train MNIST on v3-8 TPU": [[4, "train-mnist-on-v3-8-tpu"]], "Train ImageNet with ResNet-50 on v3-8 TPU": [[4, "train-imagenet-with-resnet-50-on-v3-8-tpu"]], "Example training scripts on TPU pod (with 10 billion parameters)": [[4, "example-training-scripts-on-tpu-pod-with-10-billion-parameters"]], "Source of recompilations in torch_xla": [[5, "source-of-recompilations-in-torch-xla"]], "Let\u2019s first start with some facts/constraints:": [[5, "lets-first-start-with-some-facts-constraints"]], "#1. From input dataset.": [[5, "from-input-dataset"]], "#2. From operator output": [[5, "from-operator-output"]], "2.1 Bounded dynamic shape can fix the case when you use the tensor with dynamic shape as a Tensor, without querying its real dimension.": [[5, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"]], "2.2 what if real dimension is queried on a tensor with dynamic shape?": [[5, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"]], "#3. From control flow": [[5, "from-control-flow"]], "Conclusion:": [[5, "conclusion"]], "Appendix:": [[5, "appendix"]], "PJRT Runtime": [[6, "pjrt-runtime"]], "TL;DR": [[6, "tl-dr"]], "Benefits": [[6, "benefits"]], "Quickstart": [[6, "quickstart"]], "CPU": [[6, "cpu"]], "TPU": [[6, "tpu"]], "Pods": [[6, "pods"]], "GPU": [[6, "gpu"]], "Single-node GPU training": [[6, "single-node-gpu-training"]], "Multi-node GPU training": [[6, "multi-node-gpu-training"]], "Differences from XRT": [[6, "differences-from-xrt"]], "Multithreading on TPU v2/v3": [[6, "id3"]], "Changes to xm.rendezvous": [[6, "changes-to-xm-rendezvous"]], "PJRT and torch.distributed": [[6, "pjrt-and-torch-distributed"]], "Performance": [[6, "performance"]], "New TPU runtime": [[6, "new-tpu-runtime"]], "PyTorch/XLA SPMD User Guide": [[7, "pytorch-xla-spmd-user-guide"]], "What is PyTorch/XLA SPMD?": [[7, "what-is-pytorch-xla-spmd"]], "How to use PyTorch/XLA SPMD?": [[7, "how-to-use-pytorch-xla-spmd"]], "SPMD Mode": [[7, "spmd-mode"]], "Mesh": [[7, "mesh"]], "Partition Spec": [[7, "partition-spec"]], "Fully Sharded Data Parallel(FSDP) via SPMD": [[7, "fully-sharded-data-parallel-fsdp-via-spmd"]], "Sharding output": [[7, "sharding-output"]], "Gradient checkpointing": [[7, "gradient-checkpointing"]], "HuggingFace Llama 2 Example": [[7, "huggingface-llama-2-example"]], "PyTorch/XLA SPMD advanced topics": [[7, "pytorch-xla-spmd-advanced-topics"]], "Hybrid Mesh": [[7, "hybrid-mesh"]], "Running SPMD on TPU Pod": [[7, "running-spmd-on-tpu-pod"]], "XLAShardedTensor": [[7, "xlashardedtensor"]], "DTensor Integration": [[7, "dtensor-integration"]], "Activation Sharding for torch.compile": [[7, "activation-sharding-for-torch-compile"]], "SPMD Debugging Tool": [[7, "spmd-debugging-tool"]], "Auto-Sharding": [[7, "auto-sharding"]], "Distributed Checkpointing": [[7, "distributed-checkpointing"]], "TorchDynamo(torch.compile) integration in PyTorch XLA": [[8, "torchdynamo-torch-compile-integration-in-pytorch-xla"]], "Integration": [[8, "integration"]], "Feature gaps": [[8, "feature-gaps"]], "Take away": [[8, "take-away"]]}, "indexentries": {"hybridmesh (class in torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.HybridMesh"]], "mesh (class in torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.Mesh"]], "parallelloader (class in torch_xla.distributed.parallel_loader)": [[3, "torch_xla.distributed.parallel_loader.ParallelLoader"]], "add_step_closure() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.add_step_closure"]], "addressable_device_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.addressable_device_count"]], "all_gather() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.all_gather"]], "all_reduce() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.all_reduce"]], "all_to_all() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.all_to_all"]], "clear_sharding() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.clear_sharding"]], "compile() (in module torch_xla)": [[3, "torch_xla.compile"]], "counter_names() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.counter_names"]], "counter_value() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.counter_value"]], "device() (in module torch_xla)": [[3, "torch_xla.device"]], "device_count() (in module torch_xla)": [[3, "torch_xla.device_count"]], "device_type() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.device_type"]], "devices() (in module torch_xla)": [[3, "torch_xla.devices"]], "eager_mode() (in module torch_xla.experimental)": [[3, "torch_xla.experimental.eager_mode"]], "get_1d_mesh() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.get_1d_mesh"]], "get_global_mesh() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.get_global_mesh"]], "get_master_ip() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.get_master_ip"]], "get_memory_info() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.get_memory_info"]], "get_rng_state() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.get_rng_state"]], "get_stablehlo() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.get_stablehlo"]], "get_stablehlo_bytecode() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.get_stablehlo_bytecode"]], "global_device_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.global_device_count"]], "global_ordinal() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.global_ordinal"]], "global_runtime_device_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.global_runtime_device_count"]], "initialize_cache() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.initialize_cache"]], "is_master_ordinal() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.is_master_ordinal"]], "is_spmd() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.is_spmd"]], "local_device_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.local_device_count"]], "local_ordinal() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.local_ordinal"]], "local_process_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.local_process_count"]], "manual_seed() (in module torch_xla)": [[3, "torch_xla.manual_seed"]], "mark_sharding() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.mark_sharding"]], "mesh_reduce() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.mesh_reduce"]], "metric_data() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.metric_data"]], "metric_names() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.metric_names"]], "metrics_report() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.metrics_report"]], "module": [[3, "module-torch_xla"], [3, "module-torch_xla.core.xla_model"], [3, "module-torch_xla.debug.metrics"], [3, "module-torch_xla.distributed.parallel_loader"], [3, "module-torch_xla.distributed.spmd"], [3, "module-torch_xla.distributed.xla_multiprocessing"], [3, "module-torch_xla.experimental"], [3, "module-torch_xla.runtime"]], "optimizer_step() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.optimizer_step"]], "per_device_loader() (torch_xla.distributed.parallel_loader.parallelloader method)": [[3, "torch_xla.distributed.parallel_loader.ParallelLoader.per_device_loader"]], "rendezvous() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.rendezvous"]], "save() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.save"]], "set_global_mesh() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.set_global_mesh"]], "set_rng_state() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.set_rng_state"]], "short_metrics_report() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.short_metrics_report"]], "spawn() (in module torch_xla.distributed.xla_multiprocessing)": [[3, "torch_xla.distributed.xla_multiprocessing.spawn"]], "sync() (in module torch_xla)": [[3, "torch_xla.sync"]], "torch_xla": [[3, "module-torch_xla"]], "torch_xla.core.xla_model": [[3, "module-torch_xla.core.xla_model"]], "torch_xla.debug.metrics": [[3, "module-torch_xla.debug.metrics"]], "torch_xla.distributed.parallel_loader": [[3, "module-torch_xla.distributed.parallel_loader"]], "torch_xla.distributed.spmd": [[3, "module-torch_xla.distributed.spmd"]], "torch_xla.distributed.xla_multiprocessing": [[3, "module-torch_xla.distributed.xla_multiprocessing"]], "torch_xla.experimental": [[3, "module-torch_xla.experimental"]], "torch_xla.runtime": [[3, "module-torch_xla.runtime"]], "use_spmd() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.use_spmd"]], "wait_device_ops() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.wait_device_ops"]], "world_size() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.world_size"]], "xla_device() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.xla_device"]], "xla_device_hw() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.xla_device_hw"]]}}) \ No newline at end of file +Search.setIndex({"docnames": ["debug", "eager_mode", "gpu", "index", "multi_process_distributed", "notes/source_of_recompilation", "runtime", "spmd", "torch_compile"], "filenames": ["debug.rst", "eager_mode.rst", "gpu.rst", "index.rst", "multi_process_distributed.rst", "notes/source_of_recompilation.md", "runtime.rst", "spmd.rst", "torch_compile.rst"], "titles": ["Troubleshooting", "Eager Mode + Compile API", "How to run with PyTorch/XLA:GPU", "PyTorch/XLA documentation", "How to do DistributedDataParallel(DDP)", "Source of recompilations in torch_xla", "PJRT Runtime", "PyTorch/XLA SPMD User Guide", "TorchDynamo(torch.compile) integration in PyTorch XLA"], "terms": {"note": [0, 1, 2, 3, 4, 5, 6, 7, 8], "inform": [0, 3, 5, 6], "thi": [0, 1, 2, 3, 4, 5, 6, 7, 8], "section": [0, 2, 3, 6, 7], "i": [0, 1, 2, 3, 4, 6, 8], "subject": 0, "remov": [0, 6], "futur": [0, 3, 5, 6, 7], "releas": [0, 2, 3, 4, 6, 7, 8], "softwar": 0, "sinc": [0, 3, 4, 5, 6, 7, 8], "mani": [0, 3, 5, 6], "them": [0, 3, 5, 6], "ar": [0, 1, 4, 5, 6, 7, 8], "peculiar": 0, "given": [0, 3, 4, 5, 7], "intern": [0, 3, 5, 6, 7], "implement": [0, 1, 4, 5, 6, 7, 8], "which": [0, 2, 3, 4, 5, 6, 7, 8], "might": [0, 3, 5], "chang": [0, 3, 4, 5, 7], "befor": [0, 3, 4, 5, 6, 7, 8], "ani": [0, 3, 4, 5, 6, 7], "depth": [0, 6], "we": [0, 1, 2, 3, 4, 5, 6, 7, 8], "want": [0, 1, 3, 5, 6, 7, 8], "do": [0, 2, 3, 5, 6, 7], "instal": [0, 2, 6], "should": [0, 1, 2, 3, 4, 5, 6, 7], "match": [0, 3], "out": [0, 1, 3, 6, 7, 8], "our": [0, 2, 3, 4, 5, 6, 7, 8], "readm": 0, "detial": 0, "avail": [0, 3, 4, 5, 6], "vm": [0, 2, 3, 4, 6], "python": [0, 2, 3, 4, 5, 6, 7, 8], "import": [0, 1, 3, 4, 6, 7, 8], "torch": [0, 1, 2, 3, 4, 5], "torch_xla": [0, 1, 2, 4, 6, 7, 8], "print": [0, 2, 3, 4, 5, 6, 7, 8], "__version__": 0, "2": [0, 1, 2, 3, 4, 6, 8], "1": [0, 1, 2, 3, 4, 6, 7, 8], "0": [0, 2, 3, 4, 5, 6, 7, 8], "cu121": 0, "export": [0, 2, 6], "pjrt_devic": [0, 2, 3, 6], "tpu": [0, 2, 8], "python3": [0, 2, 3, 4, 6], "core": [0, 1, 3, 4, 6, 7, 8], "xla_model": [0, 4, 6, 7, 8], "xm": [0, 1, 3, 4, 7, 8], "t1": [0, 3, 7], "100": [0, 2, 4], "devic": [0, 1, 2, 4, 5, 6, 7, 8], "xla_devic": [0, 3, 4, 6, 7, 8], "t2": [0, 7], "200": 0, "300": [0, 1], "For": [0, 2, 3, 4, 5, 6, 7, 8], "nightli": [0, 4, 7], "git": [0, 2, 4, 6], "clone": [0, 2, 6], "http": [0, 2, 3, 4, 6, 7], "github": [0, 2, 4, 6, 7], "com": [0, 2, 4, 6, 7], "test_train_mp_imagenet": [0, 2, 4, 6], "py": [0, 2, 3, 4, 6, 7], "fake_data": [0, 2, 4, 6], "x": [0, 3, 4, 5, 7], "y": [0, 2, 3, 4, 5, 7], "you": [0, 1, 2, 3, 4, 6, 7, 8], "us": [0, 1, 2, 3, 6, 8], "branch": [0, 5, 6], "rx": 0, "exampl": [0, 1, 3, 5, 6, 8], "r2": [0, 6, 7], "If": [0, 2, 3, 5, 6, 7], "can": [0, 1, 2, 3, 4, 6, 7, 8], "conclud": 0, "correctli": [0, 2, 7], "To": [0, 1, 2, 3, 4, 5, 6, 7], "diagnos": 0, "issu": [0, 1, 3, 4, 6, 7], "counter": [0, 3], "provid": [0, 3, 4, 5, 7, 8], "first": [0, 3, 4, 6, 7], "thing": 0, "when": [0, 1, 3, 4, 6, 7, 8], "model": [0, 1, 4, 5, 6, 7, 8], "slow": 0, "gener": [0, 1, 3, 5, 6], "extrem": 0, "help": [0, 5], "pleas": [0, 2, 3, 4, 6, 7], "try": [0, 5], "includ": [0, 2, 3, 5, 6, 7], "your": [0, 2, 3, 4, 5, 6, 7], "bug": [0, 4, 6], "sent": [0, 3], "u": [0, 2, 5, 6, 8], "have": [0, 2, 3, 4, 5, 6, 7, 8], "enabl": [0, 1, 2, 3, 4, 7], "set": [0, 3, 4, 5, 6, 7, 8], "pt_xla_debug_level": 0, "coupl": [0, 3], "featur": [0, 4, 6, 7], "also": [0, 1, 2, 3, 4, 5, 6, 7, 8], "lower": [0, 5], "level": [0, 7, 8], "slip": 0, "analyz": 0, "summari": 0, "some": [0, 1, 3, 4, 6, 7], "output": [0, 3, 4, 6, 8], "would": [0, 3, 5, 6], "pt": [0, 3, 6], "compiletim": 0, "too": [0, 5], "frequent": 0, "21": 0, "count": [0, 3], "dure": [0, 3, 4, 7, 8], "11": [0, 2, 5], "step": [0, 1, 2, 3, 4, 5, 6, 7, 8], "transferfromdevicetim": 0, "op": [0, 1, 3, 5, 7], "": [0, 1, 2, 3, 4, 6, 7, 8], "aten": [0, 5], "_ctc_loss": 0, "_ctc_loss_backward": 0, "open": [0, 6], "abov": [0, 1, 2, 3, 4, 5, 6, 7, 8], "request": [0, 3, 4, 5, 7], "23": [0, 2], "12": [0, 2, 4, 6, 8], "everi": [0, 3, 5, 6, 7, 8], "caus": [0, 1, 3, 5, 6], "mark_step": [0, 1, 3, 4, 6], "parallel": [0, 3, 6], "loader": [0, 3, 8], "end": [0, 2, 3, 4, 6, 7], "graph": [0, 1, 3, 4, 5, 6, 7, 8], "info": [0, 2, 3, 5, 7], "hash": 0, "c74c3b91b855b2b123f833b0d5f86943": 0, "number": [0, 1, 3, 4, 6, 7], "input": [0, 1, 3, 6, 7], "35": [0, 2, 6], "107": 0, "frame": 0, "trigger": [0, 5], "workspac": 0, "dk3": 0, "1055": 0, "next": [0, 3, 5], "distribut": [0, 4], "parallel_load": [0, 3, 6], "44": 0, "__next__": 0, "32": [0, 3], "train_loop_fn": 0, "train_decoder_only_bas": 0, "48": [0, 4], "start_train": 0, "65": [0, 1], "modul": [0, 3, 4, 7], "73": 0, "post": [0, 3], "size": [0, 2, 3, 5, 6, 7], "548000": 0, "gb": 0, "7": [0, 3, 4, 8], "922460": 0, "alias": 0, "547871": 0, "intermedi": [0, 3, 6], "124478": 0, "program": [0, 3, 5, 7, 8], "028210": 0, "user": [0, 1, 2, 3, 5, 6, 8], "manual": [0, 1, 4], "call": [0, 1, 3, 4, 5, 6, 7, 8], "configur": [0, 2, 3, 6, 7], "batch": [0, 3, 6, 7], "exit": [0, 3, 4], "steptrac": 0, "region": [0, 1, 3, 7], "decid": [0, 3, 5], "access": [0, 3, 5, 6, 7], "often": [0, 1, 5], "due": [0, 6], "log": [0, 2], "valu": [0, 3, 5, 6, 7], "4": [0, 2, 3, 4, 5, 6, 7, 8], "expect": [0, 1, 5, 6, 8], "avoid": [0, 2], "5": [0, 2, 3, 4, 5], "either": [0, 2, 3, 5, 6], "reduc": [0, 1, 3, 4, 6], "frequenc": 0, "add": [0, 3, 4, 5, 8], "see": [0, 2, 3, 4, 5, 6, 8], "pair": 0, "after": [0, 2, 3, 5, 6, 7], "stabil": [0, 6], "onli": [0, 1, 2, 3, 5, 6, 7, 8], "disabl": [0, 1, 3], "effici": [0, 8], "same": [0, 1, 3, 5, 6, 7], "code": [0, 1, 3, 4, 5, 6, 7, 8], "happen": [0, 1, 3, 5, 6], "onc": [0, 3, 5, 7, 8], "keep": [0, 5, 6], "dump": [0, 3], "ir": [0, 5], "hlo": [0, 3], "follow": [0, 1, 2, 3, 4, 5, 6, 7], "compar": [0, 1, 3, 4, 6, 8], "each": [0, 3, 4, 5, 6, 7, 8], "sourc": [0, 3], "differ": [0, 3, 4, 5, 7], "explain": [0, 3, 5, 7], "how": [0, 1, 3, 5, 6], "detail": [0, 3, 5, 6], "put": [0, 3, 4], "line": [0, 1, 3, 4, 5], "met": 0, "short": [0, 5], "contain": [0, 2, 3, 5, 6], "few": [0, 3, 4, 5, 7], "kei": [0, 6, 7], "short_metrics_report": [0, 3], "full": [0, 2, 3, 4], "all": [0, 2, 3, 4, 5, 6, 7], "metrics_report": [0, 3], "like": [0, 3, 4, 5, 6, 7], "time": [0, 2, 3, 5, 6, 7, 8], "spent": 0, "handl": [0, 1, 4, 5, 7], "creat": [0, 4, 6, 7], "destroi": 0, "etc": [0, 1, 2, 3, 5, 7], "term": [0, 1, 5], "percentil": 0, "sampl": [0, 3, 6], "an": [0, 4, 5, 6, 7, 8], "totalsampl": 0, "202": 0, "06m09s401ms746": 0, "001u": 0, "valuer": 0, "778ms572": 0, "062u": 0, "second": [0, 4, 6, 7], "rate": [0, 2, 4], "425201": 0, "001ms32": 0, "778u": 0, "001ms61": 0, "283u": 0, "10": [0, 2, 3, 5, 6, 7, 8], "001ms79": 0, "236u": 0, "20": [0, 2, 3, 4], "001ms110": 0, "973u": 0, "50": [0, 2], "001ms228": 0, "773u": 0, "80": [0, 2], "001ms339": 0, "183u": 0, "90": 0, "001ms434": 0, "305u": 0, "95": 0, "002ms921": 0, "063u": 0, "99": [0, 4], "21s102ms853": 0, "173u": 0, "name": [0, 2, 3, 5, 6, 7], "integ": [0, 3], "track": [0, 7], "statu": 0, "cachedsynctensor": 0, "395": [0, 4], "In": [0, 1, 2, 3, 5, 6, 7, 8], "start": [0, 1, 3, 6], "indic": [0, 3, 5], "context": [0, 3, 5, 6], "switch": [0, 3, 4, 5], "between": [0, 2, 3, 4, 5, 6, 7], "cpu": [0, 2, 4, 5, 7], "potenti": [0, 3, 6, 7], "optim": [0, 1, 2, 3, 4, 5, 6, 7, 8], "area": 0, "oper": [0, 3, 6, 7], "rout": 0, "back": [0, 3, 7], "engin": 0, "thei": [0, 3, 5, 6, 7], "fulli": [0, 1, 3, 6], "qualifi": 0, "c": [0, 3, 5, 6], "namespac": 0, "nonzero": [0, 5], "33": [0, 2, 4, 8], "other": [0, 2, 3, 4, 5, 6, 7], "than": [0, 4, 5, 6], "_local_scalar_dens": 0, "usual": [0, 1, 3], "mean": [0, 3, 4, 5, 6, 7], "miss": [0, 3], "feel": [0, 4], "free": [0, 4], "epoch": [0, 2, 4], "clear_al": 0, "xla_dynamo_debug": 0, "workload": [0, 3, 6, 7], "bottleneck": 0, "resourc": [0, 3], "offici": 0, "tutori": [0, 4, 7], "colab": 0, "notebook": 0, "mnist": [0, 2, 3, 6], "train": [0, 2, 3, 7], "script": [0, 3, 6], "util": [0, 2, 3, 4, 7], "captur": [0, 3], "take": [0, 3, 5, 7], "look": [0, 3], "train_resnet_benchmark": 0, "blob": [0, 4, 6, 7], "master": [0, 3, 4, 6, 7], "_": [0, 4, 6, 8], "behav": 0, "semant": [0, 5], "regular": [0, 3], "share": [0, 2, 3, 6, 7], "interfac": [0, 3, 7], "gpu": [0, 3, 7], "howev": [0, 7], "constraint": [0, 6], "hardwar": [0, 3], "lazi": [0, 5, 7, 8], "evalu": [0, 5], "suggest": 0, "certain": [0, 5], "pattern": [0, 5, 8], "result": [0, 3, 4, 6, 7], "bad": 0, "show": [0, 3, 4, 6], "mind": [0, 6], "yield": [0, 3], "degrad": 0, "recompil": [0, 1, 3], "expens": [0, 1, 5], "automat": [0, 3, 4, 5, 6, 7], "new": [0, 1, 3, 5, 7, 8], "shape": [0, 3, 7], "encount": [0, 6], "within": [0, 3, 7], "huge": [0, 4, 5], "speedup": [0, 8], "rest": [0, 5, 6], "order": [0, 2, 3, 7], "must": [0, 3, 6, 7], "constant": [0, 7], "comput": [0, 2, 3, 5, 6, 7], "across": [0, 3, 4, 6, 7], "host": [0, 2, 3, 4, 6, 7], "possibl": [0, 3, 4, 6, 7], "direct": [0, 6], "indirect": 0, "introduc": [0, 1, 4, 6, 7], "dynam": [0, 8], "mask": [0, 5], "index": [0, 3, 6], "base": [0, 1, 2, 3, 4, 5, 6, 7], "where": [0, 3, 4, 5, 6, 7], "loop": [0, 1, 3, 5, 7], "iter": [0, 3, 7, 8], "thu": [0, 2, 6], "requir": [0, 2, 3, 5, 6, 7], "solut": [0, 5], "low": 0, "variat": 0, "pad": [0, 5], "fix": [0, 7, 8], "don": [0, 1, 4, 5, 6], "t": [0, 1, 3, 4, 5, 6, 7], "nativ": [0, 1, 2, 4, 6, 7], "translat": 0, "transfer": [0, 3, 6, 7], "memori": [0, 2, 4, 5], "lead": 0, "signific": [0, 8], "slowdown": [0, 4], "item": 0, "explicitli": [0, 3, 5], "ask": [0, 1, 5], "unless": [0, 5], "necessari": [0, 3], "most": [0, 3, 6, 8], "checkout": [0, 2], "find": [0, 4, 6, 7], "even": [0, 3, 4, 5, 6], "scalar": [0, 5], "substitut": 0, "control": [0, 3, 7], "flow": 0, "applic": [0, 7], "e": [0, 3, 4, 5, 6, 7], "g": [0, 2, 3, 5, 6, 7], "clip_grad": 0, "norm": 0, "problemat": 0, "impact": [0, 3, 4, 5, 6], "so": [0, 2, 3, 4, 5, 6, 7], "patch": 0, "clip_grad_norm_": 0, "instead": [0, 1, 3, 4, 5, 6, 7, 8], "give": [0, 7], "dramat": 0, "improv": [0, 3, 6, 7, 8], "block": [0, 3, 4, 7], "els": [0, 5], "paramet": [0, 3, 6, 7], "total_norm": 0, "zero": [0, 4, 7], "none": [0, 3, 7], "p": [0, 2, 5, 6], "param_norm": 0, "grad": 0, "norm_typ": 0, "add_": 0, "clip_coef": 0, "max_norm": 0, "1e": [0, 8], "6": [0, 2, 3, 5], "mul_": 0, "data_parallel": 0, "mai": [0, 3, 5, 6, 7], "drop": 0, "last": 0, "make": [0, 1, 2, 3, 4, 5, 6, 7, 8], "sure": [0, 2, 3], "amount": [0, 3, 5], "work": [0, 3, 4, 5, 6, 7, 8], "dataset": [0, 4], "small": [0, 1, 4, 5, 8], "therefor": 0, "better": [0, 1, 3, 5, 6, 8], "those": [0, 4], "case": [0, 3, 6, 7, 8], "opaqu": [0, 3], "alwai": [0, 3, 5, 6, 7], "appear": [0, 3], "contigu": [0, 3], "without": [0, 3, 6, 7], "storag": [0, 2, 3, 4, 7], "network": [0, 3, 6, 7], "stride": 0, "move": [0, 4, 5, 6, 7], "save": [0, 4, 7], "directli": [0, 3, 4, 5, 6, 7], "load": [0, 4, 6, 7], "were": [0, 3, 5], "from": [0, 4, 7, 8], "unavail": [0, 3], "fail": [0, 3, 7], "let": [0, 3, 6, 7, 8], "machin": [0, 2, 6], "care": [0, 3, 5], "taken": [0, 3, 4, 5, 7], "type": [0, 2, 3, 4, 6], "doe": [0, 3, 5, 6, 7], "preserv": [0, 3], "view": [0, 3], "relationship": [0, 3], "reconstruct": 0, "copi": [0, 3, 6], "return": [0, 1, 3, 4, 5, 7, 8], "deep": 0, "shallow": 0, "weight": [0, 3, 7], "one": [0, 3, 4, 5, 6, 7, 8], "anoth": [0, 3, 5], "ty": 0, "done": [0, 3, 5], "otherwis": [0, 3, 5, 7], "two": [0, 3, 5, 6, 7], "independ": [0, 3, 6], "made": [0, 5, 7], "But": [0, 3, 5], "submit": 0, "addit": [0, 2, 3, 4, 6], "doesn": [0, 3, 5, 7], "_xlac": [0, 5], "_get_xla_tensors_text": [0, 5], "re": [0, 1, 3, 5, 6, 7], "_get_xla_tensors_hlo": 0, "function": [0, 1, 3, 7, 8], "prior": [0, 7], "alreadi": [0, 2, 3, 4, 5, 7], "materi": [0, 3, 5, 7], "There": [0, 1, 3, 4, 5, 7, 8], "behavior": [0, 3, 6], "stack": [0, 3, 5, 7], "degre": 0, "xla_ir_debug": 0, "trace": [0, 1, 3, 4, 5, 6, 7, 8], "node": [0, 5], "henc": [0, 8], "allow": [0, 3, 7], "wa": [0, 3, 5, 6, 7], "respons": [0, 7, 8], "xla_hlo_debug": [0, 3], "_xla_ir": 0, "activ": [0, 3, 4], "propag": 0, "metadata": 0, "xla_save_tensors_fil": 0, "path": [0, 2, 3, 4, 5], "file": [0, 2, 3, 4, 6], "becom": [0, 5, 6], "realli": [0, 5, 8], "big": [0, 5], "option": [0, 3, 6, 7], "left": 0, "long": [0, 1, 4, 5], "append": 0, "clean": [0, 8], "sheet": 0, "xla_save_tensors_fmt": 0, "format": [0, 3, 8], "store": [0, 3], "_xla_save_tensor": 0, "text": 0, "default": [0, 1, 2, 3, 4, 6, 7], "dot": 0, "graphviz": 0, "xla_flag": 0, "xla_dump_to": 0, "tmp": [0, 4], "dir_nam": 0, "unoptim": 0, "optimz": 0, "per": [0, 2, 3, 4, 6, 8], "xla_metrics_fil": 0, "local": [0, 2, 3, 6, 7], "exist": [0, 1, 3, 6, 7, 8], "xla_save_hlo_fil": 0, "error": [0, 3], "offend": 0, "xla_sync_wait": 0, "forc": [0, 5, 6], "sync": [0, 1, 2, 3], "wait": [0, 3], "its": [0, 3, 4, 6, 7, 8], "complet": [0, 3], "xla_use_eager_debug_mod": 0, "eagerli": [0, 1, 3, 5], "bypass": 0, "overal": 0, "lot": [0, 3, 5], "slower": [0, 4], "usag": [0, 2, 3, 4, 5, 7], "higher": [0, 7], "optimizaiton": 0, "skip": [0, 8], "tf_cpp_log_thread_id": 0, "tf": [0, 5], "thread": [0, 3, 6, 7], "id": [0, 2, 3, 6], "multithread": [0, 3], "process": [0, 1, 2, 4, 6, 7], "tf_cpp_vmodul": 0, "vlog": 0, "form": [0, 5, 6], "tf_cpp_min_log_level": 0, "messag": [0, 3], "turn": 0, "warn": 0, "tf_vlog": 0, "tensorflow": [0, 3, 5, 6], "xla_dump_hlo_graph": 0, "part": [0, 1, 3, 6, 7], "runtim": [0, 2, 4, 7], "rais": 0, "xla_util": 0, "cc": 0, "record": [0, 3], "save1": 0, "xla_graph_executor": 0, "pjrt_computation_cli": 0, "3": [0, 1, 2, 3, 4, 7, 8], "pr": [0, 4], "repo": [0, 3], "dir": 0, "pytorch_test_with_slow": 0, "test_torch": 0, "k": 0, "test_put_xla_uint8": 0, "command": [0, 2, 3, 4, 6], "need": [0, 2, 3, 4, 5, 6, 7], "torch_test_devic": 0, "pytorch_test_bas": 0, "doc": [1, 2, 5, 6, 7], "go": [1, 2, 3, 7], "over": [1, 2, 3, 4, 6, 7], "pytorch": [1, 5, 6], "xla": [1, 5, 6], "experiment": [1, 4, 6, 7, 8], "The": [1, 2, 3, 4, 5, 6, 7, 8], "goal": 1, "experi": [1, 4, 6, 7], "more": [1, 2, 3, 5, 6, 7], "align": 1, "develop": [1, 3, 4, 7, 8], "easier": [1, 5], "current": [1, 2, 3, 4, 5, 6, 7, 8], "run": [1, 4, 5, 6, 8], "lazytensor": [1, 3], "torchvis": [1, 8], "resnet18": [1, 8], "randn": [1, 3, 4, 6, 7, 8], "64": [1, 4, 8], "224": 1, "execut": [1, 2, 3, 4, 5, 6, 7, 8], "actual": [1, 4, 5, 7], "multipl": [1, 5, 8], "drawback": 1, "approach": [1, 4, 5], "confus": 1, "about": [1, 3, 5, 6], "framework": [1, 3, 5], "non": [1, 5, 7], "data": [1, 2, 3, 5, 6, 8], "preprocess": 1, "pend": [1, 3], "get": [1, 2, 3, 4, 5, 6], "leak": 1, "main": [1, 3, 6, 7], "whole": [1, 3, 5, 8], "veri": [1, 2, 3, 5], "It": [1, 2, 3, 4, 5, 7, 8], "hard": [1, 4, 5, 8], "debug": [1, 5], "why": [1, 5], "mitig": 1, "ux": 1, "eager_mod": [1, 3], "true": [1, 3, 4, 5, 6, 7], "mark": [1, 3], "compiled_model": 1, "right": [1, 5, 8], "awai": 1, "ha": [1, 3, 5, 6, 7], "wrap": [1, 3, 4, 7], "pretti": [1, 3, 4, 5], "straight": 1, "forward": [1, 4, 7, 8], "enter": 1, "target": [1, 3, 5, 6, 8], "reenabl": 1, "perfomr": 1, "backend": [1, 3, 5, 6, 7, 8], "openxla": [1, 8], "recommen": 1, "overhad": 1, "def": [1, 3, 4, 6, 7, 8], "step_fn": 1, "loss_fn": [1, 3, 4, 6, 8], "zero_grad": [1, 3, 4, 6], "logit": [1, 7], "loss": [1, 2, 3, 4, 6, 7, 8], "backward": [1, 3, 4, 6, 7, 8], "refactor": 1, "becaus": [1, 3, 6, 7], "togeth": [1, 3, 4, 6, 7], "now": [1, 3, 5, 6, 7], "recommend": [1, 2, 3, 6, 7], "reason": [1, 4, 6], "layer": [1, 4, 7], "decod": 1, "much": [1, 3, 5, 6, 8], "just": [1, 3, 4, 5, 6, 7], "llama2": 1, "fake": [1, 7], "singl": [1, 4, 5, 7, 8], "chip": [1, 6], "v4": [1, 3, 6, 7, 8], "8": [1, 2, 3, 5, 6, 7, 8], "below": [1, 2, 5, 6, 7], "observ": [1, 4, 6], "token": 1, "147": 1, "achiev": [1, 4], "45": [1, 2], "perform": [1, 3, 4, 7, 8], "trainer": 1, "test": [1, 2, 4, 6], "found": [1, 2, 6], "here": [1, 2, 3, 4, 5, 7, 8], "perfomran": 1, "depend": [1, 3, 5], "tri": 1, "resnet50": [1, 3, 6, 8], "exepct": 1, "meant": 1, "logic": [1, 3, 5, 7], "random": [1, 3, 6], "compil": [2, 5, 6], "acceler": [2, 3, 6], "basic": [2, 4, 5], "nvidia": 2, "attach": [2, 7], "cloud": [2, 3, 6, 7, 8], "googl": [2, 3, 6], "cuda": [2, 3, 5, 6], "driver": 2, "publish": 2, "prebuilt": 2, "imag": [2, 4, 5, 6], "cuda11": 2, "correspond": [2, 3, 4, 7], "config": 2, "list": [2, 3, 7], "refer": [2, 3, 4, 6, 7], "sudo": [2, 6], "pull": [2, 4], "central1": 2, "pkg": 2, "dev": [2, 4], "nightly_3": 2, "8_cuda_12": 2, "toolkit": 2, "datacent": 2, "latest": 2, "guid": [2, 3, 4, 6], "html": [2, 4, 6], "curl": 2, "fssl": 2, "io": [2, 4, 6], "libnvidia": 2, "gpgkei": 2, "gpg": 2, "dearmor": 2, "o": [2, 4, 6], "usr": 2, "keyr": 2, "l": 2, "stabl": [2, 4, 6], "deb": 2, "sed": 2, "sign": 2, "tee": 2, "apt": 2, "d": [2, 3, 5], "updat": [2, 3, 5, 7], "ctk": 2, "systemctl": 2, "restart": [2, 6], "shm": 2, "16g": 2, "net": [2, 6], "bin": 2, "bash": [2, 4], "exec": 2, "awk": 2, "nr": 2, "visibl": [2, 3, 5], "smi": 2, "verifi": 2, "root": [2, 3, 5], "20ab2c7a2d06": 2, "dec": 2, "06": 2, "24": 2, "29": [2, 4, 8], "2022": 2, "510": 2, "47": 2, "03": 2, "version": [2, 6, 7], "persist": [2, 3, 7], "m": [2, 4, 5], "bu": 2, "disp": 2, "A": [2, 3, 5, 6, 7], "volatil": 2, "uncorr": 2, "ecc": 2, "fan": 2, "temp": 2, "perf": [2, 5], "pwr": 2, "cap": 2, "mig": 2, "tesla": 2, "v100": 2, "sxm2": 2, "off": 2, "00000000": 2, "00": [2, 4], "04": [2, 8], "n": [2, 3], "36c": 2, "p0": 2, "38w": 2, "300w": 2, "0mib": 2, "16384mib": 2, "gi": 2, "ci": 2, "pid": 2, "No": [2, 5, 6], "ld_library_path": 2, "account": 2, "echo": 2, "link": 2, "bashrc": 2, "lib64": 2, "compat": [2, 3, 6, 7], "x86_64": 2, "linux": 2, "architecutr": 2, "architectur": [2, 4, 6], "system": [2, 7], "unam": 2, "pip3": 2, "whl": 2, "googleapi": 2, "cp310": 2, "manylinux_2_28_x86_64": 2, "repositori": [2, 6], "imagenet": 2, "what": [2, 3], "gpu_num_devic": [2, 6], "recurs": [2, 4, 7], "prepar": 2, "begin": [2, 7], "38": 2, "89059": 2, "82": 2, "globalr": 2, "13": [2, 3, 4, 6], "79297": 2, "117": 2, "16": [2, 3, 4, 7], "84": 2, "36": 2, "40": [2, 4], "43628": 2, "281": 2, "49": [2, 8], "43": [2, 8], "60": [2, 4], "83108": 2, "346": 2, "88": [2, 7], "108": 2, "99023": 2, "373": 2, "62": [2, 8], "132": 2, "56": 2, "92699": 2, "384": 2, "152": 2, "14": 2, "02": [2, 4], "120": 2, "68816": 2, "388": 2, "169": 2, "09": 2, "train_resnet_bas": 2, "35pm": 2, "utc": 2, "jun": 2, "08": 2, "2024": 2, "887794017791748": 2, "746502586051985": 2, "877807140350342": 2, "238": 2, "4789458412044": 2, "867819786071777": 2, "329": 2, "86095958663503": 2, "30": [2, 4, 6], "857839584350586": 2, "367": 2, "3038003653586": 2, "847847938537598": 2, "381": 2, "53141087190835": 2, "837860584259033": 2, "387": 2, "80462249591113": 2, "260": 2, "628140926361084": 2, "391": 2, "135639565343": 2, "270": 2, "618192195892334": 2, "6901797745233": 2, "280": 2, "608224391937256": 2, "1602680460045": 2, "290": 2, "598264217376709": 2, "6731498290759": 2, "36pm": 2, "reus": [2, 3], "rule": 2, "modifi": [2, 7, 8], "insid": [2, 7], "cd": [2, 4], "use_cuda": 2, "bdist_wheel": 2, "hermet": 2, "xla_cuda": 2, "been": [2, 3, 5, 6, 7], "successfulli": 2, "packag": [3, 4], "learn": [3, 6], "connect": [3, 6, 7], "troubleshoot": 3, "eager": [3, 4, 5], "mode": [3, 4, 5], "distributeddataparallel": [3, 6], "ddp": [3, 6], "pjrt": [3, 7], "shard": 3, "fsdp": 3, "via": [3, 4, 6], "advanc": 3, "topic": 3, "checkpoint": [3, 4, 6], "torchdynamo": 3, "integr": 3, "describ": [3, 4, 7], "familiar": [3, 7], "initi": [3, 4, 6, 7], "environ": [3, 4, 6, 7], "ad": [3, 5, 7, 8], "t0": 3, "Or": [3, 5, 6], "matrix": 3, "multipli": [3, 7], "mm": 3, "neural": 3, "l_in": 3, "linear": [3, 4, 6], "nn": [3, 4, 6, 7, 8], "l_out": 3, "floattensor": 3, "throw": 3, "build": [3, 4], "convert": [3, 4], "specif": [3, 4], "snippet": [3, 7], "highlight": 3, "nllloss": 3, "sgd": [3, 4, 6, 8], "lr": [3, 4, 6, 7, 8], "momentum": 3, "train_load": [3, 7], "easi": [3, 5, 6], "definit": [3, 5], "dataload": [3, 4, 7], "acquir": 3, "pl": [3, 6, 7], "_mp_fn": [3, 6], "mp_device_load": 3, "mpdeviceload": [3, 7], "optimizer_step": [3, 4], "__name__": [3, 4, 6], "__main__": [3, 4, 6], "launch": [3, 4, 6, 8], "arg": [3, 4], "three": 3, "previou": [3, 5, 6], "wrapper": [3, 4, 7], "spawn": [3, 6], "torchrun": [3, 6], "abl": [3, 5, 7], "assign": 3, "being": [3, 4, 7], "up": [3, 5, 6, 7], "own": [3, 4], "v2": 3, "v3": 3, "check": [3, 7], "onto": 3, "preload": 3, "overlap": [3, 7, 8], "batches_per_execut": 3, "consolid": [3, 4], "gradient": [3, 4], "all_reduce_gradi": 3, "remain": [3, 5], "retriev": [3, 5, 7, 8], "parent": 3, "multiprocess": [3, 6], "setup": [3, 4], "talk": 3, "bit": 3, "basi": 3, "gcloud": [3, 6], "project": [3, 6], "howto": 3, "focu": [3, 5], "perspect": [3, 6], "assum": [3, 4, 5, 7], "train_mnist_xla": 3, "ssh": [3, 6], "tpuvm": [3, 6, 7], "scp": [3, 6], "alpha": [3, 6], "zone": [3, 6], "worker": [3, 4, 6, 7], "outsid": 3, "underli": 3, "infrastructur": 3, "awar": 3, "global": [3, 6, 7], "topologi": [3, 7], "ordin": 3, "cross": [3, 7], "commun": [3, 6, 7, 8], "regard": [3, 8], "fakedata": 3, "though": [3, 4], "act": 3, "uniqu": [3, 5], "immedi": [3, 7], "hand": 3, "until": [3, 7], "defer": 3, "separ": [3, 4, 7, 8], "fuse": 3, "invis": 3, "caller": 3, "construct": [3, 4, 7], "send": [3, 6, 7], "synchron": [3, 6, 7], "insert": 3, "barrier": [3, 6], "design": [3, 6, 7, 8], "paper": 3, "represent": [3, 7], "expos": [3, 6, 7], "unlik": 3, "adjust": 3, "wai": [3, 4, 5, 6, 7, 8], "again": 3, "appreci": 3, "accommod": 3, "transit": 3, "recreat": 3, "destin": 3, "previous": 3, "state_dict": [3, 4, 7], "limit": [3, 6], "footprint": 3, "serial": [3, 6], "xser": 3, "stream": 3, "restor": [3, 7], "load_state_dict": [3, 7], "under": [3, 4, 6], "consum": [3, 5], "disk": 3, "significantli": [3, 6], "still": [3, 4, 5, 6, 7], "occur": 3, "opt": 3, "through": [3, 5, 7], "initialize_cach": 3, "xr": [3, 4, 6, 7], "your_cache_path": 3, "readonli": 3, "fals": [3, 4, 7], "specifi": [3, 4], "whether": 3, "write": [3, 7], "mount": 3, "int": [3, 5, 6, 7], "instanc": [3, 4, 7], "virtual": [3, 7], "device_count": [3, 7], "address": [3, 6, 7], "bool": 3, "finish": 3, "f": [3, 4, 7], "callabl": [3, 4], "full_graph": 3, "repres": [3, 5, 6], "funciton": 3, "pass": [3, 4, 6, 7], "manag": [3, 7], "foo": 3, "sin": 3, "co": 3, "foo2": 3, "compiled_foo2": 3, "manual_se": [3, 6], "seed": 3, "state": [3, 4, 7], "rng": [3, 6], "device_typ": 3, "str": 3, "select": [3, 6, 7], "local_process_count": 3, "local_device_count": 3, "total": [3, 5, 7], "addressable_device_count": 3, "global_device_count": 3, "global_runtime_device_count": [3, 7], "especi": [3, 6, 7, 8], "world_siz": [3, 4, 6, 7], "particip": [3, 6], "job": [3, 8], "global_ordin": [3, 4, 6], "rang": [3, 6, 7], "guarante": 3, "predict": 3, "nor": 3, "local_ordin": 3, "get_master_ip": 3, "ip": [3, 6, 7], "discoveri": 3, "string": [3, 7], "use_spmd": [3, 7], "auto": [3, 4], "is_spmd": 3, "devkind": 3, "custom": [3, 4, 5, 7], "deprec": 3, "xla_device_hw": 3, "map": 3, "real": [3, 8], "is_master_ordin": 3, "replic": [3, 7], "while": [3, 4, 5], "num_host": 3, "boolean": 3, "all_reduc": 3, "reduce_typ": 3, "scale": [3, 6, 7, 8], "group": [3, 4, 6, 7], "pin_layout": 3, "inplac": [3, 7], "One": [3, 4], "reduce_sum": 3, "reduce_mul": 3, "reduce_and": 3, "reduce_or": 3, "reduce_min": 3, "reduce_max": 3, "float": [3, 5], "appli": [3, 4, 7], "replica": [3, 6], "defin": [3, 7], "pin": 3, "pine": 3, "prevent": [3, 7, 8], "corrupt": 3, "slightli": 3, "unpin": 3, "hlomodul": 3, "mix": [3, 7], "constrain": [3, 6], "hold": [3, 7], "tupl": [3, 5, 7], "itself": [3, 4], "all_gath": [3, 6], "dim": 3, "gather": [3, 7], "along": [3, 4], "dimens": [3, 7], "all_to_al": 3, "split_dimens": 3, "concat_dimens": 3, "split_count": 3, "alltoal": 3, "www": 3, "org": [3, 4, 6], "operation_semant": 3, "upon": 3, "split": 3, "concat": 3, "add_step_closur": 3, "closur": 3, "run_async": 3, "ones": [3, 5], "report": 3, "consol": 3, "tensorboard": 3, "content": 3, "intermediari": 3, "inspect": 3, "point": [3, 5], "typic": 3, "ensur": [3, 5, 7], "live": [3, 5], "argument": [3, 4, 8], "queu": 3, "sequenti": 3, "advis": 3, "throttl": 3, "event": 3, "asynchron": [3, 7], "wait_device_op": 3, "async": [3, 8], "whose": 3, "empti": 3, "optimizer_arg": 3, "parallelload": [3, 7], "dataparallel": 3, "support": [3, 4, 5, 6, 7, 8], "dict": [3, 4], "dictionari": 3, "file_or_path": 3, "master_onli": [3, 4], "global_mast": 3, "nest": [3, 4], "combin": [3, 5], "object": [3, 7], "overrid": 3, "locat": 3, "flag": 3, "hang": 3, "rendezv": 3, "tag": [3, 6], "payload": [3, 6], "b": [3, 5, 6, 7, 8], "mesh": [3, 6], "client": [3, 6], "reach": 3, "xrt": 3, "server": [3, 6], "effect": 3, "alia": 3, "xla_rendezv": 3, "join": 3, "byte": 3, "exchang": 3, "posit": 3, "mesh_reduc": 3, "reduce_fn": 3, "reduct": 3, "receiv": 3, "come": [3, 5], "set_rng_stat": 3, "get_rng_stat": 3, "get_memory_info": 3, "memoryinfo": 3, "get_stablehlo": 3, "stablehlo": 3, "todo": 3, "lsy323": 3, "investig": [3, 4], "infer": [3, 6, 7], "straightforward": 3, "identifi": [3, 7], "env": [3, 6, 7], "var": [3, 7], "get_stablehlo_bytecod": 3, "bytecod": [3, 8], "class": [3, 4, 7], "batchdim": 3, "loader_prefetch_s": 3, "device_prefetch_s": 3, "host_to_device_transfer_thread": 3, "input_shard": [3, 7], "background": [3, 7], "upload": [3, 7], "th": [3, 7], "len": 3, "max": [3, 5, 7], "capac": 3, "queue": 3, "deposit": 3, "shardingspec": [3, 7], "spec": 3, "per_device_load": [3, 7], "structur": [3, 4, 7], "resid": 3, "xla_multiprocess": 3, "fn": 3, "nproc": [3, 6], "daemon": 3, "start_method": 3, "At": 3, "moment": 3, "maximum": 3, "creation": 3, "method": [3, 6, 7], "mark_shard": [3, 7], "union": 3, "xlashardedtensor": 3, "partition_spec": [3, 7], "annot": [3, 7], "partit": 3, "xlatensor": [3, 7], "spmdpartition": [3, 7], "param": 3, "device_mesh": [3, 7], "axi": [3, 7], "rank": [3, 4, 6, 7], "mesh_shap": [3, 7], "ax": [3, 7], "row": 3, "wise": 3, "8x10": 3, "column": 3, "dynamo_custom_op": 3, "dynamo": [3, 8], "variant": [3, 5], "recogniz": 3, "traceabl": 3, "num_devic": [3, 7], "device_id": [3, 7], "np": [3, 7], "arrai": [3, 7], "clear_shard": 3, "clear": 3, "cast": 3, "place": [3, 7], "get_1d_mesh": 3, "set_global_mesh": 3, "get_global_mesh": 3, "axis_nam": [3, 7], "helper": 3, "ndarrai": 3, "ravel": 3, "reshap": 3, "fill": 3, "element": [3, 5, 7], "sequenc": 3, "Its": 3, "length": [3, 5], "get_xla_supported_devic": 3, "get_logical_mesh": 3, "ordereddict": [3, 7], "hybridmesh": [3, 7], "ici_mesh_shap": [3, 7], "dcn_mesh_shap": [3, 7], "hybrid": 3, "ici": 3, "dcn": [3, 7], "increas": 3, "intens": 3, "mdl": 3, "inner": [3, 4, 7], "outer": [3, 4, 7], "slice": [3, 7], "metric": [3, 4], "counter_nam": 3, "metric_nam": 3, "counter_valu": 3, "metric_data": 3, "total_sampl": 3, "accumul": 3, "retain": 3, "circular": 3, "buffer": 3, "sum": [3, 4, 7], "document": [4, 6], "further": 4, "against": 4, "minimum": [4, 7], "runnabl": [4, 7], "abil": [4, 5], "api": [4, 5, 6, 7, 8], "And": [4, 5, 7], "who": 4, "know": [4, 5], "xla_backend": [4, 6, 7], "init": [4, 6, 8], "similar": [4, 6], "nccl": 4, "gloo": [4, 6, 7], "dist": [4, 6, 7], "init_process_group": [4, 6, 7], "new_rank": 4, "gradient_as_bucket_view": [4, 6], "ddp_model": 4, "final": [4, 7], "launcher": 4, "demo_fn": 4, "everyth": [4, 5], "touch": [4, 7], "plu": 4, "five": 4, "sy": 4, "tempfil": 4, "master_addr": [4, 6], "localhost": [4, 6], "master_port": [4, 6], "12355": [4, 6], "cleanup": 4, "destroy_process_group": 4, "toymodel": 4, "__init__": 4, "self": [4, 7], "super": [4, 8], "net1": 4, "1000000": 4, "relu": 4, "net2": 4, "demo_bas": 4, "assert": 4, "graident_as_bucket_view": 4, "mseloss": [4, 6], "001": [4, 6], "label": 4, "run_demo": 4, "collect": [4, 6, 7, 8], "num_epoch": [4, 6], "tot": 4, "statist": 4, "produc": [4, 5], "unit": 4, "median": 4, "90th": 4, "std": 4, "cv": 4, "418": 4, "54": 4, "419": 4, "22": 4, "430": 4, "9": [4, 5], "76": 4, "97": 4, "407": 4, "39": 4, "seem": 4, "extra": [4, 7], "overhead": [4, 6, 8], "test_train_mp_mnist": [4, 6], "17864": 4, "19": [4, 8], "20108": 4, "96": 4, "24351": 4, "74": 4, "5866": 4, "83": 4, "10701": 4, "11770": 4, "14313": 4, "78": 4, "3102": 4, "92": 4, "41": [4, 8], "round": 4, "heavili": [4, 8], "sens": 4, "amort": 4, "logdir": 4, "converg": 4, "high": 4, "accuraci": 4, "caution": 4, "interest": 4, "known": 4, "enforc": [4, 5], "crash": 4, "xlafullyshardeddataparallel": 4, "my_modul": [4, 7], "adam": [4, 7], "0001": [4, 7], "individu": [4, 7], "leftov": [4, 7], "both": [4, 5, 6, 7, 8], "arxiv": 4, "ab": 4, "1910": 4, "02054": 4, "reshard_after_forward": 4, "test_train_mp_mnist_fsdp_with_ckpt": 4, "test_train_mp_imagenet_fsdp": 4, "larg": [4, 5, 6, 7], "cannot": [4, 5], "fit": 4, "interleav": 4, "submodul": 4, "fsdpvitmodel": 4, "ronghanghu": 4, "vit_10b_fsdp_exampl": 4, "run_vit_train": 4, "simpl": [4, 6, 7], "checkpoint_modul": [4, 7], "3524": 4, "auto_wrap_polici": [4, 7], "size_based_auto_wrap_polici": 4, "polici": [4, 7], "larger": [4, 8], "100m": 4, "transformer_auto_wrap_polici": [4, 7], "transform": [4, 7], "conv2d": 4, "partial": [4, 5, 7], "transformer_layer_cl": [4, 7], "addition": 4, "auto_wrapper_cal": 4, "remateri": 4, "lambda": 4, "kwarg": [4, 7], "latter": 4, "resum": 4, "get_shard_metadata": 4, "consolidate_sharded_model_checkpoint": 4, "stitch": 4, "ckpt": 4, "shard_metadata": 4, "ckpt_path": 4, "pth": 4, "tool": 4, "consolidate_sharded_ckpt": 4, "ckpt_prefix": 4, "your_sharded_checkpoint_fil": 4, "ckpt_suffix": 4, "_rank": 4, "inspir": 4, "mostli": [4, 6], "fairscal": 4, "fullyshardeddataparallel": 4, "readthedoc": 4, "en": 4, "biggest": [4, 8], "explicit": 4, "resort": 4, "train_resnet_fsdp_auto_wrap": 4, "newer": 4, "wheel": 4, "around": [4, 5, 6], "98": 4, "batch_siz": [4, 6], "drop_last": 4, "use_nested_fsdp": 4, "use_gradient_checkpoint": 4, "final_ckpt": 4, "75": 4, "download": 4, "1k": 4, "datadir": 4, "test_set_batch_s": 4, "eval_interv": 4, "128": [4, 6], "num_warmup_epoch": 4, "lr_scheduler_divide_every_n_epoch": 4, "lr_scheduler_divisor": 4, "residu": 4, "entir": [4, 6], "algorithm": [4, 7], "vision": 4, "vit": 4, "static": 5, "word": 5, "hurt": 5, "understand": 5, "normal": [5, 6, 7], "pov": 5, "sai": 5, "assur": 5, "magic": 5, "gone": 5, "good": [5, 7], "coverag": 5, "aim": [5, 7], "explan": 5, "common": [5, 6, 7], "rid": 5, "mainli": 5, "problem": 5, "beginn": 5, "propos": 5, "reli": 5, "impract": 5, "assumpt": 5, "ye": 5, "sentenc": 5, "vari": [5, 6, 7], "ll": 5, "bucket": [5, 7], "kinda": 5, "anti": 5, "frontend": 5, "matter": 5, "workaround": 5, "okai": 5, "teach": 5, "practic": [5, 6, 7], "enough": 5, "theoret": 5, "trade": 5, "less": [5, 6, 8], "faster": [5, 6, 8], "speed": [5, 8], "well": [5, 6, 7], "sort": 5, "obviou": 5, "shown": [5, 6], "s64": 5, "num_output": 5, "mul": 5, "although": [5, 6], "inde": 5, "_get_xla_tensor_dimension_s": 5, "commonli": 5, "dtype": [5, 6], "cut": 5, "correct": 5, "wrong": 5, "wors": 5, "probabl": 5, "upper": 5, "nit": 5, "simplic": 5, "rand": 5, "solv": 5, "world": [5, 6, 7, 8], "kept": 5, "earli": 5, "accessor": 5, "2d": [5, 7], "implicitli": 5, "doubl": 5, "overload": 5, "easili": [5, 8], "explod": 5, "convers": 5, "cheap": 5, "ve": 5, "hoc": 5, "think": 5, "verison": 5, "bla": 5, "blabla": 5, "interpret": 5, "proce": 5, "choic": 5, "wide": 5, "adopt": 5, "uglier": 5, "win": 5, "pars": 5, "statement": 5, "torchscript": 5, "somehow": 5, "merg": 5, "lazili": [5, 7], "properli": 5, "haven": 5, "thought": 5, "trivial": 5, "effort": [5, 7], "side": 5, "That": 5, "hit": 5, "bandwidth": 5, "automag": 5, "gold": 5, "smart": 5, "trick": 5, "tbh": 5, "longer": 5, "sometim": 5, "unawar": 5, "hope": 5, "smash": 5, "ideal": [5, 8], "blocker": 5, "ahead": 5, "nnc": 5, "symbol": 5, "By": [5, 6], "concret": 5, "kernel": 5, "exactli": 5, "transpos": 5, "With": [5, 6, 8], "brian": 5, "hirsh": 5, "bdhirsh": 5, "question": 5, "comment": 5, "worth": 5, "stick": 5, "torch_warn": 5, "yea": 5, "tell": 5, "hei": 5, "won": 5, "blaze": 5, "fast": 5, "isn": [5, 7], "rewrit": [5, 7], "devirtu": 5, "v": [5, 6], "sound": 5, "great": 5, "carri": [5, 7], "truth": 5, "As": [5, 7], "irvalu": 5, "discrep": 5, "followup": 5, "mention": [5, 8], "1000": 5, "my": [5, 7], "properti": 5, "presenc": 5, "get_dimention_s": 5, "didn": 5, "altern": [5, 6], "condit": 5, "middl": [5, 6], "exponenti": 5, "blowup": 5, "smaller": 5, "fewer": 5, "opportun": 5, "recogn": [5, 8], "could": [5, 6, 7], "break": 5, "feasibl": 5, "annoi": 5, "z": 5, "subgraph": 5, "variabl": [5, 6], "wasn": 5, "materiz": 5, "involv": [5, 7], "combo": 5, "migrat": 6, "jax": 6, "public": 6, "renam": 6, "regist": [6, 7], "init_method": [6, 7], "plugin": 6, "xpu": 6, "neuron": 6, "continu": [6, 8], "xrt_tpu_config": 6, "libtpu": 6, "thousand": 6, "preview": 6, "On": [6, 7], "safe": 6, "broadcast": 6, "broadcast_master_param": 6, "pjrt_backend": 6, "These": [6, 7], "diff": 6, "42": 6, "confirm": 6, "localservic": 6, "51011": 6, "grpc": 6, "torchbench": 6, "2048": 6, "read": 6, "central2": 6, "256": 6, "tpu_process_bound": 6, "tpu_visible_chip": 6, "r1": 6, "preinstal": 6, "docker_imag": 6, "gcr": 6, "authent": 6, "privat": 6, "gcp": 6, "auth": 6, "rm": 6, "privileg": 6, "simpli": 6, "nnode": 6, "num_gpu_devic": 6, "pjrt_distribut": 6, "physic": [6, 7], "number_gpu_vm": 6, "node_rank": 6, "current_node_rank": 6, "nproc_per_nod": 6, "number_local_gpu_devic": 6, "rdzv_endpoint": 6, "internal_ip_address": 6, "port": 6, "multinode_train": 6, "endpoint": 6, "omit": [6, 7], "machine_0": 6, "machine_1": 6, "machine_0_internal_ip_address": 6, "ident": 6, "page": 6, "interchang": 6, "subtl": 6, "importantli": 6, "latenc": 6, "deseri": 6, "gain": 6, "interact": 6, "profil": 6, "plan": 6, "simpler": 6, "xla_dist": 6, "sdk": 6, "reimplement": 6, "enhanc": 6, "xmp": 6, "substanti": 6, "consist": 6, "servic": 6, "unreli": 6, "inbound": 6, "failur": 6, "impos": 6, "unwant": 6, "permit": 6, "subset": 6, "old": 6, "alter": 6, "consid": 6, "all_gather_object": 6, "new_group": 6, "subgroup": 6, "reliabl": 6, "md": 6, "strongli": 6, "queri": 6, "_all_gath": 6, "tensor": [6, 7, 8], "int32": 6, "zeros_lik": 6, "get_world_s": 6, "averag": 6, "task": 6, "175": 6, "chart": 6, "breakdown": 6, "tfrt": 6, "legaci": 6, "streamexecutor": 6, "tpu_legaci": 6, "comparison": [6, 7], "discuss": 7, "gspmd": 7, "overview": 7, "illustr": 7, "ml": 7, "proper": 7, "hint": 7, "figur": 7, "strategi": 7, "numpi": 7, "concept": 7, "librari": 7, "cluster": 7, "interconnect": 7, "almost": 7, "encourag": 7, "fist": 7, "express": 7, "paral": 7, "fsdpv2": 7, "famou": 7, "offer": 7, "enjoi": 7, "benefit": 7, "bring": 7, "tabl": 7, "review": 7, "proceed": 7, "spmd_fully_sharded_data_parallel": 7, "spmdfullyshardeddataparallel": 7, "autowrap": 7, "decoderlay": 7, "functool": 7, "decoder_only_model": 7, "shard_output": 7, "fall": 7, "categori": 7, "0th": 7, "children": 7, "infinit": 7, "fork": 7, "hf": 7, "demonstr": 7, "cover": 7, "proced": 7, "src": 7, "_input_sharding_": 7, "4d": 7, "input_mesh": 7, "_after": 7, "_the": 7, "unnecessari": 7, "forth": 7, "techniqu": 7, "decis": 7, "nice": 7, "abstract": 7, "arrang": 7, "center": 7, "box": 7, "multislic": 7, "accept": 7, "denot": 7, "hardcod": 7, "rfc": 7, "delai": 7, "except": 7, "satisfi": 7, "subclass": 7, "__torch_dispatch__": 7, "invok": [7, 8], "global_tensor": 7, "special": 7, "strictli": 7, "local_shard": 7, "xlashard": 7, "4e8e5511555073ce8b6d1a436bf808c9333dcac6": 7, "xla_sharded_tensor": 7, "l12": 7, "ongo": 7, "distributedtensor": 7, "prototyp": 7, "proof": 7, "distribute_tensor": 7, "devicemesh": 7, "big_tensor": 7, "100000": 7, "my_dtensor": 7, "stai": 7, "tune": 7, "upcom": [7, 8], "dynamo_mark_shard": 7, "placement": 7, "visual": 7, "multi": 7, "visualize_tensor_shard": 7, "visualize_shard": 7, "rich": 7, "2x2": 7, "generated_t": 7, "use_color": 7, "style": 7, "tile": 7, "partial_repl": 7, "envvar": 7, "xla_auto_spmd": 7, "_tensor": 7, "distribute_modul": 7, "auto_polici": 7, "mymodul": 7, "sharded_model": 7, "behvaior": 7, "xla_auto_use_group_shard": 7, "reshard": 7, "xla_auto_spmd_mesh": 7, "unset": 7, "dedic": 7, "planner": 7, "spmdsaveplann": 7, "spmdloadplann": 7, "dist_cp": 7, "distributed_checkpoint": 7, "xc": 7, "storage_writ": 7, "filesystemwrit": 7, "checkpoint_dir": 7, "desir": 7, "storage_read": 7, "filesystemread": 7, "checkpointmanag": 7, "all_step": 7, "save_async": 7, "written": 7, "unblock": 7, "durat": 7, "dispatch": 7, "preemption": 7, "detect": 7, "termin": 7, "provis": 7, "queuedresourc": 7, "autocheckpoint": 7, "chkpt_on_preempt": 7, "fsspec": 7, "filesystem": 7, "gc": 7, "prime_optim": 7, "chkpt_mgr": 7, "tracked_step": 7, "choos": 7, "highest": 7, "best_step": 7, "prime": 7, "enumer": 7, "present": 7, "attempt": 7, "unprim": 7, "destruct": 7, "discov": 7, "jit": 8, "unmodifi": 8, "hook": 8, "bridg": 8, "torchfx": 8, "technologi": 8, "fx": 8, "a_xla": 8, "b_xla": 8, "compiled_cod": 8, "eval_model": 8, "xla_resnet18": 8, "eval": 8, "dynamo_resnet18": 8, "no_grad": 8, "resent18": 8, "binari": 8, "analysi": 8, "bench": 8, "59": 8, "resnext50_32x4d": 8, "91": 8, "alexnet": 8, "28": 8, "mobilenet_v2": 8, "18": 8, "mnasnet1_0": 8, "68": 8, "vgg16": 8, "bert_pytorch": 8, "squeezenet1_1": 8, "timm_vision_transform": 8, "52": 8, "geomean": 8, "team": 8, "train_model": 8, "crossentropyloss": 8, "pred": 8, "train_model_main": 8, "dynamo_train_model": 8, "xla_optim": 8, "weight_decai": 8, "extract": 8, "07": 8, "81": 8, "87": 8, "fwd": 8, "bwd": 8, "e2": 8, "hide": 8, "cost": 8, "scenario": 8, "best": 8, "promis": 8, "complex": 8, "tradit": 8, "seen": 8, "expand": 8, "excit": 8, "invest": 8, "upstream": 8, "matur": 8, "stori": 8}, "objects": {"": [[3, 0, 0, "-", "torch_xla"]], "torch_xla": [[3, 1, 1, "", "compile"], [3, 1, 1, "", "device"], [3, 1, 1, "", "device_count"], [3, 1, 1, "", "devices"], [3, 0, 0, "-", "experimental"], [3, 1, 1, "", "manual_seed"], [3, 0, 0, "-", "runtime"], [3, 1, 1, "", "sync"]], "torch_xla.core": [[3, 0, 0, "-", "xla_model"]], "torch_xla.core.xla_model": [[3, 1, 1, "", "add_step_closure"], [3, 1, 1, "", "all_gather"], [3, 1, 1, "", "all_reduce"], [3, 1, 1, "", "all_to_all"], [3, 1, 1, "", "get_memory_info"], [3, 1, 1, "", "get_rng_state"], [3, 1, 1, "", "get_stablehlo"], [3, 1, 1, "", "get_stablehlo_bytecode"], [3, 1, 1, "", "is_master_ordinal"], [3, 1, 1, "", "mesh_reduce"], [3, 1, 1, "", "optimizer_step"], [3, 1, 1, "", "rendezvous"], [3, 1, 1, "", "save"], [3, 1, 1, "", "set_rng_state"], [3, 1, 1, "", "wait_device_ops"], [3, 1, 1, "", "xla_device"], [3, 1, 1, "", "xla_device_hw"]], "torch_xla.debug": [[3, 0, 0, "-", "metrics"]], "torch_xla.debug.metrics": [[3, 1, 1, "", "counter_names"], [3, 1, 1, "", "counter_value"], [3, 1, 1, "", "metric_data"], [3, 1, 1, "", "metric_names"], [3, 1, 1, "", "metrics_report"], [3, 1, 1, "", "short_metrics_report"]], "torch_xla.distributed": [[3, 0, 0, "-", "parallel_loader"], [3, 0, 0, "-", "spmd"], [3, 0, 0, "-", "xla_multiprocessing"]], "torch_xla.distributed.parallel_loader": [[3, 2, 1, "", "ParallelLoader"]], "torch_xla.distributed.parallel_loader.ParallelLoader": [[3, 3, 1, "", "per_device_loader"]], "torch_xla.distributed.spmd": [[3, 2, 1, "", "HybridMesh"], [3, 2, 1, "", "Mesh"], [3, 1, 1, "", "clear_sharding"], [3, 1, 1, "", "get_1d_mesh"], [3, 1, 1, "", "get_global_mesh"], [3, 1, 1, "", "mark_sharding"], [3, 1, 1, "", "set_global_mesh"]], "torch_xla.distributed.xla_multiprocessing": [[3, 1, 1, "", "spawn"]], "torch_xla.experimental": [[3, 1, 1, "", "eager_mode"]], "torch_xla.runtime": [[3, 1, 1, "", "addressable_device_count"], [3, 1, 1, "", "device_type"], [3, 1, 1, "", "get_master_ip"], [3, 1, 1, "", "global_device_count"], [3, 1, 1, "", "global_ordinal"], [3, 1, 1, "", "global_runtime_device_count"], [3, 1, 1, "", "initialize_cache"], [3, 1, 1, "", "is_spmd"], [3, 1, 1, "", "local_device_count"], [3, 1, 1, "", "local_ordinal"], [3, 1, 1, "", "local_process_count"], [3, 1, 1, "", "use_spmd"], [3, 1, 1, "", "world_size"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class", "3": "py:method"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"], "3": ["py", "method", "Python method"]}, "titleterms": {"troubleshoot": 0, "saniti": 0, "check": [0, 2], "pytorch": [0, 2, 3, 4, 7, 8], "xla": [0, 2, 3, 4, 7, 8], "version": 0, "perform": [0, 6], "A": 0, "simpl": [0, 2], "calcul": 0, "run": [0, 2, 3, 7], "resnet": [0, 2, 4], "With": 0, "fake": [0, 4], "data": [0, 4, 7], "debug": [0, 3, 7], "tool": [0, 7], "auto": [0, 7], "metric": 0, "analysi": 0, "compil": [0, 1, 3, 7, 8], "execut": 0, "get": 0, "report": 0, "understand": 0, "The": 0, "clear": 0, "dynamo": 0, "profil": 0, "benchmark": [0, 1, 4], "known": 0, "caveat": 0, "tensor": [0, 3, 5], "quirk": 0, "more": 0, "environ": [0, 2], "variabl": [0, 2], "common": 0, "combin": 0, "reproduc": 0, "ci": 0, "cd": 0, "unit": 0, "test": 0, "failur": 0, "eager": 1, "mode": [1, 7], "api": [1, 3], "background": [1, 4], "basic": 1, "usag": 1, "infer": [1, 8], "train": [1, 4, 6, 8], "how": [2, 4, 7], "gpu": [2, 6], "creat": [2, 3], "instanc": 2, "setup": 2, "docker": [2, 6], "wheel": 2, "some": [2, 5], "model": [2, 3], "mp_imagenet": 2, "exampl": [2, 4, 7], "amp": 2, "automat": 2, "mix": 2, "precis": 2, "develop": 2, "build": 2, "from": [2, 3, 5, 6], "sourc": [2, 5], "support": 2, "document": 3, "doc": 3, "devic": 3, "an": 3, "ar": 3, "singl": [3, 6], "multipl": 3, "multi": [3, 6], "process": 3, "tpu": [3, 4, 6, 7], "pod": [3, 4, 6, 7], "deep": 3, "dive": 3, "lazi": 3, "memori": 3, "layout": 3, "move": 3, "cpu": [3, 6], "save": 3, "load": 3, "cach": 3, "further": [3, 7], "read": [3, 7], "torch_xla": [3, 5], "runtim": [3, 6], "xla_model": 3, "distribut": [3, 6, 7], "spmd": [3, 7], "experiment": 3, "do": 4, "distributeddataparallel": 4, "ddp": 4, "motiv": 4, "us": [4, 5, 7], "resnet50": 4, "mnist": 4, "real": [4, 5], "disclaim": 4, "fulli": [4, 7], "shard": [4, 7], "parallel": [4, 7], "fsdp": [4, 7], "script": 4, "imagenet": 4, "instal": 4, "clone": 4, "repo": 4, "v3": [4, 6], "8": 4, "50": 4, "10": 4, "billion": 4, "paramet": 4, "recompil": 5, "let": 5, "": 5, "first": 5, "start": 5, "fact": 5, "constraint": 5, "1": 5, "input": 5, "dataset": 5, "2": [5, 7], "oper": 5, "output": [5, 7], "bound": 5, "dynam": 5, "shape": 5, "can": 5, "fix": 5, "case": 5, "when": 5, "you": 5, "without": 5, "queri": 5, "its": 5, "dimens": 5, "what": [5, 7], "i": [5, 7], "3": 5, "control": 5, "flow": 5, "conclus": 5, "appendix": 5, "pjrt": 6, "tl": 6, "dr": 6, "benefit": 6, "quickstart": 6, "node": 6, "differ": 6, "xrt": 6, "multithread": 6, "v2": 6, "chang": 6, "xm": 6, "rendezv": 6, "torch": [6, 7, 8], "new": 6, "user": 7, "guid": 7, "mesh": 7, "partit": 7, "spec": 7, "via": 7, "gradient": 7, "checkpoint": 7, "huggingfac": 7, "llama": 7, "advanc": 7, "topic": 7, "hybrid": 7, "xlashardedtensor": 7, "dtensor": 7, "integr": [7, 8], "activ": 7, "torchdynamo": 8, "featur": 8, "gap": 8, "take": 8, "awai": 8}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"Troubleshooting": [[0, "troubleshooting"]], "Sanity Check": [[0, "sanity-check"]], "Check PyTorch/XLA Version": [[0, "check-pytorch-xla-version"]], "Perform A Simple Calculation": [[0, "perform-a-simple-calculation"]], "Run Resnet With Fake Data": [[0, "run-resnet-with-fake-data"]], "Performance Debugging": [[0, "performance-debugging"]], "PyTorch/XLA Debugging Tool": [[0, "pytorch-xla-debugging-tool"]], "Perform A Auto-Metrics Analysis": [[0, "perform-a-auto-metrics-analysis"]], "Compilation & Execution Analysis": [[0, "compilation-execution-analysis"]], "Get A Metrics Report": [[0, "get-a-metrics-report"]], "Understand The Metrics Report": [[0, "understand-the-metrics-report"]], "Clear The Metrics Report": [[0, "clear-the-metrics-report"]], "PyTorch/XLA + Dynamo Debugging Tool": [[0, "pytorch-xla-dynamo-debugging-tool"]], "Performance Profiling": [[0, "performance-profiling"]], "Simple Benchmarking": [[0, "simple-benchmarking"]], "Known Performance Caveats": [[0, "known-performance-caveats"]], "XLA Tensor Quirks": [[0, "xla-tensor-quirks"]], "More Debugging Tools": [[0, "more-debugging-tools"]], "Environment Variables": [[0, "environment-variables"]], "Common Debugging Environment Variables Combinations": [[0, "common-debugging-environment-variables-combinations"]], "Reproducing PyTorch/XLA CI/CD unit test failures.": [[0, "reproducing-pytorch-xla-ci-cd-unit-test-failures"]], "Eager Mode + Compile API": [[1, "eager-mode-compile-api"]], "Background": [[1, "background"]], "Basic Usage": [[1, "basic-usage"]], "Inference": [[1, "inference"], [8, "inference"]], "Training": [[1, "training"], [8, "training"]], "Benchmark": [[1, "benchmark"]], "How to run with PyTorch/XLA:GPU": [[2, "how-to-run-with-pytorch-xla-gpu"]], "Create a GPU instance": [[2, "create-a-gpu-instance"]], "Environment Setup": [[2, "environment-setup"]], "Docker": [[2, "docker"], [6, "docker"]], "Check environment variable": [[2, "check-environment-variable"]], "Wheel": [[2, "wheel"]], "Run some simple models": [[2, "run-some-simple-models"]], "MP_ImageNet Example": [[2, "mp-imagenet-example"]], "ResNet Example": [[2, "resnet-example"]], "AMP (AUTOMATIC MIXED PRECISION)": [[2, "amp-automatic-mixed-precision"]], "Develop PyTorch/XLA on a GPU instance (build PyTorch/XLA from source with GPU support)": [[2, "develop-pytorch-xla-on-a-gpu-instance-build-pytorch-xla-from-source-with-gpu-support"]], "PyTorch/XLA documentation": [[3, "pytorch-xla-documentation"]], "Docs": [[3, null]], "PyTorch on XLA Devices": [[3, "pytorch-on-xla-devices"]], "Creating an XLA Tensor": [[3, "creating-an-xla-tensor"]], "XLA Tensors are PyTorch Tensors": [[3, "xla-tensors-are-pytorch-tensors"]], "Running Models on XLA Devices": [[3, "running-models-on-xla-devices"]], "Running on a Single XLA Device": [[3, "running-on-a-single-xla-device"]], "Running on Multiple XLA Devices with Multi-processing": [[3, "running-on-multiple-xla-devices-with-multi-processing"]], "Running on TPU Pods": [[3, "running-on-tpu-pods"]], "XLA Tensor Deep Dive": [[3, "id3"]], "XLA Tensors are Lazy": [[3, "xla-tensors-are-lazy"]], "Memory Layout": [[3, "memory-layout"]], "Moving XLA Tensors to and from the CPU": [[3, "moving-xla-tensors-to-and-from-the-cpu"]], "Saving and Loading XLA Tensors": [[3, "saving-and-loading-xla-tensors"]], "Compilation Caching": [[3, "compilation-caching"]], "Further Reading": [[3, "further-reading"], [7, "further-reading"]], "PyTorch/XLA API": [[3, "pytorch-xla-api"]], "torch_xla": [[3, "module-torch_xla"]], "runtime": [[3, "module-torch_xla.runtime"]], "xla_model": [[3, "module-torch_xla.core.xla_model"]], "distributed": [[3, "module-torch_xla.distributed.parallel_loader"]], "spmd": [[3, "module-torch_xla.distributed.spmd"]], "experimental": [[3, "module-torch_xla.experimental"]], "debug": [[3, "module-torch_xla.debug.metrics"]], "How to do DistributedDataParallel(DDP)": [[4, "how-to-do-distributeddataparallel-ddp"]], "Background / Motivation": [[4, "background-motivation"]], "How to use DistributedDataParallel": [[4, "how-to-use-distributeddataparallel"]], "Benchmarking": [[4, "benchmarking"]], "Resnet50 with fake data": [[4, "resnet50-with-fake-data"]], "MNIST with fake data": [[4, "mnist-with-fake-data"]], "MNIST with real data": [[4, "mnist-with-real-data"]], "Disclaimer": [[4, "disclaimer"]], "Fully Sharded Data Parallel (FSDP) in PyTorch XLA": [[4, "fully-sharded-data-parallel-fsdp-in-pytorch-xla"]], "Example training scripts on MNIST and ImageNet": [[4, "example-training-scripts-on-mnist-and-imagenet"]], "Installation": [[4, "installation"]], "Clone PyTorch/XLA repo": [[4, "clone-pytorch-xla-repo"]], "Train MNIST on v3-8 TPU": [[4, "train-mnist-on-v3-8-tpu"]], "Train ImageNet with ResNet-50 on v3-8 TPU": [[4, "train-imagenet-with-resnet-50-on-v3-8-tpu"]], "Example training scripts on TPU pod (with 10 billion parameters)": [[4, "example-training-scripts-on-tpu-pod-with-10-billion-parameters"]], "Source of recompilations in torch_xla": [[5, "source-of-recompilations-in-torch-xla"]], "Let\u2019s first start with some facts/constraints:": [[5, "lets-first-start-with-some-facts-constraints"]], "#1. From input dataset.": [[5, "from-input-dataset"]], "#2. From operator output": [[5, "from-operator-output"]], "2.1 Bounded dynamic shape can fix the case when you use the tensor with dynamic shape as a Tensor, without querying its real dimension.": [[5, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"]], "2.2 what if real dimension is queried on a tensor with dynamic shape?": [[5, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"]], "#3. From control flow": [[5, "from-control-flow"]], "Conclusion:": [[5, "conclusion"]], "Appendix:": [[5, "appendix"]], "PJRT Runtime": [[6, "pjrt-runtime"]], "TL;DR": [[6, "tl-dr"]], "Benefits": [[6, "benefits"]], "Quickstart": [[6, "quickstart"]], "CPU": [[6, "cpu"]], "TPU": [[6, "tpu"]], "Pods": [[6, "pods"]], "GPU": [[6, "gpu"]], "Single-node GPU training": [[6, "single-node-gpu-training"]], "Multi-node GPU training": [[6, "multi-node-gpu-training"]], "Differences from XRT": [[6, "differences-from-xrt"]], "Multithreading on TPU v2/v3": [[6, "id3"]], "Changes to xm.rendezvous": [[6, "changes-to-xm-rendezvous"]], "PJRT and torch.distributed": [[6, "pjrt-and-torch-distributed"]], "Performance": [[6, "performance"]], "New TPU runtime": [[6, "new-tpu-runtime"]], "PyTorch/XLA SPMD User Guide": [[7, "pytorch-xla-spmd-user-guide"]], "What is PyTorch/XLA SPMD?": [[7, "what-is-pytorch-xla-spmd"]], "How to use PyTorch/XLA SPMD?": [[7, "how-to-use-pytorch-xla-spmd"]], "SPMD Mode": [[7, "spmd-mode"]], "Mesh": [[7, "mesh"]], "Partition Spec": [[7, "partition-spec"]], "Fully Sharded Data Parallel(FSDP) via SPMD": [[7, "fully-sharded-data-parallel-fsdp-via-spmd"]], "Sharding output": [[7, "sharding-output"]], "Gradient checkpointing": [[7, "gradient-checkpointing"]], "HuggingFace Llama 2 Example": [[7, "huggingface-llama-2-example"]], "PyTorch/XLA SPMD advanced topics": [[7, "pytorch-xla-spmd-advanced-topics"]], "Hybrid Mesh": [[7, "hybrid-mesh"]], "Running SPMD on TPU Pod": [[7, "running-spmd-on-tpu-pod"]], "XLAShardedTensor": [[7, "xlashardedtensor"]], "DTensor Integration": [[7, "dtensor-integration"]], "Activation Sharding for torch.compile": [[7, "activation-sharding-for-torch-compile"]], "SPMD Debugging Tool": [[7, "spmd-debugging-tool"]], "Auto-Sharding": [[7, "auto-sharding"]], "Distributed Checkpointing": [[7, "distributed-checkpointing"]], "TorchDynamo(torch.compile) integration in PyTorch XLA": [[8, "torchdynamo-torch-compile-integration-in-pytorch-xla"]], "Integration": [[8, "integration"]], "Feature gaps": [[8, "feature-gaps"]], "Take away": [[8, "take-away"]]}, "indexentries": {"hybridmesh (class in torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.HybridMesh"]], "mesh (class in torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.Mesh"]], "parallelloader (class in torch_xla.distributed.parallel_loader)": [[3, "torch_xla.distributed.parallel_loader.ParallelLoader"]], "add_step_closure() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.add_step_closure"]], "addressable_device_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.addressable_device_count"]], "all_gather() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.all_gather"]], "all_reduce() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.all_reduce"]], "all_to_all() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.all_to_all"]], "clear_sharding() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.clear_sharding"]], "compile() (in module torch_xla)": [[3, "torch_xla.compile"]], "counter_names() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.counter_names"]], "counter_value() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.counter_value"]], "device() (in module torch_xla)": [[3, "torch_xla.device"]], "device_count() (in module torch_xla)": [[3, "torch_xla.device_count"]], "device_type() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.device_type"]], "devices() (in module torch_xla)": [[3, "torch_xla.devices"]], "eager_mode() (in module torch_xla.experimental)": [[3, "torch_xla.experimental.eager_mode"]], "get_1d_mesh() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.get_1d_mesh"]], "get_global_mesh() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.get_global_mesh"]], "get_master_ip() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.get_master_ip"]], "get_memory_info() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.get_memory_info"]], "get_rng_state() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.get_rng_state"]], "get_stablehlo() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.get_stablehlo"]], "get_stablehlo_bytecode() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.get_stablehlo_bytecode"]], "global_device_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.global_device_count"]], "global_ordinal() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.global_ordinal"]], "global_runtime_device_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.global_runtime_device_count"]], "initialize_cache() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.initialize_cache"]], "is_master_ordinal() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.is_master_ordinal"]], "is_spmd() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.is_spmd"]], "local_device_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.local_device_count"]], "local_ordinal() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.local_ordinal"]], "local_process_count() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.local_process_count"]], "manual_seed() (in module torch_xla)": [[3, "torch_xla.manual_seed"]], "mark_sharding() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.mark_sharding"]], "mesh_reduce() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.mesh_reduce"]], "metric_data() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.metric_data"]], "metric_names() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.metric_names"]], "metrics_report() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.metrics_report"]], "module": [[3, "module-torch_xla"], [3, "module-torch_xla.core.xla_model"], [3, "module-torch_xla.debug.metrics"], [3, "module-torch_xla.distributed.parallel_loader"], [3, "module-torch_xla.distributed.spmd"], [3, "module-torch_xla.distributed.xla_multiprocessing"], [3, "module-torch_xla.experimental"], [3, "module-torch_xla.runtime"]], "optimizer_step() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.optimizer_step"]], "per_device_loader() (torch_xla.distributed.parallel_loader.parallelloader method)": [[3, "torch_xla.distributed.parallel_loader.ParallelLoader.per_device_loader"]], "rendezvous() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.rendezvous"]], "save() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.save"]], "set_global_mesh() (in module torch_xla.distributed.spmd)": [[3, "torch_xla.distributed.spmd.set_global_mesh"]], "set_rng_state() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.set_rng_state"]], "short_metrics_report() (in module torch_xla.debug.metrics)": [[3, "torch_xla.debug.metrics.short_metrics_report"]], "spawn() (in module torch_xla.distributed.xla_multiprocessing)": [[3, "torch_xla.distributed.xla_multiprocessing.spawn"]], "sync() (in module torch_xla)": [[3, "torch_xla.sync"]], "torch_xla": [[3, "module-torch_xla"]], "torch_xla.core.xla_model": [[3, "module-torch_xla.core.xla_model"]], "torch_xla.debug.metrics": [[3, "module-torch_xla.debug.metrics"]], "torch_xla.distributed.parallel_loader": [[3, "module-torch_xla.distributed.parallel_loader"]], "torch_xla.distributed.spmd": [[3, "module-torch_xla.distributed.spmd"]], "torch_xla.distributed.xla_multiprocessing": [[3, "module-torch_xla.distributed.xla_multiprocessing"]], "torch_xla.experimental": [[3, "module-torch_xla.experimental"]], "torch_xla.runtime": [[3, "module-torch_xla.runtime"]], "use_spmd() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.use_spmd"]], "wait_device_ops() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.wait_device_ops"]], "world_size() (in module torch_xla.runtime)": [[3, "torch_xla.runtime.world_size"]], "xla_device() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.xla_device"]], "xla_device_hw() (in module torch_xla.core.xla_model)": [[3, "torch_xla.core.xla_model.xla_device_hw"]]}}) \ No newline at end of file diff --git a/master/spmd.html b/master/spmd.html index f42c2f6ac77..66748924d24 100644 --- a/master/spmd.html +++ b/master/spmd.html @@ -267,7 +267,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )
diff --git a/master/torch_compile.html b/master/torch_compile.html index adcaedd8c85..44778886119 100644 --- a/master/torch_compile.html +++ b/master/torch_compile.html @@ -266,7 +266,7 @@
- master (2.5.0+git52ea89f ) + master (2.5.0+git4e94ff6 )