Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add depreacated decorator #7703

Merged
merged 6 commits into from
Jul 18, 2024
Merged

add depreacated decorator #7703

merged 6 commits into from
Jul 18, 2024

Conversation

zpcore
Copy link
Collaborator

@zpcore zpcore commented Jul 17, 2024

Introduce the mark_deprecated decorator to map deprecated function to the new function. The reason is that existing deprecated function only maps to function with the same name. It's impossible to map function like xrt_world_size to world_size.

Instead of passing old function to the deprecation wrapper, I think use the decorator above the old function is the most straightward.

Example to use:

import torch_xla
from torch_xla.experimental.deprecation import mark_deprecated

@mark_deprecated(torch_xla.runtime.world_size)
def xrt_world_size(defval=1):
  """Retrieves the number of devices which is taking part of the replication.

  Args:
    defval (int, optional): The default value to be returned in case there is no
      replication information available.
      Default: 1

  Returns:
    The number of devices which is taking part of the replication.
  """
  global _WORLD_SIZE
  if _WORLD_SIZE is not None:
    return _WORLD_SIZE

  return runtime.world_size()

Output:

>>> import torch_xla
>>> torch_xla.core.xla_model.xrt_world_size()
WARNING:root:torch_xla.core.xla_model.xrt_world_size is deprecated. Use torch_xla.runtime.world_size instead.
1

@zpcore zpcore requested a review from will-cromar July 17, 2024 06:19
@zpcore zpcore marked this pull request as ready for review July 17, 2024 06:20
@will-cromar
Copy link
Collaborator

The decorator makes sense here, but not because It's impossible to map function like xrt_world_size to world_size.. For example, see how these three functions in experimental.pjrt were renamed after being deprecated:

rendezvous = deprecated(this_module, xm.xla_rendezvous)
device_attributes = deprecated(this_module, runtime.runtime_device_attributes)
global_device_attributes = deprecated(this_module,
runtime.global_runtime_device_attributes)

You would need the decorator if the implementation of the old function is actually different than the new one. In this case, the signatures of xrt_world_size ((int) -> int) and world_size (() -> int) differ.

Having said that, you can make mark_deprecated DRY. Remember that a decorator is a function that takes one function as an argument and returns a function. If you use functools.partial to fill the first argument of deprecated ((module, fn) -> fn), the partial callable is (fn) -> fn (ie a decorator. For any module, you can create a decorator like this:

mark_deprecated = functools.partial(deprecated, this_module)

Or, you can put this idea to work in experimental.deprecation to reduce the duplicated logic in your implementation:

def mark_deprecated(module):
  def f(func):
    return deprecated(module, func)
  return f

@zpcore
Copy link
Collaborator Author

zpcore commented Jul 17, 2024

The decorator makes sense here, but not because It's impossible to map function like xrt_world_size to world_size.. For example, see how these three functions in experimental.pjrt were renamed after being deprecated:

rendezvous = deprecated(this_module, xm.xla_rendezvous)
device_attributes = deprecated(this_module, runtime.runtime_device_attributes)
global_device_attributes = deprecated(this_module,
runtime.global_runtime_device_attributes)

You would need the decorator if the implementation of the old function is actually different than the new one. In this case, the signatures of xrt_world_size ((int) -> int) and world_size (() -> int) differ.

Having said that, you can make mark_deprecated DRY. Remember that a decorator is a function that takes one function as an argument and returns a function. If you use functools.partial to fill the first argument of deprecated ((module, fn) -> fn), the partial callable is (fn) -> fn (ie a decorator. For any module, you can create a decorator like this:

mark_deprecated = functools.partial(deprecated, this_module)

Or, you can put this idea to work in experimental.deprecation to reduce the duplicated logic in your implementation:

def mark_deprecated(module):
  def f(func):
    return deprecated(module, func)
  return f

Thanks for the point. It's mostly about the warning message. If we did something like:

 rendezvous = deprecated(this_module, xm.xla_rendezvous) 

It will warn like this_module.xla_rendezvous is deprecated with xm.xla_rendezvous, while there is no function named xla_rendezvous originally in this module.

@will-cromar
Copy link
Collaborator

It will warn like this_module.xla_rendezvous is deprecated with xm.xla_rendezvous, while there is no function named xla_rendezvous originally in this module.

I see, good catch. It looks like deprecated is missing a parameter to print the correct warning then. You can add something like this:

def deprecated(module, new_fn, old_name=None):
  old_name = old_name or newfn.__name__
  ...
     logging.warning(f'{module.__name__}.{old_name} is deprecated. Use {new.__module__}.{new.__name__} instead.)
  ...

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@zpcore zpcore merged commit a1e99a8 into master Jul 18, 2024
23 checks passed
@zpcore zpcore deleted the piz/new_deprecated branch July 18, 2024 16:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants