Skip to content

Commit

Permalink
Update the CUDA implementation of CudaLuDecomposition with the CUDA m…
Browse files Browse the repository at this point in the history
…atrix class (#464)

* update cuda matrix class for cuda lu decomposition

* more debugs

* fix compilation bug for calling base class function

* address PR comments

---------

Co-authored-by: Jian Sun <sunjian@ucar.edu>
  • Loading branch information
sjsprecious and sjsprecious committed Apr 19, 2024
1 parent abf001c commit 9a958e3
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 104 deletions.
5 changes: 4 additions & 1 deletion include/micm/solver/cuda_lu_decomposition.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ namespace micm
{
/// This is the host function that will call the CUDA kernel
/// to perform LU decomposition on the device
std::chrono::nanoseconds DecomposeKernelDriver(CudaSparseMatrixParam& sparseMatrix, const LuDecomposeParam& devstruct);
void DecomposeKernelDriver(const CudaMatrixParam& A_param,
CudaMatrixParam& L_param,
CudaMatrixParam& U_param,
const LuDecomposeParam& devstruct);

/// This is the function that will copy the constant data
/// members of class "CudaLuDecomposition" to the device;
Expand Down
37 changes: 21 additions & 16 deletions include/micm/solver/cuda_lu_decomposition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <micm/solver/cuda_lu_decomposition.cuh>
#include <micm/solver/lu_decomposition.hpp>
#include <micm/util/cuda_param.hpp>
#include <micm/util/cuda_sparse_matrix.hpp>
#include <stdexcept>

namespace micm
Expand Down Expand Up @@ -73,28 +74,32 @@ namespace micm
/// L is the lower triangular matrix created by decomposition
/// U is the upper triangular matrix created by decomposition
template<typename T, template<class> typename SparseMatrixPolicy>
requires VectorizableSparse<SparseMatrixPolicy<T>> std::chrono::nanoseconds
Decompose(const SparseMatrixPolicy<T>& A, SparseMatrixPolicy<T>& L, SparseMatrixPolicy<T>& U)
requires(CudaSparseMatrices<SparseMatrixPolicy<T>> && VectorizableSparse<SparseMatrixPolicy<T>>)
void Decompose(const SparseMatrixPolicy<T>& A, SparseMatrixPolicy<T>& L, SparseMatrixPolicy<T>& U)
const;

template<typename T, template<class> typename SparseMatrixPolicy>
requires(!CudaSparseMatrices<SparseMatrixPolicy<T>>)
void Decompose(const SparseMatrixPolicy<T>& A, SparseMatrixPolicy<T>& L, SparseMatrixPolicy<T>& U)
const;
};

template<typename T, template<class> class SparseMatrixPolicy>
requires(VectorizableSparse<SparseMatrixPolicy<T>>) std::chrono::nanoseconds
CudaLuDecomposition::Decompose(const SparseMatrixPolicy<T>& A, SparseMatrixPolicy<T>& L, SparseMatrixPolicy<T>& U)
requires(CudaSparseMatrices<SparseMatrixPolicy<T>> && VectorizableSparse<SparseMatrixPolicy<T>>)
void CudaLuDecomposition::Decompose(const SparseMatrixPolicy<T>& A, SparseMatrixPolicy<T>& L, SparseMatrixPolicy<T>& U)
const
{
/// Once the CudaMatrix class is generated, we won't need the following lines any more;
CudaSparseMatrixParam sparseMatrix;
sparseMatrix.A_ = A.AsVector().data();
sparseMatrix.A_size_ = A.AsVector().size();
sparseMatrix.L_ = L.AsVector().data();
sparseMatrix.L_size_ = L.AsVector().size();
sparseMatrix.U_ = U.AsVector().data();
sparseMatrix.U_size_ = U.AsVector().size();
sparseMatrix.n_grids_ = A.Size();
auto L_param = L.AsDeviceParam(); // we need to update lower matrix so it can't be constant and must be an lvalue
auto U_param = U.AsDeviceParam(); // we need to update upper matrix so it can't be constant and must be an lvalue
micm::cuda::DecomposeKernelDriver(A.AsDeviceParam(), L_param, U_param, this->devstruct_);
}

/// Call the "DecomposeKernelDriver" function that invokes the
/// CUDA kernel to perform LU decomposition on the device
return micm::cuda::DecomposeKernelDriver(sparseMatrix, this->devstruct_);

template<typename T, template<class> class SparseMatrixPolicy>
requires(!CudaSparseMatrices<SparseMatrixPolicy<T>>)
void CudaLuDecomposition::Decompose(const SparseMatrixPolicy<T>& A, SparseMatrixPolicy<T>& L, SparseMatrixPolicy<T>& U)
const
{
LuDecomposition::Decompose<T, SparseMatrixPolicy>(A, L, U);
}
} // end of namespace micm
14 changes: 13 additions & 1 deletion include/micm/util/cuda_sparse_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@

