Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Jul 18, 2024
1 parent 9422bd8 commit 41e8877
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 2 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_python_ops.py"
run_test "$CDIR/test_ops.py"
run_test "$CDIR/test_metrics.py"
run_test "$CDIR/test_deprecation.py"
run_test "$CDIR/dynamo/test_dynamo_integrations_util.py"
run_test "$CDIR/dynamo/test_dynamo_aliasing.py"
run_test "$CDIR/dynamo/test_dynamo.py"
Expand Down
55 changes: 55 additions & 0 deletions test/test_deprecation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import logging
import io
import unittest
import importlib

from unittest.mock import patch

from torch_xla.experimental.deprecation import deprecated, mark_deprecated


def old_function():
return False


def new_function():
return True


@mark_deprecated(new_function)
def old_funtion_to_wrap():
return False


class TestDepecation(unittest.TestCase):

def test_map_to_new_func(self):
this_module = importlib.import_module(__name__)
old_function = deprecated(this_module, this_module.new_function,
"random_old_name")
with self.assertLogs(level='WARNING') as log:
result = old_function()
self.assertIn("random_old_name", log.output[0])
assert (result)

@patch('__main__.new_function')
def test_decorator(self, mock_new_function):
mock_new_function.__name__ = "new_name"
mock_new_function.return_value = True
with self.assertLogs(level='WARNING') as log:

@mark_deprecated(new_function)
def function_to_deprecate():
return False

result = function_to_deprecate()
assert (result)
self.assertIn("function_to_deprecate", log.output[0])
self.assertIn(mock_new_function.__name__, log.output[0])
mock_new_function.assert_called_once()


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
6 changes: 4 additions & 2 deletions torch_xla/experimental/deprecation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import logging
import importlib
from typing import TypeVar

FN = TypeVar('FN')
Expand All @@ -22,7 +23,7 @@ def wrapped(*args, **kwargs):
return wrapped


def mark_deprecated(module, new: FN) -> FN:
def mark_deprecated(new: FN) -> FN:
"""Decorator to mark a function as deprecated and map to new function.
Args:
Expand All @@ -34,7 +35,8 @@ def mark_deprecated(module, new: FN) -> FN:
"""

def decorator(func):
return deprecated(module, new, old_name=func.__name__)
return deprecated(
importlib.import_module(func.__module__), new, old_name=func.__name__)

return decorator

Expand Down

0 comments on commit 41e8877

Please sign in to comment.