From 2b9ad5d2407023ed4fbbe3a636ced2f28127c16d Mon Sep 17 00:00:00 2001 From: Matt Dawson Date: Mon, 14 Aug 2023 14:48:20 -0700 Subject: [PATCH] JITed forcing function (#179) * add llvm to cmake scripts and add Dockerfile * add JIT compiler * add scalar JIT type tests * add pointer type JIT tests * add loops to jit function class * add JIT function tests * fix JIT loop bug * Update include/micm/jit/jit_function.hpp Co-authored-by: Kyle Shores * Update cmake/test_util.cmake Co-authored-by: Kyle Shores * add assert on regenerating JIT function * add JIT forcing function --------- Co-authored-by: Kyle Shores --- include/micm/process/jit_process_set.hpp | 152 +++++++++++-- test/unit/process/CMakeLists.txt | 4 + test/unit/process/test_jit_process_set.cpp | 78 +++++++ test/unit/process/test_process_set.cpp | 213 ++++-------------- test/unit/process/test_process_set_policy.hpp | 169 ++++++++++++++ 5 files changed, 426 insertions(+), 190 deletions(-) create mode 100644 test/unit/process/test_jit_process_set.cpp create mode 100644 test/unit/process/test_process_set_policy.hpp diff --git a/include/micm/process/jit_process_set.hpp b/include/micm/process/jit_process_set.hpp index 46d7b1de6..5a04bb152 100644 --- a/include/micm/process/jit_process_set.hpp +++ b/include/micm/process/jit_process_set.hpp @@ -3,7 +3,10 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once +#include +#include #include +#include namespace micm { @@ -11,52 +14,159 @@ namespace micm /// @brief JIT-compiled solver function calculators for a collection of processes /// The template parameter is the number of grid cells to solve simultaneously template - class JitProcessSet : public ProcessSet + class JitProcessSet : public ProcessSet { std::shared_ptr compiler_; llvm::orc::ResourceTrackerSP resource_tracker_; - void (*forcing_function_)(double*, double*, double*); + void (*forcing_function_)(const double *, const double *, double *); public: /// @brief Create a JITed process set calculator for a given set of processes /// @param compiler JIT compiler /// @param processes Processes to create calculator for /// @param state Solver state + template class MatrixPolicy> JitProcessSet( std::shared_ptr compiler, - const std::vector& processes, - const State& state); + const std::vector &processes, + const State &state); + + ~JitProcessSet(); /// @brief Add forcing terms for the set of processes for the current conditions /// @param rate_constants Current values for the process rate constants (grid cell, process) /// @param state_variables Current state variable values (grid cell, state variable) /// @param forcing Forcing terms for each state variable (grid cell, state variable) + template class MatrixPolicy> void AddForcingTerms( - const VectorMatrix& rate_constants, - const VectorMatrix& state_variables, - VectorMatrix& forcing) const; + const MatrixPolicy &rate_constants, + const MatrixPolicy &state_variables, + MatrixPolicy &forcing) const; }; - inline JitProcessSet::JitProcessSet( + template + template class MatrixPolicy> + inline JitProcessSet::JitProcessSet( std::shared_ptr compiler, - const std::vector& processes, - const State& state) - : ProcessSet(processes, state), + const std::vector &processes, + const State &state) + : ProcessSet(processes, state), compiler_(compiler) { - JitFunction func = JitFunction::create(compiler.get()) + MatrixPolicy test_matrix; + if (test_matrix.GroupVectorSize() != L) + { + std::cerr << "Vector matrix group size invalid for JitProcessSet"; + std::exit(micm::ExitCodes::InvalidMatrixDimension); + } + JitFunction func = JitFunction::create(compiler) .name("add_forcing_terms") - .arguments({ { "rate constants", JitType::Double }, - { "state variables", JitType::Double }, - { "forcing", JitType::Double } }), - .return_type(JitType::Void); + .arguments({ { "rate constants", JitType::DoublePtr }, + { "state variables", JitType::DoublePtr }, + { "forcing", JitType::DoublePtr } }) + .return_type(JitType::Void); + llvm::Type *double_type = func.GetType(JitType::Double); + llvm::Value *zero = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, 0)); + llvm::Type *rate_array_type = llvm::ArrayType::get(double_type, L); + llvm::AllocaInst *rate_array = func.builder_->CreateAlloca( + rate_array_type, llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, 1)), "rate_array"); + + auto react_ids = reactant_ids_.begin(); + auto prod_ids = product_ids_.begin(); + auto yields = yields_.begin(); for (std::size_t i_rxn = 0; i_rxn < number_of_reactants_.size(); ++i_rxn) { - llvm::Value* rc_start = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, i_rxn * L)); - llvm::Value* rc_end = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, i_rxn * L + L)); - llvm::ArrayType* rate_arr = llvm::ArrayType::get(func.GetType(JitType::Int32), ) auto loop = - func.StartLoop("rate constant loop", rc_start, rc_end); - llvm::Value* rate = func.GetArrayElement(func.arguments_[0], index_list, micm::JitType::Double); + llvm::Value *rc_start = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, i_rxn * L)); + + // save rate constant in rate array for each grid cell + auto loop = func.StartLoop("rate constant", 0, L); + llvm::Value *ptr_index[1]; + ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, rc_start); + llvm::Value *rate_const = func.GetArrayElement(func.arguments_[0], ptr_index, JitType::Double); + llvm::Value *array_index[2]; + array_index[0] = zero; + array_index[1] = loop.index_; + llvm::Value *rate_ptr = func.builder_->CreateInBoundsGEP(rate_array_type, rate_array, array_index); + func.builder_->CreateStore(rate_const, rate_ptr); + func.EndLoop(loop); + + // rates[i_cell] *= reactant_concentration for each reactant + for (std::size_t i_react = 0; i_react < number_of_reactants_[i_rxn]; ++i_react) + { + loop = func.StartLoop("rate calc", 0, L); + llvm::Value *react_id = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, react_ids[i_react] * L)); + ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, react_id); + llvm::Value *react_conc = func.GetArrayElement(func.arguments_[1], ptr_index, JitType::Double); + array_index[1] = loop.index_; + rate_ptr = func.builder_->CreateInBoundsGEP(rate_array_type, rate_array, array_index); + llvm::Value *rate = func.builder_->CreateLoad(double_type, rate_ptr, "rate"); + rate = func.builder_->CreateFMul(rate, react_conc, "rate"); + func.builder_->CreateStore(rate, rate_ptr); + func.EndLoop(loop); + } + + // set forcing for each reactant f[i_react][i_cell] -= rate[i_cell] + for (std::size_t i_react = 0; i_react < number_of_reactants_[i_rxn]; ++i_react) + { + loop = func.StartLoop("reactant forcing", 0, L); + llvm::Value *react_id = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, react_ids[i_react] * L)); + ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, react_id); + llvm::Value *react_forcing_ptr = func.builder_->CreateGEP(double_type, func.arguments_[2].ptr_, ptr_index); + llvm::Value *react_forcing = func.builder_->CreateLoad(double_type, react_forcing_ptr, "forcing"); + array_index[1] = loop.index_; + rate_ptr = func.builder_->CreateInBoundsGEP(rate_array_type, rate_array, array_index); + llvm::Value *rate = func.builder_->CreateLoad(double_type, rate_ptr, "rate"); + react_forcing = func.builder_->CreateFSub(react_forcing, rate, "forcing"); + func.builder_->CreateStore(react_forcing, react_forcing_ptr); + func.EndLoop(loop); + } + + // set forcing for each product f[i_prod][i_cell] += yield * rate[i_cell] + for (std::size_t i_prod = 0; i_prod < number_of_products_[i_rxn]; ++i_prod) + { + loop = func.StartLoop("product forcing", 0, L); + llvm::Value *prod_id = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, prod_ids[i_prod] * L)); + ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, prod_id); + llvm::Value *prod_forcing_ptr = func.builder_->CreateGEP(double_type, func.arguments_[2].ptr_, ptr_index); + llvm::Value *prod_forcing = func.builder_->CreateLoad(double_type, prod_forcing_ptr, "forcing"); + array_index[1] = loop.index_; + rate_ptr = func.builder_->CreateInBoundsGEP(rate_array_type, rate_array, array_index); + llvm::Value *rate = func.builder_->CreateLoad(double_type, rate_ptr, "rate"); + llvm::Value *yield = llvm::ConstantFP::get(*(func.context_), llvm::APFloat(yields[i_prod])); + rate = func.builder_->CreateFMul(rate, yield, "rate_yield"); + prod_forcing = func.builder_->CreateFAdd(prod_forcing, rate, "forcing"); + func.builder_->CreateStore(prod_forcing, prod_forcing_ptr); + func.EndLoop(loop); + } + react_ids += number_of_reactants_[i_rxn]; + prod_ids += number_of_products_[i_rxn]; + yields += number_of_products_[i_rxn]; } + func.builder_->CreateRetVoid(); + + auto target = func.Generate(); + forcing_function_ = (void (*)(const double *, const double *, double *))(intptr_t)target.second; + resource_tracker_ = target.first; } + + template + JitProcessSet::~JitProcessSet() + { + if (resource_tracker_) + { + llvm::ExitOnError exit_on_error; + exit_on_error(resource_tracker_->remove()); + } + } + + template + template class MatrixPolicy> + void JitProcessSet::AddForcingTerms( + const MatrixPolicy &rate_constants, + const MatrixPolicy &state_variables, + MatrixPolicy &forcing) const + { + forcing_function_(rate_constants.AsVector().data(), state_variables.AsVector().data(), forcing.AsVector().data()); + } + } // namespace micm \ No newline at end of file diff --git a/test/unit/process/CMakeLists.txt b/test/unit/process/CMakeLists.txt index 5c3717257..51f240390 100644 --- a/test/unit/process/CMakeLists.txt +++ b/test/unit/process/CMakeLists.txt @@ -22,4 +22,8 @@ endif() if(ENABLE_OPENACC) create_standard_test(NAME openacc_process_set SOURCES test_openacc_process_set.cpp LIBRARIES musica::micm_openacc) +endif() + +if(ENABLE_LLVM) + create_standard_test(NAME jit_process_set SOURCES test_jit_process_set.cpp) endif() \ No newline at end of file diff --git a/test/unit/process/test_jit_process_set.cpp b/test/unit/process/test_jit_process_set.cpp new file mode 100644 index 000000000..51190bfd5 --- /dev/null +++ b/test/unit/process/test_jit_process_set.cpp @@ -0,0 +1,78 @@ +#include + +#include +#include +#include +#include +#include + +#include "test_process_set_policy.hpp" + +template +using Group2VectorMatrix = micm::VectorMatrix; +template +using Group2000VectorMatrix = micm::VectorMatrix; +template +using Group3000VectorMatrix = micm::VectorMatrix; +template +using Group4000VectorMatrix = micm::VectorMatrix; + +template +using Group2SparseVectorMatrix = micm::SparseMatrix>; + +TEST(JitProcessSet, VectorMatrix) +{ + auto jit{ micm::JitCompiler::create() }; + if (auto err = jit.takeError()) + { + llvm::logAllUnhandledErrors(std::move(err), llvm::errs(), "[JIT Error]"); + EXPECT_TRUE(false); + } + testProcessSet>( + [&](const std::vector& processes, + const micm::State& state) -> micm::JitProcessSet<2> { + return micm::JitProcessSet<2>{ jit.get(), processes, state }; + }); +} + +TEST(RandomJitProcessSet, VectorMatrix) +{ + auto jit{ micm::JitCompiler::create() }; + if (auto err = jit.takeError()) + { + llvm::logAllUnhandledErrors(std::move(err), llvm::errs(), "[JIT Error]"); + EXPECT_TRUE(false); + } + testRandomSystem>( + 2000, + 20, + 30, + [&](const std::vector& processes, + const micm::State& state) -> micm::JitProcessSet<2000> { + return micm::JitProcessSet<2000>{ jit.get(), processes, state }; + }); + testRandomSystem>( + 3000, + 50, + 40, + [&](const std::vector& processes, + const micm::State& state) -> micm::JitProcessSet<3000> { + return micm::JitProcessSet<3000>{ jit.get(), processes, state }; + }); + testRandomSystem>( + 3000, + 30, + 20, + [&](const std::vector& processes, + const micm::State& state) -> micm::JitProcessSet<3000> { + return micm::JitProcessSet<3000>{ jit.get(), processes, state }; + }); + testRandomSystem>( + 4000, + 100, + 80, + [&](const std::vector& processes, + const micm::State& state) -> micm::JitProcessSet<4000> { + return micm::JitProcessSet<4000>{ jit.get(), processes, state }; + }); +} \ No newline at end of file diff --git a/test/unit/process/test_process_set.cpp b/test/unit/process/test_process_set.cpp index d39a5960f..265008ce3 100644 --- a/test/unit/process/test_process_set.cpp +++ b/test/unit/process/test_process_set.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include @@ -9,122 +8,10 @@ #include #include -using yields = std::pair; -using index_pair = std::pair; - -void compare_pair(const index_pair& a, const index_pair& b) -{ - EXPECT_EQ(a.first, b.first); - EXPECT_EQ(a.second, b.second); -} - -template class MatrixPolicy, template class SparseMatrixPolicy> -void testProcessSet() -{ - auto foo = micm::Species("foo"); - auto bar = micm::Species("bar"); - auto baz = micm::Species("baz"); - auto quz = micm::Species("quz"); - auto quuz = micm::Species("quuz"); - - micm::Phase gas_phase{ std::vector{ foo, bar, baz, quz, quuz } }; - - micm::State state{ micm::StateParameters{ .state_variable_names_{ "foo", "bar", "baz", "quz", "quuz" }, - .number_of_grid_cells_ = 2, - .number_of_rate_constants_ = 3 } }; - - micm::Process r1 = - micm::Process::create().reactants({ foo, baz }).products({ yields(bar, 1), yields(quuz, 2.4) }).phase(gas_phase); - - micm::Process r2 = - micm::Process::create().reactants({ bar }).products({ yields(foo, 1), yields(quz, 1.4) }).phase(gas_phase); - - micm::Process r3 = micm::Process::create().reactants({ quz }).products({}).phase(gas_phase); - - micm::ProcessSet set{ std::vector{ r1, r2, r3 }, state }; - - EXPECT_EQ(state.variables_.size(), 2); - EXPECT_EQ(state.variables_[0].size(), 5); - state.variables_[0] = { 0.1, 0.2, 0.3, 0.4, 0.5 }; - state.variables_[1] = { 1.1, 1.2, 1.3, 1.4, 1.5 }; - MatrixPolicy rate_constants{ 2, 3 }; - rate_constants[0] = { 10.0, 20.0, 30.0 }; - rate_constants[1] = { 110.0, 120.0, 130.0 }; - - MatrixPolicy forcing{ 2, 5, 1000.0 }; - - set.AddForcingTerms(rate_constants, state.variables_, forcing); - EXPECT_EQ(forcing[0][0], 1000.0 - 10.0 * 0.1 * 0.3 + 20.0 * 0.2); - EXPECT_EQ(forcing[1][0], 1000.0 - 110.0 * 1.1 * 1.3 + 120.0 * 1.2); - EXPECT_EQ(forcing[0][1], 1000.0 + 10.0 * 0.1 * 0.3 - 20.0 * 0.2); - EXPECT_EQ(forcing[1][1], 1000.0 + 110.0 * 1.1 * 1.3 - 120.0 * 1.2); - EXPECT_EQ(forcing[0][2], 1000.0 - 10.0 * 0.1 * 0.3); - EXPECT_EQ(forcing[1][2], 1000.0 - 110.0 * 1.1 * 1.3); - EXPECT_EQ(forcing[0][3], 1000.0 + 20.0 * 0.2 * 1.4 - 30.0 * 0.4); - EXPECT_EQ(forcing[1][3], 1000.0 + 120.0 * 1.2 * 1.4 - 130.0 * 1.4); - EXPECT_EQ(forcing[0][4], 1000.0 + 10.0 * 0.1 * 0.3 * 2.4); - EXPECT_EQ(forcing[1][4], 1000.0 + 110.0 * 1.1 * 1.3 * 2.4); - - auto non_zero_elements = set.NonZeroJacobianElements(); - // ---- foo bar baz quz quuz - // foo 0 1 2 - - - // bar 3 4 5 - - - // baz 6 - 7 - - - // quz - 8 - 9 - - // quuz 10 - 11 - - - - auto elem = non_zero_elements.begin(); - compare_pair(*elem, index_pair(0, 0)); - compare_pair(*(++elem), index_pair(0, 1)); - compare_pair(*(++elem), index_pair(0, 2)); - compare_pair(*(++elem), index_pair(1, 0)); - compare_pair(*(++elem), index_pair(1, 1)); - compare_pair(*(++elem), index_pair(1, 2)); - compare_pair(*(++elem), index_pair(2, 0)); - compare_pair(*(++elem), index_pair(2, 2)); - compare_pair(*(++elem), index_pair(3, 1)); - compare_pair(*(++elem), index_pair(3, 3)); - compare_pair(*(++elem), index_pair(4, 0)); - compare_pair(*(++elem), index_pair(4, 2)); - - auto builder = SparseMatrixPolicy::create(5).number_of_blocks(2).initial_value(100.0); - for (auto& elem : non_zero_elements) - builder = builder.with_element(elem.first, elem.second); - SparseMatrixPolicy jacobian{ builder }; - set.SetJacobianFlatIds(jacobian); - set.AddJacobianTerms(rate_constants, state.variables_, jacobian); - EXPECT_EQ(jacobian[0][0][0], 100.0 - 10.0 * 0.3); // foo -> foo - EXPECT_EQ(jacobian[1][0][0], 100.0 - 110.0 * 1.3); - EXPECT_EQ(jacobian[0][0][1], 100.0 + 20.0); // foo -> bar - EXPECT_EQ(jacobian[1][0][1], 100.0 + 120.0); - EXPECT_EQ(jacobian[0][0][2], 100.0 - 10.0 * 0.1); // foo -> baz - EXPECT_EQ(jacobian[1][0][2], 100.0 - 110.0 * 1.1); - EXPECT_EQ(jacobian[0][1][0], 100.0 + 10.0 * 0.3); // bar -> foo - EXPECT_EQ(jacobian[1][1][0], 100.0 + 110.0 * 1.3); - EXPECT_EQ(jacobian[0][1][1], 100.0 - 20.0); // bar -> bar - EXPECT_EQ(jacobian[1][1][1], 100.0 - 120.0); - EXPECT_EQ(jacobian[0][1][2], 100.0 + 10.0 * 0.1); // bar -> baz - EXPECT_EQ(jacobian[1][1][2], 100.0 + 110.0 * 1.1); - EXPECT_EQ(jacobian[0][2][0], 100.0 - 10.0 * 0.3); // baz -> foo - EXPECT_EQ(jacobian[1][2][0], 100.0 - 110.0 * 1.3); - EXPECT_EQ(jacobian[0][2][2], 100.0 - 10.0 * 0.1); // baz -> baz - EXPECT_EQ(jacobian[1][2][2], 100.0 - 110.0 * 1.1); - EXPECT_EQ(jacobian[0][3][1], 100.0 + 1.4 * 20.0); // quz -> bar - EXPECT_EQ(jacobian[1][3][1], 100.0 + 1.4 * 120.0); - EXPECT_EQ(jacobian[0][3][3], 100.0 - 30.0); // quz -> quz - EXPECT_EQ(jacobian[1][3][3], 100.0 - 130.0); - EXPECT_EQ(jacobian[0][4][0], 100.0 + 2.4 * 10.0 * 0.3); // quuz -> foo - EXPECT_EQ(jacobian[1][4][0], 100.0 + 2.4 * 110.0 * 1.3); - EXPECT_EQ(jacobian[0][4][2], 100.0 + 2.4 * 10.0 * 0.1); // quuz -> baz - EXPECT_EQ(jacobian[1][4][2], 100.0 + 2.4 * 110.0 * 1.1); -} +#include "test_process_set_policy.hpp" template using SparseMatrixTest = micm::SparseMatrix; -TEST(ProcessSet, Matrix) -{ - testProcessSet(); -} template using Group1VectorMatrix = micm::VectorMatrix; @@ -144,67 +31,55 @@ using Group3SparseVectorMatrix = micm::SparseMatrix using Group4SparseVectorMatrix = micm::SparseMatrix>; -TEST(ProcessSet, VectorMatrix) +TEST(ProcessSet, Matrix) { - testProcessSet(); - testProcessSet(); - testProcessSet(); - testProcessSet(); + testProcessSet( + [](const std::vector& processes, const micm::State& state) -> micm::ProcessSet { + return micm::ProcessSet{ processes, state }; + }); } -template class MatrixPolicy> -void testRandomSystem(std::size_t n_cells, std::size_t n_reactions, std::size_t n_species) +TEST(ProcessSet, VectorMatrix) { - auto get_n_react = std::bind(std::uniform_int_distribution<>(0, 3), std::default_random_engine()); - auto get_n_product = std::bind(std::uniform_int_distribution<>(0, 10), std::default_random_engine()); - auto get_species_id = std::bind(std::uniform_int_distribution<>(0, n_species - 1), std::default_random_engine()); - auto get_double = std::bind(std::lognormal_distribution(-2.0, 4.0), std::default_random_engine()); - - std::vector species{}; - std::vector species_names{}; - for (std::size_t i = 0; i < n_species; ++i) - { - species.push_back(micm::Species{ std::to_string(i) }); - species_names.push_back(std::to_string(i)); - } - micm::Phase gas_phase{ species }; - micm::State state{ micm::StateParameters{ .state_variable_names_{ species_names }, - .number_of_grid_cells_ = n_cells, - .number_of_rate_constants_ = n_reactions } }; - std::vector processes{}; - for (std::size_t i = 0; i < n_reactions; ++i) - { - auto n_react = get_n_react(); - std::vector reactants{}; - for (std::size_t i_react = 0; i_react < n_react; ++i_react) - { - reactants.push_back({ std::to_string(get_species_id()) }); - } - auto n_product = get_n_product(); - std::vector products{}; - for (std::size_t i_prod = 0; i_prod < n_product; ++i_prod) - { - products.push_back(yields(std::to_string(get_species_id()), 1.2)); - } - auto proc = micm::Process(micm::Process::create().reactants(reactants).products(products).phase(gas_phase)); - processes.push_back(proc); - } - micm::ProcessSet set{ processes, state }; - - for (auto& elem : state.variables_.AsVector()) - elem = get_double(); - - MatrixPolicy rate_constants{ n_cells, n_reactions }; - for (auto& elem : rate_constants.AsVector()) - elem = get_double(); - MatrixPolicy forcing{ n_cells, n_species, 1000.0 }; - - set.AddForcingTerms(rate_constants, state.variables_, forcing); + testProcessSet( + [](const std::vector& processes, const micm::State& state) -> micm::ProcessSet { + return micm::ProcessSet{ processes, state }; + }); + testProcessSet( + [](const std::vector& processes, const micm::State& state) -> micm::ProcessSet { + return micm::ProcessSet{ processes, state }; + }); + testProcessSet( + [](const std::vector& processes, const micm::State& state) -> micm::ProcessSet { + return micm::ProcessSet{ processes, state }; + }); + testProcessSet( + [](const std::vector& processes, const micm::State& state) -> micm::ProcessSet { + return micm::ProcessSet{ processes, state }; + }); } TEST(RandomProcessSet, Matrix) { - testRandomSystem(2000, 500, 400); - testRandomSystem(3000, 300, 200); - testRandomSystem(4000, 100, 80); + testRandomSystem( + 2000, + 500, + 400, + [](const std::vector& processes, const micm::State& state) -> micm::ProcessSet { + return micm::ProcessSet{ processes, state }; + }); + testRandomSystem( + 3000, + 300, + 200, + [](const std::vector& processes, const micm::State& state) -> micm::ProcessSet { + return micm::ProcessSet{ processes, state }; + }); + testRandomSystem( + 4000, + 100, + 80, + [](const std::vector& processes, const micm::State& state) -> micm::ProcessSet { + return micm::ProcessSet{ processes, state }; + }); } \ No newline at end of file diff --git a/test/unit/process/test_process_set_policy.hpp b/test/unit/process/test_process_set_policy.hpp new file mode 100644 index 000000000..1e050b006 --- /dev/null +++ b/test/unit/process/test_process_set_policy.hpp @@ -0,0 +1,169 @@ +#include + +#include +#include + +using yields = std::pair; +using index_pair = std::pair; + +void compare_pair(const index_pair& a, const index_pair& b) +{ + EXPECT_EQ(a.first, b.first); + EXPECT_EQ(a.second, b.second); +} + +template class MatrixPolicy, template class SparseMatrixPolicy, class ProcessSetPolicy> +void testProcessSet( + const std::function&, const micm::State&)> create_set) +{ + auto foo = micm::Species("foo"); + auto bar = micm::Species("bar"); + auto baz = micm::Species("baz"); + auto quz = micm::Species("quz"); + auto quuz = micm::Species("quuz"); + + micm::Phase gas_phase{ std::vector{ foo, bar, baz, quz, quuz } }; + + micm::State state{ micm::StateParameters{ .state_variable_names_{ "foo", "bar", "baz", "quz", "quuz" }, + .number_of_grid_cells_ = 2, + .number_of_rate_constants_ = 3 } }; + + micm::Process r1 = + micm::Process::create().reactants({ foo, baz }).products({ yields(bar, 1), yields(quuz, 2.4) }).phase(gas_phase); + + micm::Process r2 = + micm::Process::create().reactants({ bar }).products({ yields(foo, 1), yields(quz, 1.4) }).phase(gas_phase); + + micm::Process r3 = micm::Process::create().reactants({ quz }).products({}).phase(gas_phase); + + ProcessSetPolicy set = create_set(std::vector{ r1, r2, r3 }, state); + + EXPECT_EQ(state.variables_.size(), 2); + EXPECT_EQ(state.variables_[0].size(), 5); + state.variables_[0] = { 0.1, 0.2, 0.3, 0.4, 0.5 }; + state.variables_[1] = { 1.1, 1.2, 1.3, 1.4, 1.5 }; + MatrixPolicy rate_constants{ 2, 3 }; + rate_constants[0] = { 10.0, 20.0, 30.0 }; + rate_constants[1] = { 110.0, 120.0, 130.0 }; + + MatrixPolicy forcing{ 2, 5, 1000.0 }; + + set.template AddForcingTerms(rate_constants, state.variables_, forcing); + EXPECT_EQ(forcing[0][0], 1000.0 - 10.0 * 0.1 * 0.3 + 20.0 * 0.2); + EXPECT_EQ(forcing[1][0], 1000.0 - 110.0 * 1.1 * 1.3 + 120.0 * 1.2); + EXPECT_EQ(forcing[0][1], 1000.0 + 10.0 * 0.1 * 0.3 - 20.0 * 0.2); + EXPECT_EQ(forcing[1][1], 1000.0 + 110.0 * 1.1 * 1.3 - 120.0 * 1.2); + EXPECT_EQ(forcing[0][2], 1000.0 - 10.0 * 0.1 * 0.3); + EXPECT_EQ(forcing[1][2], 1000.0 - 110.0 * 1.1 * 1.3); + EXPECT_EQ(forcing[0][3], 1000.0 + 20.0 * 0.2 * 1.4 - 30.0 * 0.4); + EXPECT_EQ(forcing[1][3], 1000.0 + 120.0 * 1.2 * 1.4 - 130.0 * 1.4); + EXPECT_EQ(forcing[0][4], 1000.0 + 10.0 * 0.1 * 0.3 * 2.4); + EXPECT_EQ(forcing[1][4], 1000.0 + 110.0 * 1.1 * 1.3 * 2.4); + + auto non_zero_elements = set.NonZeroJacobianElements(); + // ---- foo bar baz quz quuz + // foo 0 1 2 - - + // bar 3 4 5 - - + // baz 6 - 7 - - + // quz - 8 - 9 - + // quuz 10 - 11 - - + + auto elem = non_zero_elements.begin(); + compare_pair(*elem, index_pair(0, 0)); + compare_pair(*(++elem), index_pair(0, 1)); + compare_pair(*(++elem), index_pair(0, 2)); + compare_pair(*(++elem), index_pair(1, 0)); + compare_pair(*(++elem), index_pair(1, 1)); + compare_pair(*(++elem), index_pair(1, 2)); + compare_pair(*(++elem), index_pair(2, 0)); + compare_pair(*(++elem), index_pair(2, 2)); + compare_pair(*(++elem), index_pair(3, 1)); + compare_pair(*(++elem), index_pair(3, 3)); + compare_pair(*(++elem), index_pair(4, 0)); + compare_pair(*(++elem), index_pair(4, 2)); + + auto builder = SparseMatrixPolicy::create(5).number_of_blocks(2).initial_value(100.0); + for (auto& elem : non_zero_elements) + builder = builder.with_element(elem.first, elem.second); + SparseMatrixPolicy jacobian{ builder }; + set.SetJacobianFlatIds(jacobian); + set.template AddJacobianTerms(rate_constants, state.variables_, jacobian); + EXPECT_EQ(jacobian[0][0][0], 100.0 - 10.0 * 0.3); // foo -> foo + EXPECT_EQ(jacobian[1][0][0], 100.0 - 110.0 * 1.3); + EXPECT_EQ(jacobian[0][0][1], 100.0 + 20.0); // foo -> bar + EXPECT_EQ(jacobian[1][0][1], 100.0 + 120.0); + EXPECT_EQ(jacobian[0][0][2], 100.0 - 10.0 * 0.1); // foo -> baz + EXPECT_EQ(jacobian[1][0][2], 100.0 - 110.0 * 1.1); + EXPECT_EQ(jacobian[0][1][0], 100.0 + 10.0 * 0.3); // bar -> foo + EXPECT_EQ(jacobian[1][1][0], 100.0 + 110.0 * 1.3); + EXPECT_EQ(jacobian[0][1][1], 100.0 - 20.0); // bar -> bar + EXPECT_EQ(jacobian[1][1][1], 100.0 - 120.0); + EXPECT_EQ(jacobian[0][1][2], 100.0 + 10.0 * 0.1); // bar -> baz + EXPECT_EQ(jacobian[1][1][2], 100.0 + 110.0 * 1.1); + EXPECT_EQ(jacobian[0][2][0], 100.0 - 10.0 * 0.3); // baz -> foo + EXPECT_EQ(jacobian[1][2][0], 100.0 - 110.0 * 1.3); + EXPECT_EQ(jacobian[0][2][2], 100.0 - 10.0 * 0.1); // baz -> baz + EXPECT_EQ(jacobian[1][2][2], 100.0 - 110.0 * 1.1); + EXPECT_EQ(jacobian[0][3][1], 100.0 + 1.4 * 20.0); // quz -> bar + EXPECT_EQ(jacobian[1][3][1], 100.0 + 1.4 * 120.0); + EXPECT_EQ(jacobian[0][3][3], 100.0 - 30.0); // quz -> quz + EXPECT_EQ(jacobian[1][3][3], 100.0 - 130.0); + EXPECT_EQ(jacobian[0][4][0], 100.0 + 2.4 * 10.0 * 0.3); // quuz -> foo + EXPECT_EQ(jacobian[1][4][0], 100.0 + 2.4 * 110.0 * 1.3); + EXPECT_EQ(jacobian[0][4][2], 100.0 + 2.4 * 10.0 * 0.1); // quuz -> baz + EXPECT_EQ(jacobian[1][4][2], 100.0 + 2.4 * 110.0 * 1.1); +} + +template class MatrixPolicy, class ProcessSetPolicy> +void testRandomSystem( + std::size_t n_cells, + std::size_t n_reactions, + std::size_t n_species, + const std::function&, const micm::State&)> create_set) +{ + auto get_n_react = std::bind(std::uniform_int_distribution<>(0, 3), std::default_random_engine()); + auto get_n_product = std::bind(std::uniform_int_distribution<>(0, 10), std::default_random_engine()); + auto get_species_id = std::bind(std::uniform_int_distribution<>(0, n_species - 1), std::default_random_engine()); + auto get_double = std::bind(std::lognormal_distribution(-2.0, 4.0), std::default_random_engine()); + + std::vector species{}; + std::vector species_names{}; + for (std::size_t i = 0; i < n_species; ++i) + { + species.push_back(micm::Species{ std::to_string(i) }); + species_names.push_back(std::to_string(i)); + } + micm::Phase gas_phase{ species }; + micm::State state{ micm::StateParameters{ .state_variable_names_{ species_names }, + .number_of_grid_cells_ = n_cells, + .number_of_rate_constants_ = n_reactions } }; + std::vector processes{}; + for (std::size_t i = 0; i < n_reactions; ++i) + { + auto n_react = get_n_react(); + std::vector reactants{}; + for (std::size_t i_react = 0; i_react < n_react; ++i_react) + { + reactants.push_back({ std::to_string(get_species_id()) }); + } + auto n_product = get_n_product(); + std::vector products{}; + for (std::size_t i_prod = 0; i_prod < n_product; ++i_prod) + { + products.push_back(yields(std::to_string(get_species_id()), 1.2)); + } + auto proc = micm::Process(micm::Process::create().reactants(reactants).products(products).phase(gas_phase)); + processes.push_back(proc); + } + ProcessSetPolicy set = create_set(processes, state); + + for (auto& elem : state.variables_.AsVector()) + elem = get_double(); + + MatrixPolicy rate_constants{ n_cells, n_reactions }; + for (auto& elem : rate_constants.AsVector()) + elem = get_double(); + MatrixPolicy forcing{ n_cells, n_species, 1000.0 }; + + set.template AddForcingTerms(rate_constants, state.variables_, forcing); +} \ No newline at end of file