Skip to content

Commit

Permalink
Lower aten::_linalg_eigh (#7674)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei committed Jul 15, 2024
1 parent 3bbd3f6 commit f975ad6
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 0 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ supported:
- _copy_from
- _copy_from_and_resize
- _index_put_impl_
- _linalg_eigh
- _linalg_slogdet
- _linalg_svd
- _local_scalar_dense
Expand Down
42 changes: 42 additions & 0 deletions test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,48 @@ TEST_F(AtenXlaTensorTest, TestLinalgVectorNormInDimsKeepDtype) {
cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLinalgEigh) {
// Hardcode the test input to avoid numerical instability from randomness,
// which is a problem in eigenvalue decomposition.
auto complex64 = [](float real, float imag) {
return c10::complex<float>{real, imag};
};
torch::Tensor input = torch::tensor({
{complex64(1, 0), complex64(2, -7), complex64(4, -8)},
{complex64(2, 7), complex64(3, 0), complex64(5, -9)},
{complex64(4, 8), complex64(5, 9), complex64(6, 0)},
});
for (c10::string_view uplo : {"U", "L"}) {
auto [eigenvalues, eigenvectors] = torch::linalg_eigh(input, uplo);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
auto [xla_eigenvalues, xla_eigenvectors] = torch::linalg_eigh(xla_input);
AllClose(eigenvalues, xla_eigenvalues);
// The eigenvectors of a symmetric matrix are not unique, nor are they
// continuous with respect to A. Due to this lack of uniqueness, different
// hardware and software may compute different eigenvectors. Therefore we
// instead verify that the decomposition follows the mathematical
// definition.
torch::Tensor input_reconstructed = torch::mm(
torch::mm(
eigenvectors,
torch::diag(eigenvalues).toType(c10::ScalarType::ComplexFloat)),
eigenvectors.t().conj());
auto xla_eigenvalues_cpu = ToCpuTensor(xla_eigenvalues);
auto xla_eigenvectors_cpu = ToCpuTensor(xla_eigenvectors);
torch::Tensor xla_input_reconstructed =
torch::mm(torch::mm(xla_eigenvectors_cpu,
torch::diag(xla_eigenvalues_cpu)
.toType(c10::ScalarType::ComplexFloat)),
xla_eigenvectors_cpu.t().conj());
AllClose(input_reconstructed, input);
AllClose(xla_input_reconstructed, input);
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_linalg_eigh", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestQR) {
static const int dims[] = {4, 7};
for (auto m : dims) {
Expand Down
15 changes: 15 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,21 @@ at::Tensor& XLANativeFunctions::_index_put_impl_(
accumulate);
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::_linalg_eigh(
const at::Tensor& self, c10::string_view uplo, bool compute_v) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (!compute_v) {
// Fallback to aten in case of `eigvalsh`.
return at::native::call_fallback_fn<&xla_fallback,
ATEN_OP(_linalg_eigh)>::call(self, uplo,
compute_v);
}
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
auto outputs = tensor_methods::eigh(self_tensor, uplo);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)),
bridge::AtenFromXlaTensor(std::get<1>(outputs)));
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
XLANativeFunctions::_linalg_slogdet(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
58 changes: 58 additions & 0 deletions torch_xla/csrc/ops/eigh.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "torch_xla/csrc/ops/eigh.h"

#include <array>

#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"
#include "xla/client/lib/self_adjoint_eig.h"

namespace torch_xla {

namespace {

std::array<xla::XlaOp, 2> LowerImpl(xla::XlaOp input, bool lower) {
auto [eigenvectors, eigenvalues] =
// The default `max_iter` and `tol` values lead to very inaccurate
// decomposition. To improve accuracy we run more iterations and tighter
// tolerance. These values are taken from the JAX lowering of eigh:
// https://github.com/google/jax/blob/a8b425cac50c842f66f36903dfb93fe6ad5a2a5b/jax/_src/lax/linalg.py#L726
xla::SelfAdjointEig(input, lower, /* max_iter */ 100, /* tol */ 1e-6);
// Torch expects `(eigenvalues, eigenvectors)` and XLA returns the reverse.
return {eigenvalues, eigenvectors};
}

xla::Shape NodeOutputShape(const torch::lazy::Value& input) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return xla::Tuple(operands[0].builder(), LowerImpl(operands[0], true));
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

} // namespace

Eigh::Eigh(const torch::lazy::Value& input, c10::string_view uplo)
: XlaNode(
torch::lazy::OpKind(at::aten::_linalg_eigh), {input},
[&]() { return NodeOutputShape(input); },
/*num_outputs=*/2, torch::lazy::MHash(uplo)) {
XLA_CHECK(uplo == "L" || uplo == "U") << "Expected L or U, got: " << uplo;
uplo_ = uplo[0];
}

XlaOpVector Eigh::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
bool lower = uplo_ == 'L';
return ReturnOps(LowerImpl(input, lower), loctx);
}

std::string Eigh::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", uplo=" << uplo_;
return ss.str();
}

} // namespace torch_xla
25 changes: 25 additions & 0 deletions torch_xla/csrc/ops/eigh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_EIGH_H_
#define XLA_TORCH_XLA_CSRC_OPS_EIGH_H_

