Skip to content

Commit

Permalink
Wrap adding logging options in try / catch. (#7307)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Jun 18, 2024
1 parent 91389c7 commit c275371
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,30 @@
'extract_jax',
]


from jax._src import xla_bridge
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
jax.config.update('jax_enable_x64', True)
jax.config.update(
'jax_pjrt_client_create_options',
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}'
)
old_pjrt_options = jax.config.jax_pjrt_client_create_options

try:
jax.config.update(
'jax_pjrt_client_create_options',
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}'
)
xla_bridge._clear_backends()
jax.devices() # open PJRT to see if it opens
except RuntimeError:
jax.config.update(
'jax_pjrt_client_create_options', old_pjrt_options
)
xla_bridge._clear_backends()
jax.devices() # open PJRT to see if it opens


env = None
def default_env():
global env

os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')

if env is None:
env = tensor.Environment()
return env
Expand Down

0 comments on commit c275371

Please sign in to comment.