Skip to content

Commit

Permalink
Update template arguments for LuDecomposition (#454)
Browse files Browse the repository at this point in the history
update template arguments for LuDecomposition
  • Loading branch information
mattldawson committed Apr 9, 2024
1 parent 582b057 commit b35622b
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 54 deletions.
6 changes: 3 additions & 3 deletions include/micm/solver/cuda_lu_decomposition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ namespace micm
/// This is the overloaded constructor that takes one argument called "matrix";
/// We need to specify the type (e.g., double, int, etc) and
/// ordering (e.g., vector-stored, non-vector-stored, etc) of the "matrix";
template<typename T, typename OrderingPolicy>
CudaLuDecomposition(const SparseMatrix<T, OrderingPolicy>& matrix)
: LuDecomposition(matrix)
template<typename SparseMatrixPolicy>
CudaLuDecomposition(const SparseMatrixPolicy& matrix)
: LuDecomposition(LuDecomposition::Create<double, SparseMatrixPolicy>(matrix))
{
/// Passing the class itself as an argument is not support by CUDA;
/// Thus we generate a host struct first to save the pointers to
Expand Down
2 changes: 1 addition & 1 deletion include/micm/solver/jit_lu_decomposition.inl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace micm
inline JitLuDecomposition<L>::JitLuDecomposition(
std::shared_ptr<JitCompiler> compiler,
const SparseMatrix<double, SparseMatrixVectorOrdering<L>> &matrix)
: LuDecomposition(matrix),
: LuDecomposition(LuDecomposition::Create<double, SparseMatrix<double, SparseMatrixVectorOrdering<L>>>(matrix)),
compiler_(compiler)
{
decompose_function_ = NULL;
Expand Down
4 changes: 2 additions & 2 deletions include/micm/solver/linear_solver.inl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace micm
: LinearSolver<T, SparseMatrixPolicy, LuDecompositionPolicy>(
matrix,
initial_value,
[](const SparseMatrixPolicy<T>& m) -> LuDecompositionPolicy { return LuDecompositionPolicy(m); })
[](const SparseMatrixPolicy<T>& m) -> LuDecompositionPolicy { return LuDecomposition::Create<T, SparseMatrixPolicy>(m); })
{
}

Expand All @@ -70,7 +70,7 @@ namespace micm
Uij_xj_(),
lu_decomp_(create_lu_decomp(matrix))
{
auto lu = lu_decomp_.GetLUMatrices(matrix, initial_value);
auto lu = lu_decomp_.template GetLUMatrices<T, SparseMatrixPolicy>(matrix, initial_value);
auto lower_matrix = std::move(lu.first);
auto upper_matrix = std::move(lu.second);
for (std::size_t i = 0; i < lower_matrix[0].size(); ++i)
Expand Down
28 changes: 23 additions & 5 deletions include/micm/solver/lu_decomposition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,26 @@ namespace micm

/// @brief Construct an LU decomposition algorithm for a given sparse matrix
/// @param matrix Sparse matrix
template<typename T, typename OrderingPolicy>
LuDecomposition(const SparseMatrix<T, OrderingPolicy>& matrix);
template<typename T>
LuDecomposition(const SparseMatrix<T>& matrix);

/// @brief Create an LU decomposition algorithm for a given sparse matrix policy
/// @param matrix Sparse matrix
template<typename T, template<class> class SparseMatrixPolicy>
static LuDecomposition Create(const SparseMatrixPolicy<T>& matrix);
template<typename T, class SparseMatrixPolicy>
static LuDecomposition Create(const SparseMatrixPolicy& matrix);

/// @brief Create sparse L and U matrices for a given A matrix
/// @param A Sparse matrix that will be decomposed
/// @return L and U Sparse matrices
template<typename T, typename OrderingPolicy>
static std::pair<SparseMatrix<T, OrderingPolicy>, SparseMatrix<T, OrderingPolicy>> GetLUMatrices(
const SparseMatrix<T, OrderingPolicy>& A,
template<typename T, template <class> class SparseMatrixPolicy>
static std::pair<SparseMatrixPolicy<T>, SparseMatrixPolicy<T>> GetLUMatrices(
const SparseMatrixPolicy<T>& A,
T initial_value);
template<typename T, class SparseMatrixPolicy>
static std::pair<SparseMatrixPolicy, SparseMatrixPolicy> GetLUMatrices(
const SparseMatrixPolicy& A,
T initial_value);

/// @brief Perform an LU decomposition on a given A matrix
Expand All @@ -110,6 +121,13 @@ namespace micm
SparseMatrixPolicy<T>& L,
SparseMatrixPolicy<T>& U,
bool& is_singular) const;

private:

/// @brief Initialize arrays for the LU decomposition
/// @param A Sparse matrix to decompose
template<typename T, class SparseMatrixPolicy>
void Initialize(const SparseMatrixPolicy& A, T initial_value);
};

} // namespace micm
Expand Down
48 changes: 39 additions & 9 deletions include/micm/solver/lu_decomposition.inl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,33 @@ namespace micm
{
}

template<typename T, typename OrderingPolicy>
inline LuDecomposition::LuDecomposition(const SparseMatrix<T, OrderingPolicy>& matrix)
template<typename T>
inline LuDecomposition::LuDecomposition(const SparseMatrix<T>& matrix)
{
Initialize<T, SparseMatrix>(matrix);
}

template<typename T, template <class> class SparseMatrixPolicy>
inline LuDecomposition LuDecomposition::Create(const SparseMatrixPolicy<T>& matrix)
{
LuDecomposition lu_decomp{};
lu_decomp.Initialize<T, SparseMatrixPolicy<T>>(matrix, T{});
return lu_decomp;
}

template<typename T, class SparseMatrixPolicy>
inline LuDecomposition LuDecomposition::Create(const SparseMatrixPolicy& matrix)
{
LuDecomposition lu_decomp{};
lu_decomp.Initialize<T, SparseMatrixPolicy>(matrix, T{});
return lu_decomp;
}

template<typename T, class SparseMatrixPolicy>
inline void LuDecomposition::Initialize(const SparseMatrixPolicy& matrix, T initial_value)
{
std::size_t n = matrix[0].size();
auto LU = GetLUMatrices(matrix, T{});
auto LU = GetLUMatrices<T, SparseMatrixPolicy>(matrix, initial_value);
const auto& L_row_start = LU.first.RowStartVector();
const auto& L_row_ids = LU.first.RowIdsVector();
const auto& U_row_start = LU.second.RowStartVector();
Expand Down Expand Up @@ -82,9 +104,17 @@ namespace micm
}
}