namespace micm
{
/// Concept for Cuda Spase matrices
template<typename T>
concept CudaSparseMatrices = requires(T t)
{
{ t.CopyToDevice() }
->std::same_as<void>;
{ t.CopyToHost() }
->std::same_as<void>;
{ t.AsDeviceParam() }
->std::same_as<CudaMatrixParam>;
};

template<class T, class OrderingPolicy>
class CudaSparseMatrix : public SparseMatrix<T, OrderingPolicy>
{
Expand Down Expand Up @@ -120,7 +132,7 @@ namespace micm
CHECK_CUDA_ERROR(micm::cuda::CopyToHost(this->param_, this->data_), "cudaMemcpyDeviceToHost");
}

CudaMatrixParam AsDeviceParam()
CudaMatrixParam AsDeviceParam() const
{
return this->param_;
}
Expand Down
59 changes: 19 additions & 40 deletions src/solver/lu_decomposition.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ namespace micm
namespace cuda
{
/// This is the CUDA kernel that performs LU decomposition on the device
/// Note that passing the reference "LuDecomposeParam&" will pass the
/// compilation but the execution of this CUDA test hangs somehow
__global__ void DecomposeKernel(const double* d_A, double* d_L, double* d_U, LuDecomposeParam devstruct, size_t ngrids)
__global__ void DecomposeKernel(const CudaMatrixParam A_param,
CudaMatrixParam L_param,
CudaMatrixParam U_param,
const LuDecomposeParam devstruct)
{
/// Local device variables
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
// Calculate global thread ID
size_t tid = blockIdx.x * BLOCK_SIZE + threadIdx.x;

// Local device variables
std::pair<size_t, size_t>* d_niLU = devstruct.niLU_;
char* d_do_aik = devstruct.do_aik_;
size_t* d_aik = devstruct.aik_;
Expand All @@ -39,7 +41,12 @@ namespace micm
size_t lki_nkj_offset = 0;
size_t uii_offset = 0;

if (tid < ngrids)
double* d_A = A_param.d_data_;
double* d_L = L_param.d_data_;
double* d_U = U_param.d_data_;
size_t number_of_grid_cells = A_param.number_of_grid_cells_;

if (tid < number_of_grid_cells)
{
// loop through every element in niLU
for (size_t i = 0; i < niLU_size; i++)
Expand Down Expand Up @@ -154,42 +161,14 @@ namespace micm
cudaFree(devstruct.uii_);
}

std::chrono::nanoseconds DecomposeKernelDriver(CudaSparseMatrixParam& sparseMatrix, const LuDecomposeParam& devstruct)
void DecomposeKernelDriver(const CudaMatrixParam& A_param,
CudaMatrixParam& L_param,
CudaMatrixParam& U_param,
const LuDecomposeParam& devstruct)
{
/// Create device pointers
double* d_A;
double* d_L;
double* d_U;

/// Allocate device memory
cudaMalloc(&d_A, sizeof(double) * sparseMatrix.A_size_);
cudaMalloc(&d_L, sizeof(double) * sparseMatrix.L_size_);
cudaMalloc(&d_U, sizeof(double) * sparseMatrix.U_size_);

/// Copy data from host to device
cudaMemcpy(d_A, sparseMatrix.A_, sizeof(double) * sparseMatrix.A_size_, cudaMemcpyHostToDevice);
cudaMemcpy(d_L, sparseMatrix.L_, sizeof(double) * sparseMatrix.L_size_, cudaMemcpyHostToDevice);
cudaMemcpy(d_U, sparseMatrix.U_, sizeof(double) * sparseMatrix.U_size_, cudaMemcpyHostToDevice);

size_t num_block = (sparseMatrix.n_grids_ + BLOCK_SIZE - 1) / BLOCK_SIZE;

/// Call CUDA kernel and measure the execution time
auto startTime = std::chrono::high_resolution_clock::now();
DecomposeKernel<<<num_block, BLOCK_SIZE>>>(d_A, d_L, d_U, devstruct, sparseMatrix.n_grids_);
size_t number_of_blocks = (A_param.number_of_grid_cells_ + BLOCK_SIZE - 1) / BLOCK_SIZE;
DecomposeKernel<<<number_of_blocks, BLOCK_SIZE>>>(A_param, L_param, U_param, devstruct);
cudaDeviceSynchronize();
auto endTime = std::chrono::high_resolution_clock::now();
auto kernel_duration = std::chrono::duration_cast<std::chrono::nanoseconds>(endTime - startTime);

/// Copy the data from device to host
cudaMemcpy(sparseMatrix.L_, d_L, sizeof(double) * sparseMatrix.L_size_, cudaMemcpyDeviceToHost);
cudaMemcpy(sparseMatrix.U_, d_U, sizeof(double) * sparseMatrix.U_size_, cudaMemcpyDeviceToHost);

/// Clean up
cudaFree(d_A);
cudaFree(d_L);
cudaFree(d_U);

return kernel_duration;
} // end of DecomposeKernelDriver
} // end of namespace cuda
} // end of namespace micm
101 changes: 55 additions & 46 deletions test/unit/solver/test_cuda_lu_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <micm/util/cuda_param.hpp>
#include <micm/util/sparse_matrix.hpp>
#include <micm/util/sparse_matrix_vector_ordering.hpp>
#include <micm/util/cuda_sparse_matrix.hpp>
#include <random>
#include <vector>

Expand Down Expand Up @@ -49,76 +50,84 @@ void check_results(
}
}

template<typename T, template<class> class SparseMatrixPolicy>
void gpu_validation(
const SparseMatrixPolicy<T>& gpu_L,
const SparseMatrixPolicy<T>& cpu_L,
const SparseMatrixPolicy<T>& gpu_U,
const SparseMatrixPolicy<T>& cpu_U)
{
size_t L_size = cpu_L.AsVector().size();
size_t U_size = cpu_U.AsVector().size();
std::vector<T> gpu_L_vector = gpu_L.AsVector();
std::vector<T> cpu_L_vector = cpu_L.AsVector();
std::vector<T> gpu_U_vector = gpu_U.AsVector();
std::vector<T> cpu_U_vector = cpu_U.AsVector();
for (int i = 0; i < L_size; i++)
{
EXPECT_EQ(gpu_L_vector[i], cpu_L_vector[i]);
};
for (int j = 0; j < U_size; j++)
{
EXPECT_EQ(gpu_U_vector[j], cpu_U_vector[j]);
};
}

template<template<class> class SparseMatrixPolicy>
template<template<class> class CPUSparseMatrixPolicy, template<class> class GPUSparseMatrixPolicy>
void testRandomMatrix(size_t n_grids)
{
auto gen_bool = std::bind(std::uniform_int_distribution<>(0, 1), std::default_random_engine());
auto get_double = std::bind(std::lognormal_distribution(-2.0, 2.0), std::default_random_engine());

auto builder = SparseMatrixPolicy<double>::create(10).number_of_blocks(n_grids).initial_value(1.0e-30);
auto builder = CPUSparseMatrixPolicy<double>::create(10).number_of_blocks(n_grids).initial_value(1.0e-30);
for (std::size_t i = 0; i < 10; ++i)
for (std::size_t j = 0; j < 10; ++j)
if (i == j || gen_bool())
builder = builder.with_element(i, j);

SparseMatrixPolicy<double> A(builder);
CPUSparseMatrixPolicy<double> cpu_A(builder);
GPUSparseMatrixPolicy<double> gpu_A(builder);

for (std::size_t i = 0; i < 10; ++i)
for (std::size_t j = 0; j < 10; ++j)
if (!A.IsZero(i, j))
if (!cpu_A.IsZero(i, j))
for (std::size_t i_block = 0; i_block < n_grids; ++i_block)
A[i_block][i][j] = get_double();

micm::CudaLuDecomposition gpu_lud(A);
auto gpu_LU = micm::CudaLuDecomposition::GetLUMatrices(A, 1.0e-30);
gpu_lud.Decompose<double, SparseMatrixPolicy>(A, gpu_LU.first, gpu_LU.second);
check_results<double, SparseMatrixPolicy>(
A, gpu_LU.first, gpu_LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); });
{
cpu_A[i_block][i][j] = get_double();
gpu_A[i_block][i][j] = cpu_A[i_block][i][j];
}

