Skip to content

Commit

Permalink
move calculation of rate constants out of solver
Browse files Browse the repository at this point in the history
  • Loading branch information
mattldawson committed Jun 6, 2024
1 parent 7d52e60 commit e42fec3
Show file tree
Hide file tree
Showing 35 changed files with 73 additions and 40 deletions.
1 change: 1 addition & 0 deletions examples/example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ int main(const int argc, const char* argv[])

while (elapsed_solve_time < time_step)
{
solver.CalculateRateConstants(state);
auto result = solver.Solve(time_step - elapsed_solve_time, state);
elapsed_solve_time = result.final_time_;
if (result.state_ != SolverState::Converged)
Expand Down
1 change: 1 addition & 0 deletions examples/profile_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ int Run(const char* filepath, const char* initial_conditions, const std::string&
MICM_PROFILE_BEGIN_SESSION("Runtime", "Profile-Runtime-" + matrix_ordering_type + ".json");
while (elapsed_solve_time < time_step)
{
solver.CalculateRateConstants(state);
auto result = solver.Solve(time_step - elapsed_solve_time, state);
elapsed_solve_time = result.final_time_;
state.variables_ = result.result_;
Expand Down
10 changes: 5 additions & 5 deletions include/micm/process/process.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ namespace micm
std::unique_ptr<RateConstant> rate_constant_;
Phase phase_;

/// @brief Update the solver state rate constants
/// @brief Recalculate the rate constants for each process for the current state
/// @param processes The set of processes being solved
/// @param state The solver state to update
template<class DenseMatrixPolicy, class SparseMatrixPolicy>
requires(!VectorizableDense<DenseMatrixPolicy>) static void UpdateState(
requires(!VectorizableDense<DenseMatrixPolicy>) static void CalculateRateConstants(
const std::vector<Process>& processes,
State<DenseMatrixPolicy, SparseMatrixPolicy>& state);
template<class DenseMatrixPolicy, class SparseMatrixPolicy>
requires(VectorizableDense<DenseMatrixPolicy>) static void UpdateState(
requires(VectorizableDense<DenseMatrixPolicy>) static void CalculateRateConstants(
const std::vector<Process>& processes,
State<DenseMatrixPolicy, SparseMatrixPolicy>& state);

Expand Down Expand Up @@ -149,7 +149,7 @@ namespace micm
};

template<class DenseMatrixPolicy, class SparseMatrixPolicy>
requires(!VectorizableDense<DenseMatrixPolicy>) void Process::UpdateState(
requires(!VectorizableDense<DenseMatrixPolicy>) void Process::CalculateRateConstants(
const std::vector<Process>& processes,
State<DenseMatrixPolicy, SparseMatrixPolicy>& state)
{
Expand All @@ -174,7 +174,7 @@ namespace micm
}

template<class DenseMatrixPolicy, class SparseMatrixPolicy>
requires(VectorizableDense<DenseMatrixPolicy>) void Process::UpdateState(
requires(VectorizableDense<DenseMatrixPolicy>) void Process::CalculateRateConstants(
const std::vector<Process>& processes,
State<DenseMatrixPolicy, SparseMatrixPolicy>& state)
{
Expand Down
7 changes: 2 additions & 5 deletions include/micm/solver/backward_euler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ namespace micm
LinearSolverPolicy linear_solver_;
ProcessSetPolicy process_set_;
std::vector<std::size_t> jacobian_diagonal_elements_;
std::vector<micm::Process> processes_;

public:
/// @brief Solver parameters typename
Expand All @@ -48,13 +47,11 @@ namespace micm
BackwardEulerSolverParameters parameters,
LinearSolverPolicy&& linear_solver,
ProcessSetPolicy&& process_set,
auto& jacobian,
std::vector<micm::Process>& processes)
auto& jacobian)
: parameters_(parameters),
linear_solver_(std::move(linear_solver)),
process_set_(std::move(process_set)),
jacobian_diagonal_elements_(jacobian.DiagonalIndices(0)),
processes_(processes)
jacobian_diagonal_elements_(jacobian.DiagonalIndices(0))
{
}

Expand Down
2 changes: 0 additions & 2 deletions include/micm/solver/backward_euler.inl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ namespace micm
auto Yn1 = state.variables_;
auto forcing = state.variables_;

Process::UpdateState(processes_, state);

while (t < time_step)
{
bool converged = false;
Expand Down
9 changes: 3 additions & 6 deletions include/micm/solver/cuda_rosenbrock.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,21 @@ namespace micm
devstruct_.errors_output_ = nullptr;
};

/// @brief Builds a CUDA Rosenbrock solver for the given system, processes, and solver parameters
/// @brief Builds a CUDA Rosenbrock solver for the given system and solver parameters
/// @param parameters Solver parameters
/// @param linear_solver Linear solver
/// @param process_set Process set
/// @param jacobian Jacobian matrix
/// @param processes Vector of processes
CudaRosenbrockSolver(
RosenbrockSolverParameters parameters,
LinearSolverPolicy&& linear_solver,
ProcessSetPolicy&& process_set,
auto& jacobian,
std::vector<Process>& processes)
auto& jacobian)
: RosenbrockSolver<ProcessSetPolicy, LinearSolverPolicy>(
parameters,
std::move(linear_solver),
std::move(process_set),
jacobian,
processes)
jacobian)
{
CudaRosenbrockSolverParam hoststruct;
// jacobian.GroupVectorSize() is the same as the number of grid cells for the CUDA implementation
Expand Down
12 changes: 4 additions & 8 deletions include/micm/solver/jit_rosenbrock.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <micm/jit/jit_function.hpp>
#include <micm/process/jit_process_set.hpp>
#include <micm/solver/jit_linear_solver.hpp>
#include <micm/solver/jit_solver_parameters.hpp>
#include <micm/solver/rosenbrock.hpp>
#include <micm/solver/rosenbrock_solver_parameters.hpp>
#include <micm/util/random_string.hpp>
Expand All @@ -31,9 +30,9 @@

namespace micm
{
struct JitRosenbrockSolverParameters;

/// @brief A Rosenbrock solver with JIT-compiled optimizations

template<class ProcessSetPolicy, class LinearSolverPolicy>
class JitRosenbrockSolver : public RosenbrockSolver<ProcessSetPolicy, LinearSolverPolicy>
{
Expand Down Expand Up @@ -65,24 +64,21 @@ namespace micm
return *this;
}

/// @brief Builds a Rosenbrock solver for the given system, processes, and solver parameters
/// @brief Builds a Rosenbrock solver for the given system and solver parameters
/// @param parameters Solver parameters
/// @param linear_solver Linear solver
/// @param process_set Process set
/// @param jacobian Jacobian matrix
/// @param processes Vector of processes
JitRosenbrockSolver(
RosenbrockSolverParameters parameters,
LinearSolverPolicy linear_solver,
ProcessSetPolicy process_set,
auto& jacobian,
std::vector<Process>& processes)
auto& jacobian)
: RosenbrockSolver<ProcessSetPolicy, LinearSolverPolicy>(
parameters,
std::move(linear_solver),
std::move(process_set),
jacobian,
processes)
jacobian)
{
this->GenerateAlphaMinusJacobian(jacobian);
}
Expand Down
7 changes: 2 additions & 5 deletions include/micm/solver/rosenbrock.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ namespace micm
LinearSolverPolicy linear_solver_;
ProcessSetPolicy process_set_;
std::vector<std::size_t> jacobian_diagonal_elements_;
std::vector<Process> processes_;

static constexpr double DELTA_MIN = 1.0e-6;

Expand All @@ -68,13 +67,11 @@ namespace micm
RosenbrockSolverParameters parameters,
LinearSolverPolicy linear_solver,
ProcessSetPolicy process_set,
auto& jacobian,
std::vector<Process>& processes)
auto& jacobian)
: parameters_(parameters),
linear_solver_(std::move(linear_solver)),
process_set_(std::move(process_set)),
jacobian_diagonal_elements_(jacobian.DiagonalIndices(0)),
processes_(processes)
jacobian_diagonal_elements_(jacobian.DiagonalIndices(0))
{
}

Expand Down
2 changes: 0 additions & 2 deletions include/micm/solver/rosenbrock.inl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ namespace micm

SolverStats stats;

Process::UpdateState(processes_, state);

K.reserve(parameters_.stages_);
for (std::size_t i = 0; i < parameters_.stages_; ++i)
K.emplace_back(num_rows, num_cols, 0.0);
Expand Down
13 changes: 11 additions & 2 deletions include/micm/solver/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <micm/solver/solver_result.hpp>
#include <micm/process/process.hpp>

namespace micm
{
Expand All @@ -16,6 +17,7 @@ namespace micm
std::size_t number_of_species_;
std::size_t number_of_reactions_;
StateParameters state_parameters_;
std::vector<micm::Process> processes_;

public:
SolverPolicy solver_;
Expand All @@ -25,12 +27,14 @@ namespace micm
StateParameters state_parameters,
std::size_t number_of_grid_cells,
std::size_t number_of_species,
std::size_t number_of_reactions)
std::size_t number_of_reactions,
std::vector<micm::Process> processes)
: solver_(std::move(solver)),
number_of_grid_cells_(number_of_grid_cells),
number_of_species_(number_of_species),
number_of_reactions_(number_of_reactions),
state_parameters_(state_parameters)
state_parameters_(state_parameters),
processes_(std::move(processes))
{
}