template<typename T, typename OrderingPolicy>
inline std::pair<SparseMatrix<T, OrderingPolicy>, SparseMatrix<T, OrderingPolicy>> LuDecomposition::GetLUMatrices(
const SparseMatrix<T, OrderingPolicy>& A,
template<typename T, template <class> class SparseMatrixPolicy>
inline std::pair<SparseMatrixPolicy<T>, SparseMatrixPolicy<T>> LuDecomposition::GetLUMatrices(
const SparseMatrixPolicy<T>& A,
T initial_value)
{
return GetLUMatrices<T, SparseMatrixPolicy<T>>(A, initial_value);
}

template<typename T, class SparseMatrixPolicy>
inline std::pair<SparseMatrixPolicy, SparseMatrixPolicy> LuDecomposition::GetLUMatrices(
const SparseMatrixPolicy& A,
T initial_value)
{
std::size_t n = A[0].size();
Expand Down Expand Up @@ -129,18 +159,18 @@ namespace micm
}
}
auto L_builder =
micm::SparseMatrix<T, OrderingPolicy>::create(n).number_of_blocks(A.size()).initial_value(initial_value);
SparseMatrixPolicy::create(n).number_of_blocks(A.size()).initial_value(initial_value);
for (auto& pair : L_ids)
{
L_builder = L_builder.with_element(pair.first, pair.second);
}
auto U_builder =
micm::SparseMatrix<T, OrderingPolicy>::create(n).number_of_blocks(A.size()).initial_value(initial_value);
SparseMatrixPolicy::create(n).number_of_blocks(A.size()).initial_value(initial_value);
for (auto& pair : U_ids)
{
U_builder = U_builder.with_element(pair.first, pair.second);
}
std::pair<SparseMatrix<T, OrderingPolicy>, SparseMatrix<T, OrderingPolicy>> LU(L_builder, U_builder);
std::pair<SparseMatrixPolicy, SparseMatrixPolicy> LU(L_builder, U_builder);
return LU;
}

