Skip to content

TPU Testing

TPU Testing #289

Workflow file for this run

name: TPU Integration Test
run-name: TPU Testing
on:
workflow_dispatch:
pull_request:
branches:
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/torch_xla2/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/torch_xla2/**'
jobs:
tpu-test:
runs-on: v4-runner-set
steps:
- name: Checkout and Setup PyTorch Repo
env:
_GLIBCXX_USE_CXX11_ABI: 0
run: |
git clone -b release/2.3 --recursive https://github.com/pytorch/pytorch
cd pytorch/
python3 setup.py install --user
- name: Install torchvision
run: |
cd pytorch/
pip install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
- name: Checkout PyTorch/XLA Repo
uses: actions/checkout@v4
with:
path: pytorch/xla
- name: Run PyTorch/XLA Setup
env:
BAZEL_VERBOSE: 1
TPUVM_MODE: 1
run: |
cd pytorch/xla
python3 setup.py install --user
- name: Run Tests
env:
PJRT_DEVICE: TPU
# Jax is needed for pallas tests.
run: |
pip install fsspec
pip install rich
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
cd pytorch/xla
test/tpu/run_tests.sh