Skip to content

Commit

Permalink
lower NativeDropoutBackward (#5642)
Browse files Browse the repository at this point in the history
* lower NativeDropoutBackward

* fix lowering and add python test
  • Loading branch information
JackCaoG authored and zpcore committed Sep 28, 2023
1 parent 9b1952d commit 8d872d5
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ full_codegen:
- lt.Tensor
- maximum
- minimum
- native_dropout_backward
- ne.Scalar
- ne.Tensor
- reciprocal
Expand Down
23 changes: 23 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,29 @@ def test_fn(*indices):
for dtype in (torch.long, torch.int32, torch.bool)
], test_fn)

def test_native_dropout_backward(self):

def test_fn(input):
dropped = torch.native_dropout(input, 0.5, train=True)
loss = dropped[0] + 0.5
loss.mean().backward()
return dropped[1].cpu(), input.grad.cpu()

met.clear_all()
xla_device = xm.xla_device()
input_cpu = torch.randn(7, 7, requires_grad=True)
input_xla = torch.randn(7, 7, device=xla_device, requires_grad=True)
mask_cpu, grad_cpu = test_fn(input_cpu)
mask_xla, grad_xla = test_fn(input_xla)
# dropout is random, hence we construct the expected grad_xla by mask_xla
# and gradient_cpu.
grad_cpu_single = grad_cpu[mask_cpu][0]
torch.allclose(
grad_cpu_single * mask_xla.to(torch.float), grad_xla, rtol=1e-03)

self.assertIn("xla::native_dropout_backward", met.counter_names())
self.assertNotIn("aten::native_dropout_backward", met.counter_names())

def test_conv2d_backward(self):
# Somehow eager cpu produces different results than us, and
# therefore we can't compare eager and xla.
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,20 @@ torch_xla::XlaOpVector Minimum::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Min(promoted.first, promoted.second), loctx);
}

torch_xla::XlaOpVector NativeDropoutBackward::Lower(
LoweringContext* loctx) const {
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp mask = loctx->GetOutputOp(operand(1));
xla::PrimitiveType grad_type =
ShapeHelper::ShapeOfXlaOp(grad_output).element_type();
xla::XlaOp res = grad_output * xla::ConvertElementType(mask, grad_type);
if (scale != 1.0) {
res = res * XlaHelpers::ScalarValue<float>(scale, grad_type,
grad_output.builder());
}
return ReturnOp(res, loctx);
}

torch_xla::XlaOpVector NeScalar::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,11 @@ xla::Shape MinimumOutputShape(const torch::lazy::Value& input,
lower_for_shape_fn);
}

xla::Shape NativeDropoutBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& mask) {
return GetXlaShape(grad_output);
}

xla::Shape NeScalarOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other) {
auto lower_for_shape_fn =
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ xla::Shape MaximumOutputShape(const torch::lazy::Value& input,
xla::Shape MinimumOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other);

xla::Shape NativeDropoutBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& mask);

xla::Shape NeScalarOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other);

Expand Down

0 comments on commit 8d872d5

Please sign in to comment.