Skip to content

Commit

Permalink
lower native_dropout (#5643)
Browse files Browse the repository at this point in the history
* prototype version (compiling error)

* Add native_dropout manual lowering.

* fix to tensor IR and add a simple native_dropout test

* fix data type issue and update test case

* fix IR hash issue

* fix corner case when probability==0

* remove typo line

* add test case when probability=0
  • Loading branch information
zpcore authored Sep 26, 2023
1 parent 00272c3 commit 2c6e4a7
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 0 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ supported:
- nan_to_num
- native_batch_norm
- native_batch_norm_backward
- native_dropout
- neg
- nll_loss2d_backward
- nll_loss2d_forward
Expand Down
73 changes: 73 additions & 0 deletions test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,79 @@ TEST_F(AtenXlaTensorTest, TestDropoutInPlace) {
ExpectCounterChanged("xla::bernoulli", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestNativeDropout) {
torch::Tensor a = torch::rand({17, 21}, torch::TensorOptions(torch::kFloat));
float allowance = 0.04;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
for (float probability : {0.1, 0.5}) {
auto [xla_b_val, xla_b_mask] =
torch::native_dropout(xla_a, probability, /*train=*/true);
double prob = static_cast<double>(
xla_b_val.cpu().eq(0.0f).sum().item().toDouble()) /
a.numel();
EXPECT_GT(prob, probability - allowance);
EXPECT_LT(prob, probability + allowance);
EXPECT_EQ(xla_b_val.scalar_type(), torch::kFloat);
EXPECT_EQ(xla_b_mask.scalar_type(), torch::kBool);
}
});

ExpectCounterNotChanged("aten::(?!_local_scalar_dense).*",
cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::native_dropout", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestNativeDropoutNotTrain) {
torch::Tensor a = torch::rand({17, 21}, torch::TensorOptions(torch::kFloat));
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
auto [xla_b_val, xla_b_mask] =
torch::native_dropout(xla_a, 0.5, /*train=*/false);
AllEqual(xla_b_val, xla_a);
EXPECT_EQ(xla_b_val.scalar_type(), torch::kFloat);
EXPECT_EQ(xla_b_mask.scalar_type(), torch::kBool);
});

ExpectCounterNotChanged("aten::(?!_local_scalar_dense).*",
cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::native_dropout", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestNativeDropoutMask) {
torch::Tensor a = torch::rand({17, 21}, torch::TensorOptions(torch::kFloat));
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
auto [xla_b_val, xla_b_mask] =
torch::native_dropout(xla_a, 0.5, /*train=*/true);
auto count1 = xla_b_val.cpu().eq(0.0f).sum().item().toInt();
auto count2 = xla_b_mask.cpu().eq(0.0f).sum().item().toInt();
EXPECT_EQ(count1, count2);
});

ExpectCounterNotChanged("aten::(?!_local_scalar_dense).*",
cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::native_dropout", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestNativeDropoutZeroProbability) {
torch::Tensor a = torch::rand({17, 21}, torch::TensorOptions(torch::kFloat));
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
auto [xla_b_val, xla_b_mask] =
torch::native_dropout(xla_a, 0, /*train=*/true);
auto count1 = xla_b_val.cpu().ne(0.0f).sum().item().toInt();
auto count2 = xla_b_mask.cpu().ne(0.0f).sum().item().toInt();
auto count3 = xla_a.cpu().ne(0.0f).sum().item().toInt();
EXPECT_EQ(count1, count2);
EXPECT_EQ(count2, count3);
});

ExpectCounterNotChanged("aten::(?!_local_scalar_dense).*",
cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::native_dropout", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestRandperm) {
int n = 5;
torch::Tensor shuffle = torch::randperm(
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,15 @@ XLANativeFunctions::native_batch_norm_backward(
: undefined);
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::native_dropout(
const at::Tensor& self, double p, c10::optional<bool> train) {
TORCH_LAZY_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
auto results = tensor_methods::native_dropout(self_tensor, p, train);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)),
bridge::AtenFromXlaTensor(std::get<1>(results)));
}

at::Tensor XLANativeFunctions::neg(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("xla::");
XLA_CHECK(self.scalar_type() != at::kBool)
Expand Down
37 changes: 37 additions & 0 deletions torch_xla/csrc/ops/native_dropout.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "torch_xla/csrc/ops/native_dropout.h"

#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace {

xla::Shape NodeOutputShape(const torch::lazy::Value& input) {
xla::Shape input_shape = GetXlaShape(input);
return xla::ShapeUtil::MakeTupleShape({input_shape, input_shape});
}

} // namespace