Expand Down
2 changes: 1 addition & 1 deletion include/micm/solver/state.inl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace micm
state_size_
);

auto lu = LuDecomposition::GetLUMatrices(jacobian_, 1.0e-30);
auto lu = LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(jacobian_, 1.0e-30);
auto lower_matrix = std::move(lu.first);
auto upper_matrix = std::move(lu.second);
lower_matrix_ = lower_matrix;
Expand Down
4 changes: 2 additions & 2 deletions test/unit/solver/test_cuda_lu_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ void testRandomMatrix(size_t n_grids)
check_results<double, SparseMatrixPolicy>(
A, gpu_LU.first, gpu_LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); });

micm::LuDecomposition cpu_lud(A);
auto cpu_LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30);
micm::LuDecomposition cpu_lud = micm::LuDecomposition::Create<double, SparseMatrixPolicy>(A);
auto cpu_LU = micm::LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(A, 1.0e-30);
cpu_lud.Decompose<double, SparseMatrixPolicy>(A, cpu_LU.first, cpu_LU.second);

// checking GPU result again CPU
Expand Down
14 changes: 7 additions & 7 deletions test/unit/solver/test_linear_solver_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void testDenseMatrix(const std::function<LinearSolverPolicy(const SparseMatrixPo
b[0][2] = 9;

LinearSolverPolicy solver = create_linear_solver(A, 1.0e-30);
auto lu = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30);
auto lu = micm::LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(A, 1.0e-30);
auto lower_matrix = std::move(lu.first);
auto upper_matrix = std::move(lu.second);
solver.Factor(A, lower_matrix, upper_matrix);
Expand Down Expand Up @@ -123,7 +123,7 @@ void testRandomMatrix(
b[i_block][i] = get_double();

LinearSolverPolicy solver = create_linear_solver(A, 1.0e-30);
auto lu = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30);
auto lu = micm::LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(A, 1.0e-30);
auto lower_matrix = std::move(lu.first);
auto upper_matrix = std::move(lu.second);
solver.Factor(A, lower_matrix, upper_matrix);
Expand Down Expand Up @@ -152,7 +152,7 @@ void testDiagonalMatrix(
A[i_block][i][i] = get_double();

LinearSolverPolicy solver = create_linear_solver(A, 1.0e-30);
auto lu = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30);
auto lu = micm::LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(A, 1.0e-30);
auto lower_matrix = std::move(lu.first);
auto upper_matrix = std::move(lu.second);
solver.Factor(A, lower_matrix, upper_matrix);
Expand Down Expand Up @@ -188,11 +188,11 @@ void testMarkowitzReordering()
builder = builder.with_element(i, j);
SparseMatrixPolicy<double> reordered_jac{ builder };

auto orig_LU_calc = micm::LuDecomposition{ orig_jac };
auto reordered_LU_calc = micm::LuDecomposition{ reordered_jac };
auto orig_LU_calc = micm::LuDecomposition::Create<double, SparseMatrixPolicy>( orig_jac );
auto reordered_LU_calc = micm::LuDecomposition::Create<double, SparseMatrixPolicy>( reordered_jac );

auto orig_LU = orig_LU_calc.GetLUMatrices(orig_jac, 0.0);
auto reordered_LU = reordered_LU_calc.GetLUMatrices(reordered_jac, 0.0);
auto orig_LU = orig_LU_calc.template GetLUMatrices<double, SparseMatrixPolicy>(orig_jac, 0.0);
auto reordered_LU = reordered_LU_calc.template GetLUMatrices<double, SparseMatrixPolicy>(reordered_jac, 0.0);

std::size_t sum_orig = 0;
std::size_t sum_reordered = 0;
Expand Down
40 changes: 20 additions & 20 deletions test/unit/solver/test_lu_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,95 +21,95 @@ using Group4SparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorO
TEST(LuDecomposition, DenseMatrixStandardOrdering)
{
testDenseMatrix<SparseMatrixTest, micm::LuDecomposition>(
[](const SparseMatrixTest<double>& matrix) -> micm::LuDecomposition { return micm::LuDecomposition{ matrix }; });
[](const SparseMatrixTest<double>& matrix) -> micm::LuDecomposition { return micm::LuDecomposition::Create<double, SparseMatrixTest>( matrix ); });
}

