Skip to content

Commit

Permalink
[Core Aten Ops] Lower reflection_pad1d, reflection_pad1d_backward, re…
Browse files Browse the repository at this point in the history
…flection_pad3d and reflection_pad3d_backward (#6588)
  • Loading branch information
ManfeiBai committed Feb 23, 2024
1 parent 263da4e commit 13e8647
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 5 deletions.
4 changes: 4 additions & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,12 @@ supported:
- random_.from
- random_.to
- randperm
- reflection_pad1d
- reflection_pad1d_backward
- reflection_pad2d
- reflection_pad2d_backward
- reflection_pad3d
- reflection_pad3d_backward
- remainder.Scalar
- remainder.Tensor
- replication_pad1d
Expand Down
90 changes: 90 additions & 0 deletions test/cpp/test_aten_xla_tensor_3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,51 @@ TEST_F(AtenXlaTensorTest, TestConstantPadIncomplete) {
});
}

TEST_F(AtenXlaTensorTest, TestReflectionPad1dRank2) {
torch::Tensor input =
torch::rand({2, 3}, torch::TensorOptions(torch::kFloat));
std::vector<int64_t> pad{2, 2};
torch::Tensor output = torch::reflection_pad1d(input, pad);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::reflection_pad1d(xla_input, pad);
AllClose(output, xla_output);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::reflection_pad1d", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestReflectionPad1dRank3) {
torch::Tensor input =
torch::rand({2, 3, 4}, torch::TensorOptions(torch::kFloat));
std::vector<int64_t> pad{2, 2};
torch::Tensor output = torch::reflection_pad1d(input, pad);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::reflection_pad1d(xla_input, pad);
AllClose(output, xla_output);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::reflection_pad1d", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestReflectionPad1dBackward) {
std::vector<int64_t> pad{2, 2};
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::reflection_pad1d(inputs[0], pad);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::rand({2, 2, 3},
torch::TensorOptions(torch::kFloat).requires_grad(true))},
device, testfn);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestReflectionPad2dRank3) {
torch::Tensor input =
torch::rand({2, 3, 4}, torch::TensorOptions(torch::kFloat));
Expand Down Expand Up @@ -678,6 +723,51 @@ TEST_F(AtenXlaTensorTest, TestReflectionPad2dBackward) {
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestReflectionPad3dRank5) {
torch::Tensor input =
torch::rand({2, 2, 3, 4, 2}, torch::TensorOptions(torch::kFloat));
std::vector<int64_t> pad{1, 1, 1, 2, 2, 1};
torch::Tensor output = torch::reflection_pad3d(input, pad);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::reflection_pad3d(xla_input, pad);
AllClose(output, xla_output);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::reflection_pad3d", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestReflectionPad3dRank4) {
torch::Tensor input =
torch::rand({2, 2, 3, 4}, torch::TensorOptions(torch::kFloat));
std::vector<int64_t> pad{1, 1, 1, 1, 1, 1};
torch::Tensor output = torch::reflection_pad3d(input, pad);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::reflection_pad3d(xla_input, pad);
AllClose(output, xla_output);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::reflection_pad3d", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestReflectionPad3dBackward) {
std::vector<int64_t> pad{1, 1, 1, 1, 1, 1};
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::reflection_pad3d(inputs[0], pad);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::rand({2, 2, 4, 4, 2},
torch::TensorOptions(torch::kFloat).requires_grad(true))},
device, testfn);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestReplicationPad1d) {
torch::Tensor input =
torch::rand({1, 4}, torch::TensorOptions(torch::kFloat));
Expand Down
5 changes: 0 additions & 5 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2943,7 +2943,6 @@ def test_aten_reciprocal_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs)

@unittest.skip
def test_aten_reflection_pad1d_0(self):
args = (
torch.randn((10, 10)).to(torch.float32),
Expand All @@ -2955,7 +2954,6 @@ def test_aten_reflection_pad1d_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs)

@unittest.skip
def test_aten_reflection_pad1d_1(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
Expand Down Expand Up @@ -2993,7 +2991,6 @@ def test_aten_reflection_pad2d_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs)

@unittest.skip
def test_aten_reflection_pad3d_0(self):
args = (
torch.randn((3, 3, 3, 3, 3)).to(torch.float32),
Expand All @@ -3009,7 +3006,6 @@ def test_aten_reflection_pad3d_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs)

@unittest.skip
def test_aten_reflection_pad3d_1(self):
args = (
torch.randn((3, 3, 3, 3, 3)).to(torch.float16),
Expand All @@ -3025,7 +3021,6 @@ def test_aten_reflection_pad3d_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs)

