diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index b6544cd83615..8c80facf7d9d 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -65,6 +65,7 @@ full_codegen: - lt.Tensor - maximum - minimum + - native_dropout_backward - ne.Scalar - ne.Tensor - reciprocal diff --git a/test/test_operations.py b/test/test_operations.py index 28626044d4c1..0083f39c8395 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1646,6 +1646,29 @@ def test_fn(*indices): for dtype in (torch.long, torch.int32, torch.bool) ], test_fn) + def test_native_dropout_backward(self): + + def test_fn(input): + dropped = torch.native_dropout(input, 0.5, train=True) + loss = dropped[0] + 0.5 + loss.mean().backward() + return dropped[1].cpu(), input.grad.cpu() + + met.clear_all() + xla_device = xm.xla_device() + input_cpu = torch.randn(7, 7, requires_grad=True) + input_xla = torch.randn(7, 7, device=xla_device, requires_grad=True) + mask_cpu, grad_cpu = test_fn(input_cpu) + mask_xla, grad_xla = test_fn(input_xla) + # dropout is random, hence we construct the expected grad_xla by mask_xla + # and gradient_cpu. + grad_cpu_single = grad_cpu[mask_cpu][0] + torch.allclose( + grad_cpu_single * mask_xla.to(torch.float), grad_xla, rtol=1e-03) + + self.assertIn("xla::native_dropout_backward", met.counter_names()) + self.assertNotIn("aten::native_dropout_backward", met.counter_names()) + def test_conv2d_backward(self): # Somehow eager cpu produces different results than us, and # therefore we can't compare eager and xla. diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 3ca447663551..28587c38df8b 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -562,6 +562,20 @@ torch_xla::XlaOpVector Minimum::Lower(LoweringContext* loctx) const { return ReturnOp(xla::Min(promoted.first, promoted.second), loctx); } +torch_xla::XlaOpVector NativeDropoutBackward::Lower( + LoweringContext* loctx) const { + xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); + xla::XlaOp mask = loctx->GetOutputOp(operand(1)); + xla::PrimitiveType grad_type = + ShapeHelper::ShapeOfXlaOp(grad_output).element_type(); + xla::XlaOp res = grad_output * xla::ConvertElementType(mask, grad_type); + if (scale != 1.0) { + res = res * XlaHelpers::ScalarValue(scale, grad_type, + grad_output.builder()); + } + return ReturnOp(res, loctx); +} + torch_xla::XlaOpVector NeScalar::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); xla::XlaOp xla_other = loctx->GetOutputOp(operand(1)); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 562b9d6df50c..b9a9d4048390 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -659,6 +659,11 @@ xla::Shape MinimumOutputShape(const torch::lazy::Value& input, lower_for_shape_fn); } +xla::Shape NativeDropoutBackwardOutputShape( + const torch::lazy::Value& grad_output, const torch::lazy::Value& mask) { + return GetXlaShape(grad_output); +} + xla::Shape NeScalarOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other) { auto lower_for_shape_fn = diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 716904caf4e3..7a79196fb772 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -216,6 +216,9 @@ xla::Shape MaximumOutputShape(const torch::lazy::Value& input, xla::Shape MinimumOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& other); +xla::Shape NativeDropoutBackwardOutputShape( + const torch::lazy::Value& grad_output, const torch::lazy::Value& mask); + xla::Shape NeScalarOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other);