TEST(LuDecomposition, SingularMatrixStandardOrdering)
{
testSingularMatrix<SparseMatrixTest, micm::LuDecomposition>(
[](const SparseMatrixTest<double>& matrix) -> micm::LuDecomposition { return micm::LuDecomposition{ matrix }; });
[](const SparseMatrixTest<double>& matrix) -> micm::LuDecomposition { return micm::LuDecomposition::Create<double, SparseMatrixTest>( matrix ); });
}

TEST(LuDecomposition, RandomMatrixStandardOrdering)
{
testRandomMatrix<SparseMatrixTest, micm::LuDecomposition>(
[](const SparseMatrixTest<double>& matrix) -> micm::LuDecomposition { return micm::LuDecomposition{ matrix }; }, 5);
[](const SparseMatrixTest<double>& matrix) -> micm::LuDecomposition { return micm::LuDecomposition::Create<double, SparseMatrixTest>( matrix ); }, 5);
}

TEST(LuDecomposition, DiagonalMatrixStandardOrdering)
{
testDiagonalMatrix<SparseMatrixTest, micm::LuDecomposition>(
[](const SparseMatrixTest<double>& matrix) -> micm::LuDecomposition { return micm::LuDecomposition{ matrix }; }, 5);
[](const SparseMatrixTest<double>& matrix) -> micm::LuDecomposition { return micm::LuDecomposition::Create<double, SparseMatrixTest>( matrix ); }, 5);
}

TEST(LuDecomposition, DenseMatrixVectorOrdering)
{
testDenseMatrix<Group1SparseVectorMatrix, micm::LuDecomposition>(
[](const Group1SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; });
{ return micm::LuDecomposition::Create<double, Group1SparseVectorMatrix>( matrix ); });
testDenseMatrix<Group2SparseVectorMatrix, micm::LuDecomposition>(
[](const Group2SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; });
{ return micm::LuDecomposition::Create<double, Group2SparseVectorMatrix>( matrix ); });
testDenseMatrix<Group3SparseVectorMatrix, micm::LuDecomposition>(
[](const Group3SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; });
{ return micm::LuDecomposition::Create<double, Group3SparseVectorMatrix>( matrix ); });
testDenseMatrix<Group4SparseVectorMatrix, micm::LuDecomposition>(
[](const Group4SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; });
{ return micm::LuDecomposition::Create<double, Group4SparseVectorMatrix>( matrix ); });
}

TEST(LuDecomposition, SingluarMatrixVectorOrdering)
{
testSingularMatrix<Group1SparseVectorMatrix, micm::LuDecomposition>(
[](const Group1SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; });
{ return micm::LuDecomposition::Create<double, Group1SparseVectorMatrix>( matrix ); });
testSingularMatrix<Group2SparseVectorMatrix, micm::LuDecomposition>(
[](const Group2SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; });
{ return micm::LuDecomposition::Create<double, Group2SparseVectorMatrix>( matrix ); });
testSingularMatrix<Group3SparseVectorMatrix, micm::LuDecomposition>(
[](const Group3SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; });
{ return micm::LuDecomposition::Create<double, Group3SparseVectorMatrix>( matrix ); });
testSingularMatrix<Group4SparseVectorMatrix, micm::LuDecomposition>(
[](const Group4SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; });
{ return micm::LuDecomposition::Create<double, Group4SparseVectorMatrix>( matrix ); });
}

