Skip to content

Commit

Permalink
JITed forcing function (#179)
Browse files Browse the repository at this point in the history
* add llvm to cmake scripts and add Dockerfile

* add JIT compiler

* add scalar JIT type tests

* add pointer type JIT tests

* add loops to jit function class

* add JIT function tests

* fix JIT loop bug

* Update include/micm/jit/jit_function.hpp

Co-authored-by: Kyle Shores <kyle.shores44@gmail.com>

* Update cmake/test_util.cmake

Co-authored-by: Kyle Shores <kyle.shores44@gmail.com>

* add assert on regenerating JIT function

* add JIT forcing function

---------

Co-authored-by: Kyle Shores <kyle.shores44@gmail.com>
  • Loading branch information
mattldawson and K20shores committed Aug 14, 2023
1 parent 3a1011b commit 2b9ad5d
Show file tree
Hide file tree
Showing 5 changed files with 426 additions and 190 deletions.
152 changes: 131 additions & 21 deletions include/micm/process/jit_process_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,170 @@
// SPDX-License-Identifier: Apache-2.0
#pragma once

#include <micm/jit/jit_compiler.hpp>
#include <micm/jit/jit_function.hpp>
#include <micm/process/process_set.hpp>
#include <micm/util/vector_matrix.hpp>

namespace micm
{

/// @brief JIT-compiled solver function calculators for a collection of processes
/// The template parameter is the number of grid cells to solve simultaneously
template<std::size_t L>
class JitProcessSet : public ProcessSet<VectorMatrix>
class JitProcessSet : public ProcessSet
{
std::shared_ptr<JitCompiler> compiler_;
llvm::orc::ResourceTrackerSP resource_tracker_;
void (*forcing_function_)(double*, double*, double*);
void (*forcing_function_)(const double *, const double *, double *);

public:
/// @brief Create a JITed process set calculator for a given set of processes
/// @param compiler JIT compiler
/// @param processes Processes to create calculator for
/// @param state Solver state
template<template<class> class MatrixPolicy>
JitProcessSet(
std::shared_ptr<JitCompiler> compiler,
const std::vector<Process>& processes,
const State<VectorMatrix>& state);
const std::vector<Process> &processes,
const State<MatrixPolicy> &state);

~JitProcessSet();

/// @brief Add forcing terms for the set of processes for the current conditions
/// @param rate_constants Current values for the process rate constants (grid cell, process)
/// @param state_variables Current state variable values (grid cell, state variable)
/// @param forcing Forcing terms for each state variable (grid cell, state variable)
template<template<class> class MatrixPolicy>
void AddForcingTerms(
const VectorMatrix<double>& rate_constants,
const VectorMatrix<double>& state_variables,
VectorMatrix<double>& forcing) const;
const MatrixPolicy<double> &rate_constants,
const MatrixPolicy<double> &state_variables,
MatrixPolicy<double> &forcing) const;
};

inline JitProcessSet::JitProcessSet(
template<std::size_t L>
template<template<class> class MatrixPolicy>
inline JitProcessSet<L>::JitProcessSet(
std::shared_ptr<JitCompiler> compiler,
const std::vector<Process>& processes,
const State<VectorMatrix>& state)
: ProcessSet<VectorMatrix>(processes, state),
const std::vector<Process> &processes,
const State<MatrixPolicy> &state)
: ProcessSet(processes, state),
compiler_(compiler)
{
JitFunction func = JitFunction::create(compiler.get())
MatrixPolicy<double> test_matrix;
if (test_matrix.GroupVectorSize() != L)
{
std::cerr << "Vector matrix group size invalid for JitProcessSet";
std::exit(micm::ExitCodes::InvalidMatrixDimension);
}
JitFunction func = JitFunction::create(compiler)
.name("add_forcing_terms")
.arguments({ { "rate constants", JitType::Double },
{ "state variables", JitType::Double },
{ "forcing", JitType::Double } }),
.return_type(JitType::Void);
.arguments({ { "rate constants", JitType::DoublePtr },
{ "state variables", JitType::DoublePtr },
{ "forcing", JitType::DoublePtr } })
.return_type(JitType::Void);
llvm::Type *double_type = func.GetType(JitType::Double);
llvm::Value *zero = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, 0));
llvm::Type *rate_array_type = llvm::ArrayType::get(double_type, L);
llvm::AllocaInst *rate_array = func.builder_->CreateAlloca(
rate_array_type, llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, 1)), "rate_array");