NativeDropout::NativeDropout(const torch::lazy::Value& input,
const torch::lazy::Value& seed, float p,
c10::optional<bool> train)
: XlaNode(torch::lazy::OpKind(at::aten::native_dropout), {input, seed},
[&]() { return NodeOutputShape(input); }, 2,
torch::lazy::MHash(p, train)),
p_(p),
train_(train) {}

torch::lazy::NodePtr NativeDropout::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<NativeDropout>(operands.at(0), operands.at(1),
p_, train_);
}

XlaOpVector NativeDropout::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp seed = loctx->GetOutputOp(operand(1));
return ReturnOps(BuildNativeDropout(input, seed, p_, train_), loctx);
}

} // namespace torch_xla
27 changes: 27 additions & 0 deletions torch_xla/csrc/ops/native_dropout.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_
#define XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

// This node has no metadata, so it could have been implemented as generic-op in
// ops.cpp, but since this might require special handling from upper IR layers,
// it gets its own IR node class.
class NativeDropout : public XlaNode {
public:
NativeDropout(const torch::lazy::Value& input, const torch::lazy::Value& seed,
float p, c10::optional<bool> train);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

private:
float p_;
c10::optional<bool> train_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_
11 changes: 11 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
#include "torch_xla/csrc/ops/multinomial.h"
#include "torch_xla/csrc/ops/native_batch_norm_backward.h"
#include "torch_xla/csrc/ops/native_batch_norm_forward.h"
#include "torch_xla/csrc/ops/native_dropout.h"
#include "torch_xla/csrc/ops/nll_loss.h"
#include "torch_xla/csrc/ops/nll_loss2d.h"
#include "torch_xla/csrc/ops/nll_loss2d_backward.h"
Expand Down Expand Up @@ -1887,6 +1888,16 @@ std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> native_batch_norm_backward(
std::move(grad_bias));
}

std::tuple<XLATensorPtr, XLATensorPtr> native_dropout(
const XLATensorPtr& input, double p, c10::optional<bool> train) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<NativeDropout>(
input->GetIrValue(),
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), p, train);
return std::make_tuple(
input->CreateFrom(torch::lazy::Value(node, 0)),
input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Bool));
}

XLATensorPtr ne(const XLATensorPtr& input, const at::Scalar& other) {
return DispatchComparisonOp(at::aten::ne, input, other);
}
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,9 @@ std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> native_batch_norm_backward(
const XLATensorPtr& weight, const XLATensorPtr& save_mean,
const XLATensorPtr& save_invstd, bool training, double eps);

std::tuple<XLATensorPtr, XLATensorPtr> native_dropout(
const XLATensorPtr& input, double p, c10::optional<bool> train);

XLATensorPtr ne(const XLATensorPtr& input, const at::Scalar& other);

XLATensorPtr ne(const XLATensorPtr& input, const XLATensorPtr& other);
Expand Down
19 changes: 19 additions & 0 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,25 @@ xla::XlaOp BuildDropout(xla::XlaOp input, float probability, xla::XlaOp seed) {
return input * mask;
}

std::vector<xla::XlaOp> BuildNativeDropout(xla::XlaOp input, xla::XlaOp seed,
float probability,
c10::optional<bool> train) {
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input);
if (!train.has_value() || *train) {
xla::XlaOp prob = XlaHelpers::ScalarBroadcast<float>(1 - probability, shape,
input.builder());
xla::XlaOp one = xla::One(input.builder(), shape.element_type());
xla::XlaOp mask = BuildBernoulli(prob, seed, shape.element_type());
if (probability > 0.0f) {
mask = mask / (one - prob);
}
return {input * mask, mask};
} else {
xla::XlaOp one = xla::One(input.builder(), xla::PrimitiveType::PRED);
return {input, one};
}
}

std::vector<xla::XlaOp> CreateBroadcastTensors(
absl::Span<const xla::XlaOp> operands) {
xla::Shape result_shape = ShapeHelper::ShapeOfXlaOp(operands.front());
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ xla::XlaOp BuildExponential(xla::XlaOp lambda, xla::XlaOp seed,

xla::XlaOp BuildDropout(xla::XlaOp input, float probability, xla::XlaOp seed);

std::vector<xla::XlaOp> BuildNativeDropout(xla::XlaOp input, xla::XlaOp seed,
float probability,
c10::optional<bool> train);

xla::XlaOp BuildSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp output,
xla::XlaOp scalar_1);

Expand Down

0 comments on commit 2c6e4a7

Please sign in to comment.