diff --git a/examples/README.md b/examples/README.md index 3bebd3b860f..1ad0018c981 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,2 +1,17 @@ ## Overview -This repo aims to provide some basic examples of how to run an existing pytorch model with PyTorch/XLA. train_resnet_base.py is a minimal trainer to run ResNet50 with fake data on a single device. Other examples will import the train_resnet_base and demonstrate how to enable different features(distributed training, profiling, dynamo etc) on PyTorch/XLA.The objective of this repository is to offer fundamental examples of executing an existing PyTorch model utilizing PyTorch/XLA. train_resnet_base.py acts as a bare-bones trainer for running ResNet50 with simulated data on an individual device. Additional examples will import train_resnet_base and illustrate how to activate various features (e.g., distributed training, profiling, dynamo) on PyTorch/XLA. +This repo aims to provide some basic examples of how to run an existing pytorch model with PyTorch/XLA. `train_resnet_base.py` is a minimal trainer to run ResNet50 with fake data on a single device. `train_decoder_only_base.py` is similar to `train_resnet_base.py` but with a decoder only model. + +Other examples will import the `train_resnet_base` or `train_decoder_only_base` and demonstrate how to enable different features(distributed training, profiling, dynamo etc) on PyTorch/XLA.The objective of this repository is to offer fundamental examples of executing an existing PyTorch model utilizing PyTorch/XLA. + +## Setup +Follow our [README](https://github.com/pytorch/xla#getting-started) to install latest release of torch_xla. Check out this [link](https://github.com/pytorch/xla#python-packages) for torch_xla at other versions. To install the nightly torchvision(required for the resnet) you can do + +```shell +pip install --no-deps --pre torchvision -i https://download.pytorch.org/whl/nightly/cu118 +``` + +## Run the example +You can run all models directly. Only environment you want to set is `PJRT_DEVICE`. +``` +PJRT_DEVICE=TPU python fsdp/train_decoder_only_fsdp_v2.py +``` diff --git a/examples/debug/train_resnet_profile.py b/examples/debug/train_resnet_profile.py index 886c25f65b7..158c138dba7 100644 --- a/examples/debug/train_resnet_profile.py +++ b/examples/debug/train_resnet_profile.py @@ -1,5 +1,5 @@ import os -import os +import sys example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) sys.path.append(example_folder) from train_resnet_base import TrainResNetBase @@ -13,9 +13,11 @@ if __name__ == '__main__': base = TrainResNetBase() profile_port = 9012 + # you can also set profile_logdir to a gs bucket, for example + # profile_logdir = "gs://your_gs_bucket/profile" profile_logdir = "/tmp/profile/" duration_ms = 30000 - assert os.path.exists(profile_logdir) + assert profile_logdir.startswith('gs://') or os.path.exists(profile_logdir) server = xp.start_server(profile_port) # Ideally you want to start the profile tracing after the initial compilation, for example # at step 5.