auto react_ids = reactant_ids_.begin();
auto prod_ids = product_ids_.begin();
auto yields = yields_.begin();
for (std::size_t i_rxn = 0; i_rxn < number_of_reactants_.size(); ++i_rxn)
{
llvm::Value* rc_start = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, i_rxn * L));
llvm::Value* rc_end = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, i_rxn * L + L));
llvm::ArrayType* rate_arr = llvm::ArrayType::get(func.GetType(JitType::Int32), ) auto loop =
func.StartLoop("rate constant loop", rc_start, rc_end);
llvm::Value* rate = func.GetArrayElement(func.arguments_[0], index_list, micm::JitType::Double);
llvm::Value *rc_start = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, i_rxn * L));

// save rate constant in rate array for each grid cell
auto loop = func.StartLoop("rate constant", 0, L);
llvm::Value *ptr_index[1];
ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, rc_start);
llvm::Value *rate_const = func.GetArrayElement(func.arguments_[0], ptr_index, JitType::Double);
llvm::Value *array_index[2];
array_index[0] = zero;
array_index[1] = loop.index_;
llvm::Value *rate_ptr = func.builder_->CreateInBoundsGEP(rate_array_type, rate_array, array_index);
func.builder_->CreateStore(rate_const, rate_ptr);
func.EndLoop(loop);

// rates[i_cell] *= reactant_concentration for each reactant
for (std::size_t i_react = 0; i_react < number_of_reactants_[i_rxn]; ++i_react)
{
loop = func.StartLoop("rate calc", 0, L);
llvm::Value *react_id = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, react_ids[i_react] * L));
ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, react_id);
llvm::Value *react_conc = func.GetArrayElement(func.arguments_[1], ptr_index, JitType::Double);
array_index[1] = loop.index_;
rate_ptr = func.builder_->CreateInBoundsGEP(rate_array_type, rate_array, array_index);
llvm::Value *rate = func.builder_->CreateLoad(double_type, rate_ptr, "rate");
rate = func.builder_->CreateFMul(rate, react_conc, "rate");
func.builder_->CreateStore(rate, rate_ptr);
func.EndLoop(loop);
}

// set forcing for each reactant f[i_react][i_cell] -= rate[i_cell]
for (std::size_t i_react = 0; i_react < number_of_reactants_[i_rxn]; ++i_react)
{
loop = func.StartLoop("reactant forcing", 0, L);
llvm::Value *react_id = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, react_ids[i_react] * L));
ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, react_id);
llvm::Value *react_forcing_ptr = func.builder_->CreateGEP(double_type, func.arguments_[2].ptr_, ptr_index);
llvm::Value *react_forcing = func.builder_->CreateLoad(double_type, react_forcing_ptr, "forcing");
array_index[1] = loop.index_;
rate_ptr = func.builder_->CreateInBoundsGEP(rate_array_type, rate_array, array_index);
llvm::Value *rate = func.builder_->CreateLoad(double_type, rate_ptr, "rate");
react_forcing = func.builder_->CreateFSub(react_forcing, rate, "forcing");
func.builder_->CreateStore(react_forcing, react_forcing_ptr);
func.EndLoop(loop);
}

// set forcing for each product f[i_prod][i_cell] += yield * rate[i_cell]
for (std::size_t i_prod = 0; i_prod < number_of_products_[i_rxn]; ++i_prod)
{
loop = func.StartLoop("product forcing", 0, L);
llvm::Value *prod_id = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, prod_ids[i_prod] * L));
ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, prod_id);
llvm::Value *prod_forcing_ptr = func.builder_->CreateGEP(double_type, func.arguments_[2].ptr_, ptr_index);
llvm::Value *prod_forcing = func.builder_->CreateLoad(double_type, prod_forcing_ptr, "forcing");
array_index[1] = loop.index_;
rate_ptr = func.builder_->CreateInBoundsGEP(rate_array_type, rate_array, array_index);
llvm::Value *rate = func.builder_->CreateLoad(double_type, rate_ptr, "rate");
llvm::Value *yield = llvm::ConstantFP::get(*(func.context_), llvm::APFloat(yields[i_prod]));
rate = func.builder_->CreateFMul(rate, yield, "rate_yield");
prod_forcing = func.builder_->CreateFAdd(prod_forcing, rate, "forcing");
func.builder_->CreateStore(prod_forcing, prod_forcing_ptr);
func.EndLoop(loop);
}
react_ids += number_of_reactants_[i_rxn];
prod_ids += number_of_products_[i_rxn];
yields += number_of_products_[i_rxn];
}
func.builder_->CreateRetVoid();

