diff --git a/cpp/modmesh/toggle/pymod/wrap_profile.cpp b/cpp/modmesh/toggle/pymod/wrap_profile.cpp index 82c43586..fdc213ac 100644 --- a/cpp/modmesh/toggle/pymod/wrap_profile.cpp +++ b/cpp/modmesh/toggle/pymod/wrap_profile.cpp @@ -169,6 +169,7 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapCallProfiler : public WrapBase wrapped_type & { return wrapped_type::instance(); }) - .def("stat", [](CallProfiler & profiler) - { - std::stringstream ss; - profiler.print_statistics(ss); - return ss.str(); }) - .def("result", [](CallProfiler & profiler) - { - const RadixTreeNode * root = profiler.radix_tree().get_root(); - if (root->empty_children()) { - return py::dict(); - } - py::dict result; - std::queue*> node_queue; - std::unordered_map*, py::dict> dict_storage; - - node_queue.push(root); - dict_storage[root] = result; - - while (!node_queue.empty()) { - const RadixTreeNode* cur_node = node_queue.front(); - const py::dict& current_dict = dict_storage[cur_node]; - node_queue.pop(); - - current_dict["name"] = cur_node->name(); - current_dict["total_time"] = cur_node->data().total_time.count() / 1e6; - current_dict["count"] = cur_node->data().call_count; - if (cur_node == profiler.radix_tree().get_current_node()){ - current_dict["current_node"] = true; - } - - py::list children_list; - for (const auto& child : cur_node->children()) { - dict_storage[child.get()] = py::dict(); - py::dict& child_dict = dict_storage[child.get()]; - children_list.append(child_dict); - node_queue.push(child.get()); - } - current_dict["children"] = children_list; - } - return result; }) + .def( + "stat", + [](CallProfiler & profiler) + { + std::stringstream ss; + profiler.print_statistics(ss); + return ss.str(); + }) + .def("result", &result) .def("reset", &wrapped_type::reset); ; @@ -228,6 +198,49 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapCallProfiler : public WrapBase * root = profiler.radix_tree().get_root(); + if (root->empty_children()) + { + return {}; + } + py::dict result; + std::queue *> node_queue; + std::unordered_map *, py::dict> dict_storage; + + node_queue.push(root); + dict_storage[root] = result; + + while (!node_queue.empty()) + { + const RadixTreeNode * cur_node = node_queue.front(); + const py::dict & current_dict = dict_storage[cur_node]; + node_queue.pop(); + + current_dict["name"] = cur_node->name(); + current_dict["total_time"] = cur_node->data().total_time.count() / 1e6; + current_dict["count"] = cur_node->data().call_count; + if (cur_node == profiler.radix_tree().get_current_node()) + { + current_dict["current_node"] = true; + } + + py::list children_list; + for (const auto & child : cur_node->children()) + { + dict_storage[child.get()] = py::dict(); + py::dict & child_dict = dict_storage[child.get()]; + children_list.append(child_dict); + node_queue.push(child.get()); + } + current_dict["children"] = children_list; + } + return result; +} + class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapCallProfilerProbe : public WrapBase { public: @@ -265,4 +278,4 @@ void wrap_profile(pybind11::module & mod) } /* end namespace modmesh */ -// vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: +// vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: \ No newline at end of file diff --git a/tests/test_callprofiler.py b/tests/test_callprofiler.py index 66a641c9..f3044446 100644 --- a/tests/test_callprofiler.py +++ b/tests/test_callprofiler.py @@ -46,7 +46,6 @@ def wrapper(*args, **kwargs): _ = modmesh.CallProfilerProbe(func.__name__) result = func(*args, **kwargs) return result - return wrapper @@ -83,12 +82,9 @@ def foo1(): foo1() result = modmesh.call_profiler.result() - path = os.path.join( - os.path.abspath(os.path.dirname(__file__)), - "data", - "profiler_python_schema.json", - ) - with open(path, "r") as schema_file: + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), + "data", "profiler_python_schema.json") + with open(path, 'r') as schema_file: schema = json.load(schema_file) try: @@ -234,7 +230,7 @@ def baz(): self.assertEqual(words[2], "calls") self.assertEqual(words[3], "in") ref_total_time = time1 * 4 + time2 * 2 + time3 - self.assertTrue(abs(float(words[4]) - ref_total_time) <= 6e-4) + self.assertLessEqual(abs(float(words[4]) - ref_total_time), 8e-4) self.assertEqual(words[5], "seconds") # Check the second line @@ -262,24 +258,41 @@ def baz(): "cumulative_time": float(words[4]), "cumulative_per_call": float(words[5]), } + bar_dict = stat_dict["bar"] + diff_total_time = abs(bar_dict["total_time"] - time1 * 4) + diff_total_per_call = abs(bar_dict["total_per_call"] - time1) + diff_cul_time = abs(bar_dict["cumulative_time"] - time1 * 4) + diff_cul_per_call = abs(bar_dict["cumulative_per_call"] - time1) + self.assertEqual(bar_dict["call_count"], 4) - self.assertTrue(bar_dict["total_time"] - (time1 * 4) <= 3e-4) - self.assertTrue(bar_dict["total_per_call"] - time1 <= 3e-4) - self.assertTrue(bar_dict["cumulative_time"] - (time1 * 4) <= 3e-4) - self.assertTrue(bar_dict["cumulative_per_call"] - time1 <= 3e-4) + self.assertLessEqual(diff_total_time, 3e-4) + self.assertLessEqual(diff_total_per_call, 3e-4) + self.assertLessEqual(diff_cul_time, 3e-4) + self.assertLessEqual(diff_cul_per_call, 3e-4) foo_dict = stat_dict["foo"] + ref_per_call = time1 + time2 + diff_total_time = abs(foo_dict["total_time"] - ref_per_call * 2) + diff_total_per_call = abs(foo_dict["total_per_call"] - ref_per_call) + diff_cul_time = abs(foo_dict["cumulative_time"] - time2 * 2) + diff_cul_per_call = abs(foo_dict["cumulative_per_call"] - time2) + self.assertEqual(foo_dict["call_count"], 2) - self.assertTrue(foo_dict["total_time"] - (time1 + time2) * 2 <= 3e-4) - self.assertTrue(foo_dict["total_per_call"] - (time1 + time2) <= 3e-4) - self.assertTrue(foo_dict["cumulative_time"] - (time2 * 2) <= 3e-4) - self.assertTrue(foo_dict["cumulative_per_call"] - time2 <= 3e-4) + self.assertLessEqual(diff_total_time, 3e-4) + self.assertLessEqual(diff_total_per_call, 3e-4) + self.assertLessEqual(diff_cul_time, 3e-4) + self.assertLessEqual(diff_cul_per_call, 3e-4) baz_dict = stat_dict["baz"] - ref_total_time = time1 + time2 + time3 + ref_per_call = time1 + time2 + time3 + diff_total_time = abs(baz_dict["total_time"] - ref_per_call) + diff_total_per_call = abs(baz_dict["total_per_call"] - ref_per_call) + diff_cul_time = abs(baz_dict["cumulative_time"] - time3) + diff_cul_per_call = abs(baz_dict["cumulative_per_call"] - time3) + self.assertEqual(baz_dict["call_count"], 1) - self.assertTrue(baz_dict["total_time"] - ref_total_time <= 3e-4) - self.assertTrue(baz_dict["total_per_call"] - ref_total_time <= 3e-4) - self.assertTrue(baz_dict["cumulative_time"] - time3 <= 3e-4) - self.assertTrue(baz_dict["cumulative_per_call"] - time3 <= 3e-4) + self.assertLessEqual(diff_total_time, 3e-4) + self.assertLessEqual(diff_total_per_call, 3e-4) + self.assertLessEqual(diff_cul_time, 3e-4) + self.assertLessEqual(diff_cul_per_call, 3e-4)