Skip to content

Commit

Permalink
draft jit jacobian functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mattldawson committed Aug 23, 2023
1 parent 2d16fe7 commit 11c65f7
Showing 1 changed file with 56 additions and 9 deletions.
65 changes: 56 additions & 9 deletions include/micm/process/jit_process_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <micm/jit/jit_compiler.hpp>
#include <micm/jit/jit_function.hpp>
#include <micm/process/process_set.hpp>
#include <micm/util/sparse_matrix_vector_ordering.hpp>
#include <micm/util/vector_matrix.hpp>

namespace micm
Expand All @@ -19,6 +20,8 @@ namespace micm
std::shared_ptr<JitCompiler> compiler_;
llvm::orc::ResourceTrackerSP forcing_function_resource_tracker_;
void (*forcing_function_)(const double *, const double *, double *);
llvm::orc::ResourceTrackerSP jacobian_function_resource_tracker_;
void (*jacobian_function_)(const double *, const double *, double *);

public:
/// @brief Create a JITed process set calculator for a given set of processes
Expand All @@ -33,6 +36,11 @@ namespace micm

~JitProcessSet();

/// @brief Sets the indices for each non-zero Jacobian element in the underlying vector
/// @param matrix The sparse matrix used for the Jacobian
template<typename OrderingPolicy>
void SetJacobianFlatIds(const SparseMatrix<double, OrderingPolicy> &matrix);

/// @brief Add forcing terms for the set of processes for the current conditions
/// @param rate_constants Current values for the process rate constants (grid cell, process)
/// @param state_variables Current state variable values (grid cell, state variable)
Expand All @@ -42,10 +50,24 @@ namespace micm
const MatrixPolicy<double> &rate_constants,
const MatrixPolicy<double> &state_variables,
MatrixPolicy<double> &forcing) const;

#if 0
/// @brief Add Jacobian terms for the set of processes for the current conditions
/// @param rate_constants Current values for the process rate constants (grid cell, process)
/// @param state_variables Current state variable values (grid cell, state variable)
/// @param jacobian Jacobian matrix for the system (grid cell, dependent variable, independent variable)
template<template<class> class MatrixPolicy, template<class> class SparseMatrixPolicy>
void AddJacobianTerms(
const MatrixPolicy<double> &rate_constants,
const MatrixPolicy<double> &state_variables,
SparseMatrixPolicy<double> &jacobian) const;
#endif
private:
/// @brief Generate a function to calculate forcing terms
void GenerateForcingFunction();
/// @param matrix The matrix that will hold the forcing terms
void GenerateForcingFunction(const VectorMatrix<double, L> &matrix);
/// @brief Generate a function to calculate Jacobian terms
/// @param matrix The sparse matrix that will hold the Jacobian
void GenerateJacobianFunction(const SparseMatrix<double, SparseMatrixVectorOrdering<L>> &matrix);
};

template<std::size_t L>
Expand All @@ -58,16 +80,11 @@ namespace micm
compiler_(compiler)
{
MatrixPolicy<double> test_matrix;
if (test_matrix.GroupVectorSize() != L)
{
std::cerr << "Vector matrix group size invalid for JitProcessSet";
std::exit(micm::ExitCodes::InvalidMatrixDimension);
}
this->GenerateForcingFunction();
this->GenerateForcingFunction(test_matrix);
}

template<std::size_t L>
void JitProcessSet<L>::GenerateForcingFunction()
void JitProcessSet<L>::GenerateForcingFunction(const VectorMatrix<double, L> &matrix)
{
JitFunction func = JitFunction::create(compiler_)
.name("add_forcing_terms")
Expand Down Expand Up @@ -159,6 +176,19 @@ namespace micm
forcing_function_resource_tracker_ = target.first;
}

template<std::size_t L>
template<typename OrderingPolicy>
inline void JitProcessSet<L>::SetJacobianFlatIds(const SparseMatrix<double, OrderingPolicy> &matrix)
{
ProcessSet::SetJacobianFlatIds(matrix);
GenerateJacobianFunction(matrix);
}

template<std::size_t L>
void JitProcessSet<L>::GenerateJacobianFunction(const SparseMatrix<double, SparseMatrixVectorOrdering<L>> &matrix)
{
}

template<std::size_t L>
JitProcessSet<L>::~JitProcessSet()
{
Expand All @@ -167,6 +197,11 @@ namespace micm
llvm::ExitOnError exit_on_error;
exit_on_error(forcing_function_resource_tracker_->remove());
}
if (jacobian_function_resource_tracker_)
{
llvm::ExitOnError exit_on_error;
exit_on_error(jacobian_function_resource_tracker_->remove());
}
}

template<std::size_t L>
Expand All @@ -179,4 +214,16 @@ namespace micm
forcing_function_(rate_constants.AsVector().data(), state_variables.AsVector().data(), forcing.AsVector().data());
}

#if 0
template<std::size_t L>
template<template<class> class MatrixPolicy, template<class> class SparseMatrixPolicy>
void JitProcessSet<L>::AddJacobianTerms(
const MatrixPolicy<double> &rate_constants,
const MatrixPolicy<double> &state_variables,
SparseMatrixPolicy<double> &jacobian) const
{
jacobian_function_(rate_constants.AsVector().data(), state_variables.AsVector().data(), jacobian.AsVector().data());
}
#endif

} // namespace micm

0 comments on commit 11c65f7

Please sign in to comment.