Skip to content

Commit

Permalink
add sparse matrix ordering policy
Browse files Browse the repository at this point in the history
  • Loading branch information
mattldawson committed Jun 29, 2023
1 parent 6106d99 commit cd027d4
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 26 deletions.
50 changes: 24 additions & 26 deletions include/micm/util/sparse_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,31 @@
#include <utility>
#include <vector>

#include <micm/util/standard_sparse_matrix.hpp>

namespace micm
{

template<class T>
template<class T, class A>
class SparseMatrixBuilder;

/// @brief A sparse block-diagonal 2D matrix class with contiguous memory
///
/// Each block sub-matrix is square and has the same structure of non-zero elements
///
/// Sparse matrix data structure follows the Compressed Sparse Row (CSR) pattern
template<class T>
class SparseMatrix
///
/// The template parameters are the type of the matrix elements and a class that
/// defines the sizing and ordering of the data elements
template<class T, class A = StandardSparseMatrix>
class SparseMatrix : public A
{
std::size_t number_of_blocks_; // Number of block sub-matrices in the overall matrix
std::vector<T> data_; // Value of each non-zero matrix element
std::vector<std::size_t> row_ids_; // Row indices of each non-zero element in a block
std::vector<std::size_t> row_start_; // Index in data_ and row_ids_ of the start of each column in a block

friend class SparseMatrixBuilder<T>;
std::vector<std::size_t> row_start_; // Index in data_ and row_ids_ of the start of each row in a block
std::vector<T> data_; // Value of each non-zero matrix element

friend class SparseMatrixBuilder<T, A>;
friend class ProxyRow;
friend class ConstProxyRow;
friend class Proxy;
Expand Down Expand Up @@ -132,27 +137,27 @@ namespace micm
};

public:
static SparseMatrixBuilder<T> create(std::size_t block_size)
static SparseMatrixBuilder<T, A> create(std::size_t block_size)
{
return SparseMatrixBuilder<T>{ block_size };
return SparseMatrixBuilder<T, A>{ block_size };
}

SparseMatrix() = default;

SparseMatrix(SparseMatrixBuilder<T>& builder)
SparseMatrix(SparseMatrixBuilder<T, A>& builder)
: number_of_blocks_(builder.number_of_blocks_),
data_(builder.NumberOfElements(), builder.initial_value_),
row_ids_(builder.RowIdsVector()),
row_start_(builder.RowStartVector())
{
row_start_(builder.RowStartVector()),
data_(A::VectorSize(number_of_blocks_, row_ids_, row_start_), builder.initial_value_)
{
}

SparseMatrix<T>& operator=(SparseMatrixBuilder<T>& builder)
SparseMatrix<T, A>& operator=(SparseMatrixBuilder<T, A>& builder)
{
number_of_blocks_ = builder.number_of_blocks_;
data_ = std::vector<T>(builder.NumberOfElements(), builder.initial_value_);
row_ids_ = builder.RowIdsVector();
row_start_ = builder.RowStartVector();
data_ = std::vector<T>(A::VectorSize(number_of_blocks_, row_ids_, row_start_), builder.initial_value_);

return *this;
}
Expand All @@ -168,14 +173,7 @@ namespace micm

std::size_t VectorIndex(std::size_t block, std::size_t row, std::size_t column) const
{
if (row >= row_start_.size() - 1 || column >= row_start_.size() - 1 || block >= number_of_blocks_)
throw std::invalid_argument("SparseMatrix element out of range");
auto begin = std::next(row_ids_.begin(), row_start_[row]);
auto end = std::next(row_ids_.begin(), row_start_[row + 1]);
auto elem = std::find(begin, end, column);
if (elem == end)
throw std::invalid_argument("SparseMatrix zero element access not allowed");
return std::size_t{ (elem - row_ids_.begin()) + block * row_ids_.size() };
return A::VectorIndex(number_of_blocks_, row_ids_, row_start_, block, row, column);
}

std::size_t VectorIndex(std::size_t row, std::size_t column) const
Expand Down Expand Up @@ -228,7 +226,7 @@ namespace micm
}
};

template<class T>
template<class T, class A = StandardSparseMatrix>
class SparseMatrixBuilder
{
std::size_t number_of_blocks_{ 1 };
Expand All @@ -245,9 +243,9 @@ namespace micm
{
}

operator SparseMatrix<T>() const
operator SparseMatrix<T, A>() const
{
return SparseMatrix<T>(*this);
return SparseMatrix<T, A>(*this);
}

SparseMatrixBuilder<T>& number_of_blocks(std::size_t n)
Expand Down
42 changes: 42 additions & 0 deletions include/micm/util/standard_sparse_matrix.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (C) 2023 National Center for Atmospheric Research,
//
// SPDX-License-Identifier: Apache-2.0
#pragma once

namespace micm
{

/// @brief Defines the ordering of SparseMatrix object data
///
/// Data is stored with blocks in the block diagonal matrix as the highest
/// level structure, then by row, then by non-zero columns in each row.
class StandardSparseMatrix
{
protected:
static std::size_t VectorSize(
std::size_t number_of_blocks,
const std::vector<std::size_t>& row_ids,
const std::vector<std::size_t>& row_start)
{
return number_of_blocks * row_ids.size();
};

static std::size_t VectorIndex(
std::size_t number_of_blocks,
const std::vector<std::size_t>& row_ids,
const std::vector<std::size_t>& row_start,
std::size_t block,
std::size_t row,
std::size_t column)
{
if (row >= row_start.size() - 1 || column >= row_start.size() - 1 || block >= number_of_blocks)
throw std::invalid_argument("SparseMatrix element out of range");
auto begin = std::next(row_ids.begin(), row_start[row]);
auto end = std::next(row_ids.begin(), row_start[row + 1]);
auto elem = std::find(begin, end, column);
if (elem == end)
throw std::invalid_argument("SparseMatrix zero element access not allowed");
return std::size_t{ (elem - row_ids.begin()) + block * row_ids.size() };
};
};
} // namespace micm

0 comments on commit cd027d4

Please sign in to comment.