auto target = func.Generate();
forcing_function_ = (void (*)(const double *, const double *, double *))(intptr_t)target.second;
resource_tracker_ = target.first;
}

template<std::size_t L>
JitProcessSet<L>::~JitProcessSet()
{
if (resource_tracker_)
{
llvm::ExitOnError exit_on_error;
exit_on_error(resource_tracker_->remove());
}
}

template<std::size_t L>
template<template<class> class MatrixPolicy>
void JitProcessSet<L>::AddForcingTerms(
const MatrixPolicy<double> &rate_constants,
const MatrixPolicy<double> &state_variables,
MatrixPolicy<double> &forcing) const
{
forcing_function_(rate_constants.AsVector().data(), state_variables.AsVector().data(), forcing.AsVector().data());
}

} // namespace micm
4 changes: 4 additions & 0 deletions test/unit/process/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ endif()

if(ENABLE_OPENACC)
create_standard_test(NAME openacc_process_set SOURCES test_openacc_process_set.cpp LIBRARIES musica::micm_openacc)
endif()

if(ENABLE_LLVM)
create_standard_test(NAME jit_process_set SOURCES test_jit_process_set.cpp)
endif()
78 changes: 78 additions & 0 deletions test/unit/process/test_jit_process_set.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <gtest/gtest.h>

#include <micm/process/jit_process_set.hpp>
#include <micm/util/sparse_matrix.hpp>
#include <micm/util/sparse_matrix_vector_ordering.hpp>
#include <micm/util/vector_matrix.hpp>
#include <random>

#include "test_process_set_policy.hpp"

template<class T>
using Group2VectorMatrix = micm::VectorMatrix<T, 2>;
template<class T>
using Group2000VectorMatrix = micm::VectorMatrix<T, 2000>;
template<class T>
using Group3000VectorMatrix = micm::VectorMatrix<T, 3000>;
template<class T>
using Group4000VectorMatrix = micm::VectorMatrix<T, 4000>;

template<class T>
using Group2SparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<2>>;

TEST(JitProcessSet, VectorMatrix)
{
auto jit{ micm::JitCompiler::create() };
if (auto err = jit.takeError())
{
llvm::logAllUnhandledErrors(std::move(err), llvm::errs(), "[JIT Error]");
EXPECT_TRUE(false);
}
testProcessSet<Group2VectorMatrix, Group2SparseVectorMatrix, micm::JitProcessSet<2>>(
[&](const std::vector<micm::Process>& processes,
const micm::State<Group2VectorMatrix>& state) -> micm::JitProcessSet<2> {
return micm::JitProcessSet<2>{ jit.get(), processes, state };
});
}

TEST(RandomJitProcessSet, VectorMatrix)
{
auto jit{ micm::JitCompiler::create() };
if (auto err = jit.takeError())
{
llvm::logAllUnhandledErrors(std::move(err), llvm::errs(), "[JIT Error]");
EXPECT_TRUE(false);
}
testRandomSystem<Group2000VectorMatrix, micm::JitProcessSet<2000>>(
2000,
20,
30,
[&](const std::vector<micm::Process>& processes,
const micm::State<Group2000VectorMatrix>& state) -> micm::JitProcessSet<2000> {
return micm::JitProcessSet<2000>{ jit.get(), processes, state };
});
testRandomSystem<Group3000VectorMatrix, micm::JitProcessSet<3000>>(
3000,
50,
40,
[&](const std::vector<micm::Process>& processes,
const micm::State<Group3000VectorMatrix>& state) -> micm::JitProcessSet<3000> {
return micm::JitProcessSet<3000>{ jit.get(), processes, state };
});
testRandomSystem<Group3000VectorMatrix, micm::JitProcessSet<3000>>(
3000,
30,
20,
[&](const std::vector<micm::Process>& processes,
const micm::State<Group3000VectorMatrix>& state) -> micm::JitProcessSet<3000> {
return micm::JitProcessSet<3000>{ jit.get(), processes, state };
});
testRandomSystem<Group4000VectorMatrix, micm::JitProcessSet<4000>>(
4000,
100,
80,
[&](const std::vector<micm::Process>& processes,
const micm::State<Group4000VectorMatrix>& state) -> micm::JitProcessSet<4000> {
return micm::JitProcessSet<4000>{ jit.get(), processes, state };
});
}
Loading

0 comments on commit 2b9ad5d

Please sign in to comment.