@unittest.skip
def test_aten_reflection_pad3d_2(self):
args = (
torch.randint(0, 10, (3, 3, 3, 3, 3)).to(torch.int32),
Expand Down
32 changes: 32 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2495,6 +2495,22 @@ at::Tensor XLANativeFunctions::randperm(int64_t n,
n, GetXlaDeviceOrCurrent(device), at::ScalarType::Long));
}

at::Tensor XLANativeFunctions::reflection_pad1d(const at::Tensor& self,
at::IntArrayRef padding) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad1d(
bridge::GetXlaTensor(self), torch::lazy::ToVector<int64_t>(padding)));
}

at::Tensor XLANativeFunctions::reflection_pad1d_backward(
const at::Tensor& grad_output, const at::Tensor& self,
at::IntArrayRef padding) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad1d_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
torch::lazy::ToVector<int64_t>(padding)));
}

at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self,
at::IntArrayRef padding) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand All @@ -2511,6 +2527,22 @@ at::Tensor XLANativeFunctions::reflection_pad2d_backward(
torch::lazy::ToVector<int64_t>(padding)));
}

at::Tensor XLANativeFunctions::reflection_pad3d(const at::Tensor& self,
at::IntArrayRef padding) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad3d(
bridge::GetXlaTensor(self), torch::lazy::ToVector<int64_t>(padding)));
}

at::Tensor XLANativeFunctions::reflection_pad3d_backward(
const at::Tensor& grad_output, const at::Tensor& self,
at::IntArrayRef padding) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad3d_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
torch::lazy::ToVector<int64_t>(padding)));
}

at::Tensor XLANativeFunctions::remainder(const at::Tensor& self,
const at::Tensor& other) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
34 changes: 34 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,23 @@ XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device,
return XLATensor::Create(node, device, scalar_type);
}

XLATensorPtr reflection_pad1d(const XLATensorPtr& input,
std::vector<int64_t> padding) {
// `ReflectionPad2d` is used due to `at::aten::reflection_pad2d_backward`
// named already
return input->CreateFrom(torch::lazy::MakeNode<ReflectionPad2d>(
input->GetIrValue(), std::move(padding)));
}

XLATensorPtr reflection_pad1d_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
std::vector<int64_t> padding) {
// `ReflectionPad2dBackward` is used due to
// `at::aten::reflection_pad2d_backward` named already
return input->CreateFrom(torch::lazy::MakeNode<ReflectionPad2dBackward>(
grad_output->GetIrValue(), input->GetIrValue(), std::move(padding)));
}

XLATensorPtr reflection_pad2d(const XLATensorPtr& input,
std::vector<int64_t> padding) {
return input->CreateFrom(torch::lazy::MakeNode<ReflectionPad2d>(
Expand All @@ -2309,6 +2326,23 @@ XLATensorPtr reflection_pad2d_backward(const XLATensorPtr& grad_output,
grad_output->GetIrValue(), input->GetIrValue(), std::move(padding)));
}

XLATensorPtr reflection_pad3d(const XLATensorPtr& input,
std::vector<int64_t> padding) {
// `ReflectionPad2d` is used due to `at::aten::reflection_pad2d_backward`
// named already
return input->CreateFrom(torch::lazy::MakeNode<ReflectionPad2d>(
input->GetIrValue(), std::move(padding)));
}

XLATensorPtr reflection_pad3d_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
std::vector<int64_t> padding) {
// `ReflectionPad2dBackward` is used due to
// `at::aten::reflection_pad2d_backward` named already
return input->CreateFrom(torch::lazy::MakeNode<ReflectionPad2dBackward>(
grad_output->GetIrValue(), input->GetIrValue(), std::move(padding)));
}

XLATensorPtr remainder(const XLATensorPtr& input, const XLATensorPtr& other) {
return input->CreateFrom(Remainder(input->GetIrValue(), other->GetIrValue()));
}
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,13 +716,27 @@ void random_(XLATensorPtr& input, int64_t from, int64_t to);
XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device,
at::ScalarType scalar_type);

XLATensorPtr reflection_pad1d(const XLATensorPtr& input,
std::vector<int64_t> padding);

XLATensorPtr reflection_pad1d_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
std::vector<int64_t> padding);

XLATensorPtr reflection_pad2d(const XLATensorPtr& input,
std::vector<int64_t> padding);

XLATensorPtr reflection_pad2d_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
std::vector<int64_t> padding);

XLATensorPtr reflection_pad3d(const XLATensorPtr& input,
std::vector<int64_t> padding);

XLATensorPtr reflection_pad3d_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
std::vector<int64_t> padding);

XLATensorPtr remainder(const XLATensorPtr& input, const XLATensorPtr& other);
XLATensorPtr remainder(const XLATensorPtr& input, const at::Scalar& other);

Expand Down

0 comments on commit 13e8647

Please sign in to comment.