Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

internalize _parse_xla_device #7675

Merged
merged 17 commits into from
Jul 18, 2024
8 changes: 8 additions & 0 deletions torch_xla/_internal/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import logging
import re


def parse_xla_device(device: str):
m = re.match(r'([A-Z]+):(\d+)$', device)
if m:
return (m.group(1), int(m.group(2)))
14 changes: 8 additions & 6 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch_xla.utils.utils as xu
import torch_xla.utils.closures as xc
import os
from torch_xla.experimental.deprecation import register_deprecated
import torch_xla._internal.utils as iutils
zpcore marked this conversation as resolved.
Show resolved Hide resolved

_DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())

Expand All @@ -40,6 +42,12 @@

XLA_LIB = Library("xla", "DEF")

aliases = [
iutils.parse_xla_device,
]
for alias in aliases:
register_deprecated(torch_xla.core, alias)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think this should actually be torch_xla.core.xla_model looking at the implementation of register_deprecated. Please double-check that you get the expected message when you call torch_xla.core.xla_model.parse_xla-device(...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to parse_xla_device = deprecated(torch_xla.core, iutils.parse_xla_device) instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message still says torch_xla.core.parse_xla_device instead of torch_xla.core.xla_model.parse_xla_device:

root@t1v-n-bf2f726f-w-0:/workspaces/ptxla# python
...
>>> import torch_xla.core.xla_model as xm
>>> xm.parse_xla_device('TPU:0')
WARNING:root:torch_xla.core.parse_xla_device is deprecated. Use torch_xla._internal.utils.parse_xla_device instead.
('TPU', 0)

You'll actually need to pass the xla_model module into register_deprecated. That was the only way I could find to print the qualified name of the aliased method consistently, but feel free to refactor it if you find a better way.

I think we need to add a unit test here to confirm that the right warning is getting printed, since you're going to be adding a lot of these aliases and this is an easy mistake to make. You can create a new unit test file for all of these public to internal aliases and we can remove it later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work:

import torch_xla._internal.utils as _utils
from . import xla_model as this_module
parse_xla_device = deprecated(this_module, _utils.parse_xla_device)



def _init_world_size_ordinal():
global _WORLD_SIZE, _ORDINAL
Expand Down Expand Up @@ -76,12 +84,6 @@ def is_xla_tensor(tensor):
return tensor.device.type == 'xla'


def parse_xla_device(device):
m = re.match(r'([A-Z]+):(\d+)$', device)
if m:
return (m.group(1), int(m.group(2)))


def get_xla_supported_devices(devkind=None, max_devices=None):
"""Returns a list of supported devices of a given kind.

Expand Down
3 changes: 2 additions & 1 deletion torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla._internal.utils as iutils
from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
import torch_xla.runtime as xr

Expand Down Expand Up @@ -219,7 +220,7 @@ def __init__(self,
mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)])
self.device_attributes = xr.global_runtime_device_attributes()
self.device_attributes.sort(
key=lambda attr: xm.parse_xla_device(attr['name'])[1])
key=lambda attr: iutils.parse_xla_device(attr['name'])[1])

if 'slice_index' in self.device_attributes[0] and np.prod(
dcn_mesh_shape) == 1:
Expand Down
Loading