From 263af0741638f6d61b90c7a4dcee7e697410e05a Mon Sep 17 00:00:00 2001 From: Montek Thind Date: Tue, 24 Sep 2024 15:37:39 -0700 Subject: [PATCH] draft of 621 --- include/micm/solver/solver.hpp | 7 +++++++ test/unit/cuda/solver/test_cuda_rosenbrock.cpp | 4 ++-- test/unit/jit/solver/test_jit_rosenbrock.cpp | 4 ++-- test/unit/solver/test_rosenbrock.cpp | 11 +++++------ 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/include/micm/solver/solver.hpp b/include/micm/solver/solver.hpp index f8fc816d0..a387866c0 100644 --- a/include/micm/solver/solver.hpp +++ b/include/micm/solver/solver.hpp @@ -63,6 +63,13 @@ namespace micm return solver_.Solve(time_step, state); } + // Overloaded Solve function to change parameters + SolverResult Solve(double time_step, StatePolicy& state, RosenbrockSolverParameters params) + { + solver_.parameters_.h_start_ = params.h_start_; + return solver_.Solve(time_step, state); + } + /// @brief Returns the number of grid cells /// @return std::size_t GetNumberOfGridCells() const diff --git a/test/unit/cuda/solver/test_cuda_rosenbrock.cpp b/test/unit/cuda/solver/test_cuda_rosenbrock.cpp index f40d11ed5..825d111c1 100644 --- a/test/unit/cuda/solver/test_cuda_rosenbrock.cpp +++ b/test/unit/cuda/solver/test_cuda_rosenbrock.cpp @@ -284,12 +284,12 @@ TEST(RosenbrockSolver, SingularSystemZeroInBottomRightOfU) // so H needs to be 1 / ( (-k1 - k2) * gamma) // since H is positive we need -k1 -k2 to be positive, hence the smaller, negative value for k1 double H = 1 / ((-k1 - k2) * params.gamma_[0]); - vector_solver.solver_.parameters_.h_start_ = H; + params.h_start_ = H; vector_solver.CalculateRateConstants(vector_state); vector_state.SyncInputsToDevice(); - auto vector_result = vector_solver.Solve(2 * H, vector_state); + auto vector_result = vector_solver.Solve(2 * H, vector_state, params); vector_state.SyncOutputsToHost(); EXPECT_NE(vector_result.stats_.singular_, 0); } diff --git a/test/unit/jit/solver/test_jit_rosenbrock.cpp b/test/unit/jit/solver/test_jit_rosenbrock.cpp index 542bde6f3..4a1d8eaff 100644 --- a/test/unit/jit/solver/test_jit_rosenbrock.cpp +++ b/test/unit/jit/solver/test_jit_rosenbrock.cpp @@ -105,11 +105,11 @@ TEST(JitRosenbrockSolver, SingularSystemZeroInBottomRightOfU) // so H needs to be 1 / ( (-k1 - k2) * gamma) // since H is positive we need -k1 -k2 to be positive, hence the smaller, negative value for k1 double H = 1 / ((-k1 - k2) * params.gamma_[0]); - vector_solver.solver_.parameters_.h_start_ = H; + params.h_start_ = H; vector_solver.CalculateRateConstants(vector_state); - auto vector_result = vector_solver.Solve(2 * H, vector_state); + auto vector_result = vector_solver.Solve(2 * H, vector_state, params); EXPECT_NE(vector_result.stats_.singular_, 0); } diff --git a/test/unit/solver/test_rosenbrock.cpp b/test/unit/solver/test_rosenbrock.cpp index 06f3cd4bf..568b8e6b9 100644 --- a/test/unit/solver/test_rosenbrock.cpp +++ b/test/unit/solver/test_rosenbrock.cpp @@ -17,7 +17,7 @@ void testNormalizedErrorDiff(SolverBuilderPolicy builder, std::size_t number_of_ { builder = getSolver(builder); auto solver = builder.SetNumberOfGridCells(number_of_grid_cells).Build(); - std::vector atol = solver.solver_.parameters_.absolute_tolerance_; + std::vector atol = solver.solver_.parameters_.absolute_tolerance_; double rtol = solver.solver_.parameters_.relative_tolerance_; auto state = solver.GetState(); @@ -172,18 +172,17 @@ TEST(RosenbrockSolver, SingularSystemZeroInBottomRightOfU) // alpha is 1 / (H * gamma), where H is the time step and gamma is the gamma value from // the rosenbrock paramters // so H needs to be 1 / ( (-k1 - k2) * gamma) - // since H is positive we need -k1 -k2 to be positive, hence the smaller, negative value for k1 + // since H is positive we need -k1 -k2 to be positive, hence the smaller, negative value for k1 double H = 1 / ((-k1 - k2) * params.gamma_[0]); - standard_solver.solver_.parameters_.h_start_ = H; - vector_solver.solver_.parameters_.h_start_ = H; + params.h_start_ = H; standard_solver.CalculateRateConstants(standard_state); vector_solver.CalculateRateConstants(vector_state); - auto standard_result = standard_solver.Solve(2 * H, standard_state); + auto standard_result = standard_solver.Solve(2 * H, standard_state, params); EXPECT_NE(standard_result.stats_.singular_, 0); - auto vector_result = vector_solver.Solve(2 * H, vector_state); + auto vector_result = vector_solver.Solve(2 * H, vector_state, params); EXPECT_NE(vector_result.stats_.singular_, 0); }