From b35622b3a682be09e5e563e201625e206b653eba Mon Sep 17 00:00:00 2001 From: Matt Dawson Date: Tue, 9 Apr 2024 09:41:23 -0700 Subject: [PATCH] Update template arguments for LuDecomposition (#454) update template arguments for LuDecomposition --- include/micm/solver/cuda_lu_decomposition.hpp | 6 +-- include/micm/solver/jit_lu_decomposition.inl | 2 +- include/micm/solver/linear_solver.inl | 4 +- include/micm/solver/lu_decomposition.hpp | 28 +++++++++-- include/micm/solver/lu_decomposition.inl | 48 +++++++++++++++---- include/micm/solver/state.inl | 2 +- .../solver/test_cuda_lu_decomposition.cpp | 4 +- .../unit/solver/test_linear_solver_policy.hpp | 14 +++--- test/unit/solver/test_lu_decomposition.cpp | 40 ++++++++-------- .../solver/test_lu_decomposition_policy.hpp | 8 ++-- 10 files changed, 102 insertions(+), 54 deletions(-) diff --git a/include/micm/solver/cuda_lu_decomposition.hpp b/include/micm/solver/cuda_lu_decomposition.hpp index 536040c07..97e1e5a64 100644 --- a/include/micm/solver/cuda_lu_decomposition.hpp +++ b/include/micm/solver/cuda_lu_decomposition.hpp @@ -25,9 +25,9 @@ namespace micm /// This is the overloaded constructor that takes one argument called "matrix"; /// We need to specify the type (e.g., double, int, etc) and /// ordering (e.g., vector-stored, non-vector-stored, etc) of the "matrix"; - template - CudaLuDecomposition(const SparseMatrix& matrix) - : LuDecomposition(matrix) + template + CudaLuDecomposition(const SparseMatrixPolicy& matrix) + : LuDecomposition(LuDecomposition::Create(matrix)) { /// Passing the class itself as an argument is not support by CUDA; /// Thus we generate a host struct first to save the pointers to diff --git a/include/micm/solver/jit_lu_decomposition.inl b/include/micm/solver/jit_lu_decomposition.inl index 9a63c4fcf..f9ee62228 100644 --- a/include/micm/solver/jit_lu_decomposition.inl +++ b/include/micm/solver/jit_lu_decomposition.inl @@ -29,7 +29,7 @@ namespace micm inline JitLuDecomposition::JitLuDecomposition( std::shared_ptr compiler, const SparseMatrix> &matrix) - : LuDecomposition(matrix), + : LuDecomposition(LuDecomposition::Create>>(matrix)), compiler_(compiler) { decompose_function_ = NULL; diff --git a/include/micm/solver/linear_solver.inl b/include/micm/solver/linear_solver.inl index d51f74eba..6ad48eb05 100644 --- a/include/micm/solver/linear_solver.inl +++ b/include/micm/solver/linear_solver.inl @@ -55,7 +55,7 @@ namespace micm : LinearSolver( matrix, initial_value, - [](const SparseMatrixPolicy& m) -> LuDecompositionPolicy { return LuDecompositionPolicy(m); }) + [](const SparseMatrixPolicy& m) -> LuDecompositionPolicy { return LuDecomposition::Create(m); }) { } @@ -70,7 +70,7 @@ namespace micm Uij_xj_(), lu_decomp_(create_lu_decomp(matrix)) { - auto lu = lu_decomp_.GetLUMatrices(matrix, initial_value); + auto lu = lu_decomp_.template GetLUMatrices(matrix, initial_value); auto lower_matrix = std::move(lu.first); auto upper_matrix = std::move(lu.second); for (std::size_t i = 0; i < lower_matrix[0].size(); ++i) diff --git a/include/micm/solver/lu_decomposition.hpp b/include/micm/solver/lu_decomposition.hpp index 8c204184b..aa5c5b442 100644 --- a/include/micm/solver/lu_decomposition.hpp +++ b/include/micm/solver/lu_decomposition.hpp @@ -75,15 +75,26 @@ namespace micm /// @brief Construct an LU decomposition algorithm for a given sparse matrix /// @param matrix Sparse matrix - template - LuDecomposition(const SparseMatrix& matrix); + template + LuDecomposition(const SparseMatrix& matrix); + + /// @brief Create an LU decomposition algorithm for a given sparse matrix policy + /// @param matrix Sparse matrix + template class SparseMatrixPolicy> + static LuDecomposition Create(const SparseMatrixPolicy& matrix); + template + static LuDecomposition Create(const SparseMatrixPolicy& matrix); /// @brief Create sparse L and U matrices for a given A matrix /// @param A Sparse matrix that will be decomposed /// @return L and U Sparse matrices - template - static std::pair, SparseMatrix> GetLUMatrices( - const SparseMatrix& A, + template class SparseMatrixPolicy> + static std::pair, SparseMatrixPolicy> GetLUMatrices( + const SparseMatrixPolicy& A, + T initial_value); + template + static std::pair GetLUMatrices( + const SparseMatrixPolicy& A, T initial_value); /// @brief Perform an LU decomposition on a given A matrix @@ -110,6 +121,13 @@ namespace micm SparseMatrixPolicy& L, SparseMatrixPolicy& U, bool& is_singular) const; + + private: + + /// @brief Initialize arrays for the LU decomposition + /// @param A Sparse matrix to decompose + template + void Initialize(const SparseMatrixPolicy& A, T initial_value); }; } // namespace micm diff --git a/include/micm/solver/lu_decomposition.inl b/include/micm/solver/lu_decomposition.inl index 08ba9e816..ccc2e8020 100644 --- a/include/micm/solver/lu_decomposition.inl +++ b/include/micm/solver/lu_decomposition.inl @@ -8,11 +8,33 @@ namespace micm { } - template - inline LuDecomposition::LuDecomposition(const SparseMatrix& matrix) + template + inline LuDecomposition::LuDecomposition(const SparseMatrix& matrix) + { + Initialize(matrix); + } + + template class SparseMatrixPolicy> + inline LuDecomposition LuDecomposition::Create(const SparseMatrixPolicy& matrix) + { + LuDecomposition lu_decomp{}; + lu_decomp.Initialize>(matrix, T{}); + return lu_decomp; + } + + template + inline LuDecomposition LuDecomposition::Create(const SparseMatrixPolicy& matrix) + { + LuDecomposition lu_decomp{}; + lu_decomp.Initialize(matrix, T{}); + return lu_decomp; + } + + template + inline void LuDecomposition::Initialize(const SparseMatrixPolicy& matrix, T initial_value) { std::size_t n = matrix[0].size(); - auto LU = GetLUMatrices(matrix, T{}); + auto LU = GetLUMatrices(matrix, initial_value); const auto& L_row_start = LU.first.RowStartVector(); const auto& L_row_ids = LU.first.RowIdsVector(); const auto& U_row_start = LU.second.RowStartVector(); @@ -82,9 +104,17 @@ namespace micm } } - template - inline std::pair, SparseMatrix> LuDecomposition::GetLUMatrices( - const SparseMatrix& A, + template class SparseMatrixPolicy> + inline std::pair, SparseMatrixPolicy> LuDecomposition::GetLUMatrices( + const SparseMatrixPolicy& A, + T initial_value) + { + return GetLUMatrices>(A, initial_value); + } + + template + inline std::pair LuDecomposition::GetLUMatrices( + const SparseMatrixPolicy& A, T initial_value) { std::size_t n = A[0].size(); @@ -129,18 +159,18 @@ namespace micm } } auto L_builder = - micm::SparseMatrix::create(n).number_of_blocks(A.size()).initial_value(initial_value); + SparseMatrixPolicy::create(n).number_of_blocks(A.size()).initial_value(initial_value); for (auto& pair : L_ids) { L_builder = L_builder.with_element(pair.first, pair.second); } auto U_builder = - micm::SparseMatrix::create(n).number_of_blocks(A.size()).initial_value(initial_value); + SparseMatrixPolicy::create(n).number_of_blocks(A.size()).initial_value(initial_value); for (auto& pair : U_ids) { U_builder = U_builder.with_element(pair.first, pair.second); } - std::pair, SparseMatrix> LU(L_builder, U_builder); + std::pair LU(L_builder, U_builder); return LU; } diff --git a/include/micm/solver/state.inl b/include/micm/solver/state.inl index 366740d6e..66fb835fb 100644 --- a/include/micm/solver/state.inl +++ b/include/micm/solver/state.inl @@ -42,7 +42,7 @@ namespace micm state_size_ ); - auto lu = LuDecomposition::GetLUMatrices(jacobian_, 1.0e-30); + auto lu = LuDecomposition::GetLUMatrices(jacobian_, 1.0e-30); auto lower_matrix = std::move(lu.first); auto upper_matrix = std::move(lu.second); lower_matrix_ = lower_matrix; diff --git a/test/unit/solver/test_cuda_lu_decomposition.cpp b/test/unit/solver/test_cuda_lu_decomposition.cpp index a18e25336..094c4b5c7 100644 --- a/test/unit/solver/test_cuda_lu_decomposition.cpp +++ b/test/unit/solver/test_cuda_lu_decomposition.cpp @@ -98,8 +98,8 @@ void testRandomMatrix(size_t n_grids) check_results( A, gpu_LU.first, gpu_LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); }); - micm::LuDecomposition cpu_lud(A); - auto cpu_LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); + micm::LuDecomposition cpu_lud = micm::LuDecomposition::Create(A); + auto cpu_LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); cpu_lud.Decompose(A, cpu_LU.first, cpu_LU.second); // checking GPU result again CPU diff --git a/test/unit/solver/test_linear_solver_policy.hpp b/test/unit/solver/test_linear_solver_policy.hpp index b05e65c68..9c13f1b19 100644 --- a/test/unit/solver/test_linear_solver_policy.hpp +++ b/test/unit/solver/test_linear_solver_policy.hpp @@ -85,7 +85,7 @@ void testDenseMatrix(const std::function(A, 1.0e-30); auto lower_matrix = std::move(lu.first); auto upper_matrix = std::move(lu.second); solver.Factor(A, lower_matrix, upper_matrix); @@ -123,7 +123,7 @@ void testRandomMatrix( b[i_block][i] = get_double(); LinearSolverPolicy solver = create_linear_solver(A, 1.0e-30); - auto lu = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); + auto lu = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); auto lower_matrix = std::move(lu.first); auto upper_matrix = std::move(lu.second); solver.Factor(A, lower_matrix, upper_matrix); @@ -152,7 +152,7 @@ void testDiagonalMatrix( A[i_block][i][i] = get_double(); LinearSolverPolicy solver = create_linear_solver(A, 1.0e-30); - auto lu = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); + auto lu = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); auto lower_matrix = std::move(lu.first); auto upper_matrix = std::move(lu.second); solver.Factor(A, lower_matrix, upper_matrix); @@ -188,11 +188,11 @@ void testMarkowitzReordering() builder = builder.with_element(i, j); SparseMatrixPolicy reordered_jac{ builder }; - auto orig_LU_calc = micm::LuDecomposition{ orig_jac }; - auto reordered_LU_calc = micm::LuDecomposition{ reordered_jac }; + auto orig_LU_calc = micm::LuDecomposition::Create( orig_jac ); + auto reordered_LU_calc = micm::LuDecomposition::Create( reordered_jac ); - auto orig_LU = orig_LU_calc.GetLUMatrices(orig_jac, 0.0); - auto reordered_LU = reordered_LU_calc.GetLUMatrices(reordered_jac, 0.0); + auto orig_LU = orig_LU_calc.template GetLUMatrices(orig_jac, 0.0); + auto reordered_LU = reordered_LU_calc.template GetLUMatrices(reordered_jac, 0.0); std::size_t sum_orig = 0; std::size_t sum_reordered = 0; diff --git a/test/unit/solver/test_lu_decomposition.cpp b/test/unit/solver/test_lu_decomposition.cpp index 5a91e43e3..561f7eb35 100644 --- a/test/unit/solver/test_lu_decomposition.cpp +++ b/test/unit/solver/test_lu_decomposition.cpp @@ -21,76 +21,76 @@ using Group4SparseVectorMatrix = micm::SparseMatrix( - [](const SparseMatrixTest& matrix) -> micm::LuDecomposition { return micm::LuDecomposition{ matrix }; }); + [](const SparseMatrixTest& matrix) -> micm::LuDecomposition { return micm::LuDecomposition::Create( matrix ); }); } TEST(LuDecomposition, SingularMatrixStandardOrdering) { testSingularMatrix( - [](const SparseMatrixTest& matrix) -> micm::LuDecomposition { return micm::LuDecomposition{ matrix }; }); + [](const SparseMatrixTest& matrix) -> micm::LuDecomposition { return micm::LuDecomposition::Create( matrix ); }); } TEST(LuDecomposition, RandomMatrixStandardOrdering) { testRandomMatrix( - [](const SparseMatrixTest& matrix) -> micm::LuDecomposition { return micm::LuDecomposition{ matrix }; }, 5); + [](const SparseMatrixTest& matrix) -> micm::LuDecomposition { return micm::LuDecomposition::Create( matrix ); }, 5); } TEST(LuDecomposition, DiagonalMatrixStandardOrdering) { testDiagonalMatrix( - [](const SparseMatrixTest& matrix) -> micm::LuDecomposition { return micm::LuDecomposition{ matrix }; }, 5); + [](const SparseMatrixTest& matrix) -> micm::LuDecomposition { return micm::LuDecomposition::Create( matrix ); }, 5); } TEST(LuDecomposition, DenseMatrixVectorOrdering) { testDenseMatrix( [](const Group1SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }); + { return micm::LuDecomposition::Create( matrix ); }); testDenseMatrix( [](const Group2SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }); + { return micm::LuDecomposition::Create( matrix ); }); testDenseMatrix( [](const Group3SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }); + { return micm::LuDecomposition::Create( matrix ); }); testDenseMatrix( [](const Group4SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }); + { return micm::LuDecomposition::Create( matrix ); }); } TEST(LuDecomposition, SingluarMatrixVectorOrdering) { testSingularMatrix( [](const Group1SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }); + { return micm::LuDecomposition::Create( matrix ); }); testSingularMatrix( [](const Group2SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }); + { return micm::LuDecomposition::Create( matrix ); }); testSingularMatrix( [](const Group3SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }); + { return micm::LuDecomposition::Create( matrix ); }); testSingularMatrix( [](const Group4SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }); + { return micm::LuDecomposition::Create( matrix ); }); } TEST(LuDecomposition, RandomMatrixVectorOrdering) { testRandomMatrix( [](const Group1SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }, + { return micm::LuDecomposition::Create( matrix ); }, 5); testRandomMatrix( [](const Group2SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }, + { return micm::LuDecomposition::Create( matrix ); }, 5); testRandomMatrix( [](const Group3SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }, + { return micm::LuDecomposition::Create( matrix ); }, 5); testRandomMatrix( [](const Group4SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }, + { return micm::LuDecomposition::Create( matrix ); }, 5); } @@ -98,18 +98,18 @@ TEST(LuDecomposition, DiagonalMatrixVectorOrdering) { testDiagonalMatrix( [](const Group1SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }, + { return micm::LuDecomposition::Create( matrix ); }, 5); testDiagonalMatrix( [](const Group2SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }, + { return micm::LuDecomposition::Create( matrix ); }, 5); testDiagonalMatrix( [](const Group3SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }, + { return micm::LuDecomposition::Create( matrix ); }, 5); testDiagonalMatrix( [](const Group4SparseVectorMatrix& matrix) -> micm::LuDecomposition - { return micm::LuDecomposition{ matrix }; }, + { return micm::LuDecomposition::Create( matrix ); }, 5); } diff --git a/test/unit/solver/test_lu_decomposition_policy.hpp b/test/unit/solver/test_lu_decomposition_policy.hpp index ec375242a..e940555b1 100644 --- a/test/unit/solver/test_lu_decomposition_policy.hpp +++ b/test/unit/solver/test_lu_decomposition_policy.hpp @@ -98,7 +98,7 @@ void testDenseMatrix(const std::function(A, 1.0e-30); lud.template Decompose(A, LU.first, LU.second); check_results( A, LU.first, LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); }); @@ -120,7 +120,7 @@ void testSingularMatrix(const std::function(A, 1.0E-30); bool is_singular{ false }; lud.template Decompose(A, LU.first, LU.second, is_singular); EXPECT_TRUE(is_singular); @@ -152,7 +152,7 @@ void testRandomMatrix( A[i_block][i][j] = get_double(); LuDecompositionPolicy lud = create_lu_decomp(A); - auto LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); + auto LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); lud.template Decompose(A, LU.first, LU.second); check_results( A, LU.first, LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); }); @@ -176,7 +176,7 @@ void testDiagonalMatrix( A[i_block][i][i] = get_double(); LuDecompositionPolicy lud = create_lu_decomp(A); - auto LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); + auto LU = micm::LuDecomposition::GetLUMatrices(A, 1.0e-30); lud.template Decompose(A, LU.first, LU.second); check_results( A, LU.first, LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); });