-
Notifications
You must be signed in to change notification settings - Fork 468
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* prototype version (compiling error) * Add native_dropout manual lowering. * fix to tensor IR and add a simple native_dropout test * fix data type issue and update test case * fix IR hash issue * fix corner case when probability==0 * remove typo line * add test case when probability=0
- Loading branch information
Showing
9 changed files
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#include "torch_xla/csrc/ops/native_dropout.h" | ||
|
||
#include "torch_xla/csrc/helpers.h" | ||
#include "torch_xla/csrc/lowering_context.h" | ||
#include "torch_xla/csrc/xla_lower_util.h" | ||
|
||
namespace torch_xla { | ||
namespace { | ||
|
||
xla::Shape NodeOutputShape(const torch::lazy::Value& input) { | ||
xla::Shape input_shape = GetXlaShape(input); | ||
return xla::ShapeUtil::MakeTupleShape({input_shape, input_shape}); | ||
} | ||
|
||
} // namespace | ||
|
||
NativeDropout::NativeDropout(const torch::lazy::Value& input, | ||
const torch::lazy::Value& seed, float p, | ||
c10::optional<bool> train) | ||
: XlaNode(torch::lazy::OpKind(at::aten::native_dropout), {input, seed}, | ||
[&]() { return NodeOutputShape(input); }, 2, | ||
torch::lazy::MHash(p, train)), | ||
p_(p), | ||
train_(train) {} | ||
|
||
torch::lazy::NodePtr NativeDropout::Clone(torch::lazy::OpList operands) const { | ||
return torch::lazy::MakeNode<NativeDropout>(operands.at(0), operands.at(1), | ||
p_, train_); | ||
} | ||
|
||
XlaOpVector NativeDropout::Lower(LoweringContext* loctx) const { | ||
xla::XlaOp input = loctx->GetOutputOp(operand(0)); | ||
xla::XlaOp seed = loctx->GetOutputOp(operand(1)); | ||
return ReturnOps(BuildNativeDropout(input, seed, p_, train_), loctx); | ||
} | ||
|
||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#ifndef XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_ | ||
#define XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_ | ||
|
||
#include "torch_xla/csrc/ir.h" | ||
|
||
namespace torch_xla { | ||
|
||
// This node has no metadata, so it could have been implemented as generic-op in | ||
// ops.cpp, but since this might require special handling from upper IR layers, | ||
// it gets its own IR node class. | ||
class NativeDropout : public XlaNode { | ||
public: | ||
NativeDropout(const torch::lazy::Value& input, const torch::lazy::Value& seed, | ||
float p, c10::optional<bool> train); | ||
|
||
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; | ||
|
||
XlaOpVector Lower(LoweringContext* loctx) const override; | ||
|
||
private: | ||
float p_; | ||
c10::optional<bool> train_; | ||
}; | ||
|
||
} // namespace torch_xla | ||
|
||
#endif // XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters