Skip to content

Commit

Permalink
update example dir's README (#7136)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored May 29, 2024
1 parent 6f406b7 commit c7bbdfb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
17 changes: 16 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
## Overview
This repo aims to provide some basic examples of how to run an existing pytorch model with PyTorch/XLA. train_resnet_base.py is a minimal trainer to run ResNet50 with fake data on a single device. Other examples will import the train_resnet_base and demonstrate how to enable different features(distributed training, profiling, dynamo etc) on PyTorch/XLA.The objective of this repository is to offer fundamental examples of executing an existing PyTorch model utilizing PyTorch/XLA. train_resnet_base.py acts as a bare-bones trainer for running ResNet50 with simulated data on an individual device. Additional examples will import train_resnet_base and illustrate how to activate various features (e.g., distributed training, profiling, dynamo) on PyTorch/XLA.
This repo aims to provide some basic examples of how to run an existing pytorch model with PyTorch/XLA. `train_resnet_base.py` is a minimal trainer to run ResNet50 with fake data on a single device. `train_decoder_only_base.py` is similar to `train_resnet_base.py` but with a decoder only model.

Other examples will import the `train_resnet_base` or `train_decoder_only_base` and demonstrate how to enable different features(distributed training, profiling, dynamo etc) on PyTorch/XLA.The objective of this repository is to offer fundamental examples of executing an existing PyTorch model utilizing PyTorch/XLA.

## Setup
Follow our [README](https://github.com/pytorch/xla#getting-started) to install latest release of torch_xla. Check out this [link](https://github.com/pytorch/xla#python-packages) for torch_xla at other versions. To install the nightly torchvision(required for the resnet) you can do

```shell
pip install --no-deps --pre torchvision -i https://download.pytorch.org/whl/nightly/cu118
```

## Run the example
You can run all models directly. Only environment you want to set is `PJRT_DEVICE`.
```
PJRT_DEVICE=TPU python fsdp/train_decoder_only_fsdp_v2.py
```
6 changes: 4 additions & 2 deletions examples/debug/train_resnet_profile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import os
import sys
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_resnet_base import TrainResNetBase
Expand All @@ -13,9 +13,11 @@
if __name__ == '__main__':
base = TrainResNetBase()
profile_port = 9012
# you can also set profile_logdir to a gs bucket, for example
# profile_logdir = "gs://your_gs_bucket/profile"
profile_logdir = "/tmp/profile/"
duration_ms = 30000
assert os.path.exists(profile_logdir)
assert profile_logdir.startswith('gs://') or os.path.exists(profile_logdir)
server = xp.start_server(profile_port)
# Ideally you want to start the profile tracing after the initial compilation, for example
# at step 5.
Expand Down

0 comments on commit c7bbdfb

Please sign in to comment.