diff --git a/test/run_tests.sh b/test/run_tests.sh index d1e8bef6125..3c8d7f405e3 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -174,6 +174,7 @@ function run_xla_op_tests1 { run_test "$CDIR/test_python_ops.py" run_test "$CDIR/test_ops.py" run_test "$CDIR/test_metrics.py" + run_test "$CDIR/test_deprecation.py" run_test "$CDIR/dynamo/test_dynamo_integrations_util.py" run_test "$CDIR/dynamo/test_dynamo_aliasing.py" run_test "$CDIR/dynamo/test_dynamo.py" diff --git a/test/test_deprecation.py b/test/test_deprecation.py new file mode 100644 index 00000000000..d813384c0d0 --- /dev/null +++ b/test/test_deprecation.py @@ -0,0 +1,48 @@ +import os +import logging +import io +import unittest +import importlib + +from torch_xla.experimental.deprecation import deprecated, mark_deprecated + + +def old_function(): + return False + + +def new_function(): + return True + + +class TestDepecation(unittest.TestCase): + + def test_map_to_new_func(self): + this_module = importlib.import_module(__name__) + old_function = deprecated(this_module, this_module.new_function, + "random_old_name") + with self.assertLogs(level='WARNING') as log: + result = old_function() + self.assertIn("random_old_name", log.output[0]) + assert (result) + + @unittest.mock.patch('__main__.new_function') + def test_decorator(self, mock_new_function): + mock_new_function.__name__ = "new_name" + mock_new_function.return_value = True + with self.assertLogs(level='WARNING') as log: + + @mark_deprecated(new_function) + def function_to_deprecate(): + return False + + result = function_to_deprecate() + assert (result) + self.assertIn("function_to_deprecate", log.output[0]) + self.assertIn(mock_new_function.__name__, log.output[0]) + mock_new_function.assert_called_once() + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/experimental/deprecation.py b/torch_xla/experimental/deprecation.py index 608939c02cc..b24af695727 100644 --- a/torch_xla/experimental/deprecation.py +++ b/torch_xla/experimental/deprecation.py @@ -1,18 +1,20 @@ import functools import logging +import importlib from typing import TypeVar FN = TypeVar('FN') -def deprecated(module, new: FN) -> FN: +def deprecated(module, new: FN, old_name=None) -> FN: already_warned = [False] + old_name = old_name or new.__name__ @functools.wraps(new) def wrapped(*args, **kwargs): if not already_warned[0]: logging.warning( - f'{module.__name__}.{new.__name__} is deprecated. Use {new.__module__}.{new.__name__} instead.' + f'{module.__name__}.{old_name} is deprecated. Use {new.__module__}.{new.__name__} instead.' ) already_warned[0] = True @@ -21,5 +23,23 @@ def wrapped(*args, **kwargs): return wrapped +def mark_deprecated(new: FN) -> FN: + """Decorator to mark a function as deprecated and map to new function. + + Args: + module: current module of the deprecated function that is in. Assume current module name is X, you can use `from . import X` and pass X here. + new: new function that we map to. Need to include the path the new function that is in. + + Returns: + Wrapper of the new function. + """ + + def decorator(func): + return deprecated( + importlib.import_module(func.__module__), new, old_name=func.__name__) + + return decorator + + def register_deprecated(module, new: FN): setattr(module, new.__name__, deprecated(module, new))