Expand Down Expand Up @@ -64,6 +68,11 @@ namespace micm
{
return StatePolicy(state_parameters_);
}

void CalculateRateConstants(StatePolicy& state)
{
Process::CalculateRateConstants(processes_, state);
}
};


Expand Down
5 changes: 3 additions & 2 deletions include/micm/solver/solver_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,12 @@ namespace micm

return Solver<SolverPolicy, State<DenseMatrixPolicy, SparseMatrixPolicy>>(
SolverPolicy(
this->options_, std::move(linear_solver), std::move(process_set), jacobian, this->reactions_),
this->options_, std::move(linear_solver), std::move(process_set), jacobian),
state_parameters,
this->number_of_grid_cells_,
number_of_species,
this->reactions_.size());
this->reactions_.size(),
this->reactions_);
}

} // namespace micm
13 changes: 13 additions & 0 deletions test/integration/analytical_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ void test_analytical_troe(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -303,6 +304,7 @@ void test_analytical_stiff_troe(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -412,6 +414,7 @@ void test_analytical_photolysis(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -547,6 +550,7 @@ void test_analytical_stiff_photolysis(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -668,6 +672,7 @@ void test_analytical_ternary_chemical_activation(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -811,6 +816,7 @@ void test_analytical_stiff_ternary_chemical_activation(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -925,6 +931,7 @@ void test_analytical_tunneling(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -1055,6 +1062,7 @@ void test_analytical_stiff_tunneling(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -1157,6 +1165,7 @@ void test_analytical_arrhenius(BuilderPolicy& builder)
times.push_back(0);
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -1289,6 +1298,7 @@ void test_analytical_stiff_arrhenius(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -1417,6 +1427,7 @@ void test_analytical_branched(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -1577,6 +1588,7 @@ void test_analytical_stiff_branched(BuilderPolicy& builder)
for (size_t i_time = 1; i_time < nsteps; ++i_time)
{
times.push_back(time_step);
solver.CalculateRateConstants(state);
// Model results
auto result = solver.Solve(time_step, state);
EXPECT_EQ(result.state_, (micm::SolverState::Converged));
Expand Down Expand Up @@ -1708,6 +1720,7 @@ void test_analytical_robertson(BuilderPolicy& builder)
{
double solve_time = time_step + i_time * time_step;
times.push_back(solve_time);
solver.CalculateRateConstants(state);
// Model results
double actual_solve = 0;
while (actual_solve < time_step)
Expand Down
4 changes: 4 additions & 0 deletions test/integration/analytical_rosenbrock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ TEST(AnalyticalExamples, Oregonator)
{
double solve_time = time_step + i_time * time_step;
times.push_back(solve_time);
solver.CalculateRateConstants(state);
// Model results
double actual_solve = 0;
while (actual_solve < time_step)
Expand Down Expand Up @@ -329,6 +330,7 @@ TEST(AnalyticalExamples, Oregonator2)
{
double solve_time = time_step + i_time * time_step;
times.push_back(solve_time);
solver.CalculateRateConstants(state);
// Model results
double actual_solve = 0;
while (actual_solve < time_step)
Expand Down Expand Up @@ -471,6 +473,7 @@ TEST(AnalyticalExamples, HIRES)
{
double solve_time = time_step + i_time * time_step;
times.push_back(solve_time);
solver.CalculateRateConstants(state);
// Model results
double actual_solve = 0;
while (actual_solve < time_step)
Expand Down Expand Up @@ -570,6 +573,7 @@ TEST(AnalyticalExamples, E5)
{
double solve_time = time_step + i_time * time_step;
times.push_back(solve_time);
solver.CalculateRateConstants(state);
// Model results
double actual_solve = 0;
while (actual_solve < time_step)
Expand Down
1 change: 1 addition & 0 deletions test/integration/analytical_surface_rxn_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ void test_analytical_surface_rxn(BuilderPolicy& builder)
for (int i = 1; i <= nstep; ++i)
{
double elapsed_solve_time = 0;
solver.CalculateRateConstants(state);

// first iteration
auto result = solver.Solve(time_step - elapsed_solve_time, state);
Expand Down
Loading

0 comments on commit e42fec3

Please sign in to comment.