TEST(LuDecomposition, RandomMatrixVectorOrdering)
{
testRandomMatrix<Group1SparseVectorMatrix, micm::LuDecomposition>(
[](const Group1SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; },
{ return micm::LuDecomposition::Create<double, Group1SparseVectorMatrix>( matrix ); },
5);
testRandomMatrix<Group2SparseVectorMatrix, micm::LuDecomposition>(
[](const Group2SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; },
{ return micm::LuDecomposition::Create<double, Group2SparseVectorMatrix>( matrix ); },
5);
testRandomMatrix<Group3SparseVectorMatrix, micm::LuDecomposition>(
[](const Group3SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; },
{ return micm::LuDecomposition::Create<double, Group3SparseVectorMatrix>( matrix ); },
5);
testRandomMatrix<Group4SparseVectorMatrix, micm::LuDecomposition>(
[](const Group4SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; },
{ return micm::LuDecomposition::Create<double, Group4SparseVectorMatrix>( matrix ); },
5);
}

TEST(LuDecomposition, DiagonalMatrixVectorOrdering)
{
testDiagonalMatrix<Group1SparseVectorMatrix, micm::LuDecomposition>(
[](const Group1SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; },
{ return micm::LuDecomposition::Create<double, Group1SparseVectorMatrix>( matrix ); },
5);
testDiagonalMatrix<Group2SparseVectorMatrix, micm::LuDecomposition>(
[](const Group2SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; },
{ return micm::LuDecomposition::Create<double, Group2SparseVectorMatrix>( matrix ); },
5);
testDiagonalMatrix<Group3SparseVectorMatrix, micm::LuDecomposition>(
[](const Group3SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; },
{ return micm::LuDecomposition::Create<double, Group3SparseVectorMatrix>( matrix ); },
5);
testDiagonalMatrix<Group4SparseVectorMatrix, micm::LuDecomposition>(
[](const Group4SparseVectorMatrix<double>& matrix) -> micm::LuDecomposition
{ return micm::LuDecomposition{ matrix }; },
{ return micm::LuDecomposition::Create<double, Group4SparseVectorMatrix>( matrix ); },
5);
}
8 changes: 4 additions & 4 deletions test/unit/solver/test_lu_decomposition_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void testDenseMatrix(const std::function<LuDecompositionPolicy(const SparseMatri
A[0][2][2] = 8;

LuDecompositionPolicy lud = create_lu_decomp(A);
auto LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30);
auto LU = micm::LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(A, 1.0e-30);
lud.template Decompose<double, SparseMatrixPolicy>(A, LU.first, LU.second);
check_results<double, SparseMatrixPolicy>(
A, LU.first, LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); });
Expand All @@ -120,7 +120,7 @@ void testSingularMatrix(const std::function<LuDecompositionPolicy(const SparseMa
A[0][1][1] = 1;

LuDecompositionPolicy lud = create_lu_decomp(A);
auto LU = micm::LuDecomposition::GetLUMatrices(A, 1.0E-30);
auto LU = micm::LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(A, 1.0E-30);
bool is_singular{ false };
lud.template Decompose<double, SparseMatrixPolicy>(A, LU.first, LU.second, is_singular);
EXPECT_TRUE(is_singular);
Expand Down Expand Up @@ -152,7 +152,7 @@ void testRandomMatrix(
A[i_block][i][j] = get_double();

LuDecompositionPolicy lud = create_lu_decomp(A);
auto LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30);
auto LU = micm::LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(A, 1.0e-30);
lud.template Decompose<double, SparseMatrixPolicy>(A, LU.first, LU.second);
check_results<double, SparseMatrixPolicy>(
A, LU.first, LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); });
Expand All @@ -176,7 +176,7 @@ void testDiagonalMatrix(
A[i_block][i][i] = get_double();

LuDecompositionPolicy lud = create_lu_decomp(A);
auto LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30);
auto LU = micm::LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(A, 1.0e-30);
lud.template Decompose<double, SparseMatrixPolicy>(A, LU.first, LU.second);
check_results<double, SparseMatrixPolicy>(
A, LU.first, LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); });
Expand Down

0 comments on commit b35622b

Please sign in to comment.