#include <c10/util/string_view.h>

#include "torch_xla/csrc/ir.h"
#include "xla/types.h"

namespace torch_xla {

class Eigh : public XlaNode {
public:
Eigh(const torch::lazy::Value& input, c10::string_view uplo);

std::string ToString() const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

private:
char uplo_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_EIGH_H_
16 changes: 16 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "torch_xla/csrc/ops/discrete_uniform.h"
#include "torch_xla/csrc/ops/dynamic_expand.h"
#include "torch_xla/csrc/ops/dynamic_view.h"
#include "torch_xla/csrc/ops/eigh.h"
#include "torch_xla/csrc/ops/einsum.h"
#include "torch_xla/csrc/ops/einsum_backward.h"
#include "torch_xla/csrc/ops/embedding_bag.h"
Expand Down Expand Up @@ -2797,6 +2798,21 @@ XLATensorPtr slice(const XLATensorPtr& input, int64_t dim, int64_t start,
input->GetIrValue(), dim, start, end, step));
}

std::tuple<XLATensorPtr, XLATensorPtr> eigh(const XLATensorPtr& input,
c10::string_view uplo) {
torch::lazy::NodePtr node =
torch::lazy::MakeNode<Eigh>(input->GetIrValue(), uplo);
// Here we explictly pass std::nullopt as logical_element_type because
// otherwise result will inherit the input's logical_element_type. In the
// case of eigh(complex) -> (real, complex), we want to derive the dtype
// from IR value instead of input's dtype.
return std::make_tuple(
input->CreateFrom(torch::lazy::Value(node, 0), std::nullopt),
// From https://pytorch.org/docs/stable/generated/torch.linalg.eigh.html,
// eigenvectors will have the same dtype as A.
input->CreateFrom(torch::lazy::Value(node, 1)));
}

std::tuple<XLATensorPtr, XLATensorPtr> slogdet(const XLATensorPtr& input) {
torch::lazy::NodePtr node = SLogDet(input->GetIrValue());
return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)),
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,9 @@ XLATensorPtr sigmoid_backward(const XLATensorPtr& grad_output,
XLATensorPtr slice(const XLATensorPtr& input, int64_t dim, int64_t start,
int64_t end, int64_t step);

std::tuple<XLATensorPtr, XLATensorPtr> eigh(const XLATensorPtr& input,
c10::string_view uplo);

std::tuple<XLATensorPtr, XLATensorPtr> slogdet(const XLATensorPtr& input);

// Computes a loss that uses a squared term if the absolute element-wise error
Expand Down

0 comments on commit f975ad6

Please sign in to comment.