micm::CudaLuDecomposition gpu_lud(gpu_A);
auto gpu_LU = micm::CudaLuDecomposition::GetLUMatrices(gpu_A, 1.0e-30);
gpu_A.CopyToDevice();
gpu_LU.first.CopyToDevice();
gpu_LU.second.CopyToDevice();
gpu_lud.Decompose<double, GPUSparseMatrixPolicy>(gpu_A, gpu_LU.first, gpu_LU.second);
gpu_LU.first.CopyToHost();
gpu_LU.second.CopyToHost();
check_results<double, GPUSparseMatrixPolicy>(
gpu_A, gpu_LU.first, gpu_LU.second, [&](const double a, const double b) -> void { EXPECT_NEAR(a, b, 1.0e-5); });

micm::LuDecomposition cpu_lud = micm::LuDecomposition::Create<double, SparseMatrixPolicy>(A);
auto cpu_LU = micm::LuDecomposition::GetLUMatrices<double, SparseMatrixPolicy>(A, 1.0e-30);
cpu_lud.Decompose<double, SparseMatrixPolicy>(A, cpu_LU.first, cpu_LU.second);
micm::LuDecomposition cpu_lud = micm::LuDecomposition::Create<double, CPUSparseMatrixPolicy>(cpu_A);
auto cpu_LU = micm::LuDecomposition::GetLUMatrices<double, CPUSparseMatrixPolicy>(cpu_A, 1.0e-30);
cpu_lud.Decompose<double, CPUSparseMatrixPolicy>(cpu_A, cpu_LU.first, cpu_LU.second);

// checking GPU result again CPU
gpu_validation<double, SparseMatrixPolicy>(gpu_LU.first, cpu_LU.first, gpu_LU.second, cpu_LU.second);
size_t L_size = cpu_LU.first.AsVector().size();
size_t U_size = cpu_LU.second.AsVector().size();
std::vector<double> gpu_L_vector = gpu_LU.first.AsVector();
std::vector<double> gpu_U_vector = gpu_LU.second.AsVector();
std::vector<double> cpu_L_vector = cpu_LU.first.AsVector();
std::vector<double> cpu_U_vector = cpu_LU.second.AsVector();
for (int i = 0; i < L_size; ++i)
{
EXPECT_DOUBLE_EQ(gpu_L_vector[i], cpu_L_vector[i]);
};
for (int j = 0; j < U_size; ++j)
{
EXPECT_DOUBLE_EQ(gpu_U_vector[j], cpu_U_vector[j]);
};
}

template<class T>
using Group1SparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<10>>;
using Group1CPUSparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<1>>;
template<class T>
using Group100CPUSparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<100>>;
template<class T>
using Group1000CPUSparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<1000>>;
template<class T>
using Group100000CPUSparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<100000>>;

template<class T>
using Group1CudaSparseMatrix = micm::CudaSparseMatrix<T, micm::SparseMatrixVectorOrdering<1>>;
template<class T>
using Group2SparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<100>>;
using Group100CudaSparseMatrix = micm::CudaSparseMatrix<T, micm::SparseMatrixVectorOrdering<100>>;
template<class T>
using Group3SparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<1000>>;
using Group1000CudaSparseMatrix = micm::CudaSparseMatrix<T, micm::SparseMatrixVectorOrdering<1000>>;
template<class T>
using Group4SparseVectorMatrix = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<100000>>;
using Group100000CudaSparseMatrix = micm::CudaSparseMatrix<T, micm::SparseMatrixVectorOrdering<100000>>;

TEST(CudaLuDecomposition, RandomMatrixVectorOrdering)
{
testRandomMatrix<Group1SparseVectorMatrix>(10);
testRandomMatrix<Group2SparseVectorMatrix>(100);
testRandomMatrix<Group3SparseVectorMatrix>(1000);
testRandomMatrix<Group4SparseVectorMatrix>(100000);
testRandomMatrix<Group1CPUSparseVectorMatrix, Group1CudaSparseMatrix>(1);
testRandomMatrix<Group100CPUSparseVectorMatrix, Group100CudaSparseMatrix>(100);
testRandomMatrix<Group1000CPUSparseVectorMatrix, Group1000CudaSparseMatrix>(1000);
testRandomMatrix<Group100000CPUSparseVectorMatrix, Group100000CudaSparseMatrix>(100000);
}

0 comments on commit 9a958e3

Please sign in to comment.