Skip to content

Commit

Permalink
Add function for retrieving fallback operations. (#7116)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored May 29, 2024
1 parent 468a5c9 commit 6d27123
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 10 deletions.
54 changes: 54 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import unittest


def XLAExperimentalContains(feat):
experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":")
return feat in experimental


class MetricsTest(unittest.TestCase):

def test_clear_counters(self):
Expand Down Expand Up @@ -205,6 +210,55 @@ def test_pybind_increment_counter(self):
torch_xla._XLAC._xla_increment_counter('FakeCounter', 2)
self.assertEqual(met.counter_value('FakeCounter'), 2)

def test_get_fallback_ops(self):

def getAndAssertFallbackOpsLenEquals(count):
fallback_ops = met.executed_fallback_ops()
fallback_ops_number = len(fallback_ops)
self.assertEqual(
fallback_ops_number,
count,
msg=f"found {fallback_ops_number}: {fallback_ops}")
return fallback_ops

# Reset all metrics, and make sure we don't start with any fallback ops.
met.clear_all()
getAndAssertFallbackOpsLenEquals(0)

# Create N boxes in the format XYXY.
# This should not run any fallback ops.
N = 10
x = torch.rand(N, 1).to(xm.xla_device())
y = torch.rand(N, 1).to(xm.xla_device())
width = torch.rand(N, 1).to(xm.xla_device())
height = torch.rand(N, 1).to(xm.xla_device())
xys = torch.cat((x, x + width, y, y - height), dim=1)
getAndAssertFallbackOpsLenEquals(0)

# tensor.item() is a fallback operation.
xys[0, 0].item()
ops = getAndAssertFallbackOpsLenEquals(1)
self.assertEqual(ops[0], "aten::_local_scalar_dense")

# Reset all metrics, and make sure we also don't retrieve any
# fallback operations.
met.clear_all()
getAndAssertFallbackOpsLenEquals(0)

if not XLAExperimentalContains("nms"):
# Run torchvision operations as fallback.
import torchvision
scores = torch.rand(N).to(xm.xla_device())
# NMS doesn't have a PyTorch/XLA implementation without dynamic shapes.
torchvision.ops.nms(xys, scores, 0.5)
# remove_small_boxes is not implemented in C++. It calls other PyTorch
# operations. One of them, nonzero, is a fallback operation.
torchvision.ops.remove_small_boxes(
xys, torch.median(torch.stack((width, height))))
ops = getAndAssertFallbackOpsLenEquals(3)
self.assertEqual(
set(ops), {"aten::nonzero", "aten::median", "torchvision::nms"})


if __name__ == '__main__':
test = unittest.main()
Expand Down
10 changes: 1 addition & 9 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,7 @@ def __call__(self, args):


def get_fallback_ops():
fallback_ops = []
for opname in metrics.counter_names():
if "aten::" not in opname:
continue
val = int(metrics.counter_value(opname))
if val > 0:
fallback_ops.append(f"{opname}={val}")

return fallback_ops
return metrics.executed_fallback_ops()


# Checks that all input args that are tensors are on the same device.
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/aten_cpu_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ namespace torch_xla {
static std::unordered_map<std::string, ::torch_xla::runtime::metrics::Counter*>
_cpu_fallback_counters;

// Get all the executed fallback operations.
// In other words, get all of them whose counters are not zero.
std::vector<std::string> GetFallbackOperations() {
std::vector<std::string> fallback;
for (auto const& pair : _cpu_fallback_counters) {
if (pair.second->Value() != 0) {
fallback.push_back(pair.first);
}
}
return fallback;
}

void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
XLA_FN_TRACK(3);
const auto name = c10::toString(op.operator_name());
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/aten_cpu_fallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace torch_xla {

void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);

std::vector<std::string> GetFallbackOperations();

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_
#endif // XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_
2 changes: 2 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "pybind11/stl_bind.h"
#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_cpu_fallback.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/dl_convertor.h"
Expand Down Expand Up @@ -1781,6 +1782,7 @@ void InitXlaModuleBindings(py::module m) {
}
},
py::arg("devices"));
m.def("_get_executed_fallback_ops", []() { return GetFallbackOperations(); });
m.def("_xla_counter_names", []() {
auto counter_names = torch::lazy::GetCounterNames();
auto xla_counter_names = runtime::metrics::GetCounterNames();
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/debug/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ def short_metrics_report(counter_names: list = None, metric_names: list = None):
'TransferToDeviceTime', 'TransferFromDeviceTime'
]
return torch_xla._XLAC._short_xla_metrics_report(counter_names, metric_names)


def executed_fallback_ops():
"""Retrieves a list of operations that were run in fallback mode."""
return torch_xla._XLAC._get_executed_fallback_ops()

0 comments on commit 6d27123

Please sign in to comment.