Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Sep 25, 2024
1 parent 7dd0663 commit 757ed0c
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ on:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/torch_xla2/**'
- 'experimental/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/torch_xla2/**'
- 'experimental/**'
workflow_dispatch:

concurrency:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_upstream_image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/torch_xla2/**'
- 'experimental/**'
workflow_dispatch:
jobs:
build:
Expand Down
1 change: 1 addition & 0 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
instantiate_device_type_tests, ops)
from torch.utils import _pytree as pytree
from torch_xla2 import tensor
import torch_xla2


skiplist = {
Expand Down
15 changes: 13 additions & 2 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from jax._src import xla_bridge
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
jax.config.update('jax_enable_x64', True)

# torch_xla2:oss-begin
old_pjrt_options = jax.config.jax_pjrt_client_create_options
Expand Down Expand Up @@ -80,4 +79,16 @@ def disable_globally():
unsupported_dtype=unsupported_dtype)

import jax
torch._register_device_module('jax', jax)
torch._register_device_module('jax', jax)


def enable_accuracy_mode():
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_default_matmul_precision', 'highest')
default_env().config.internal_respect_torch_return_dtypes = True


def enable_performance_mode():
jax.config.update('jax_enable_x64', False)
jax.config.update('jax_default_matmul_precision', 'default')
default_env().config.internal_respect_torch_return_dtypes = False
1 change: 1 addition & 0 deletions experimental/torch_xla2/torch_xla2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ class Configuration:
# device
treat_cuda_as_jax_device: bool = True
use_torch_native_for_cpu_tensor: bool = False
internal_respect_torch_return_dtypes: bool = False

0 comments on commit 757ed0c

Please sign in to comment.