From 31fbf4af9ca5c0cf02956a6dd0ab5fdde2af0866 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Tue, 22 Mar 2022 14:37:44 +0800 Subject: [PATCH] add _modules and _parameters property --- python/jittor/__init__.py | 10 +++++++++- python/jittor/test/test_core.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 6b56c972..94181785 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.1.49' +__version__ = '1.3.1.50' from jittor_utils import lock with lock.lock_scope(): ori_int = int @@ -957,6 +957,14 @@ def callback_leave(parents, k, v, n): self.dfs([], "", callback, callback_leave) return ms + @property + def _modules(self): + return { k:v for k,v in self.__dict__.items() if isinstance(v, Module) } + + @property + def _parameters(self): + return { k:v for k,v in self.__dict__.items() if isinstance(v, Var) } + def requires_grad_(self, requires_grad=True): self._requires_grad = requires_grad self._place_hooker() diff --git a/python/jittor/test/test_core.py b/python/jittor/test/test_core.py index f1fc1b63..86d88a16 100644 --- a/python/jittor/test/test_core.py +++ b/python/jittor/test/test_core.py @@ -106,5 +106,18 @@ def test_module(self): a.y = 2 assert a.y == 2 + def test_modules(self): + a = jt.Module() + a.x = jt.Module() + a.y = jt.Module() + a.a = jt.array([1,2,3]) + a.b = jt.array([1,2,3]) + assert a._modules.keys() == ["x", "y"] + assert a._modules['x'] is a.x + assert a._modules['y'] is a.y + assert a._parameters.keys() == ['a', 'b'] + assert a._parameters['a'] is a.a + assert a._parameters['b'] is a.b + if __name__ == "__main__": unittest.main() \ No newline at end of file