Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Stig Rune Sellevag committed Aug 8, 2023
1 parent 02644c7 commit 0fff1b1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
52 changes: 26 additions & 26 deletions include/scilib/linalg_impl/auxiliary.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,28 +243,28 @@ inline T prod(const Sci::MDArray<T, stdex::extents<IndexType, ext>, Layout, Cont
//--------------------------------------------------------------------------------------------------
// Create special vectors and matrices:

template <class M, class... Args>
requires(__Detail::Is_mdarray_v<M>&& M::rank() == sizeof...(Args))
inline M zeros(Args... args)
template <class M, class... IndexTypes>
requires(__Detail::Is_mdarray_v<M> && (M::rank() == sizeof...(IndexTypes)))
inline M zeros(IndexTypes... exts)
{
using extents_type = typename M::extents_type;
using value_type = typename M::value_type;
using index_type = typename M::index_type;

return M(stdex::extents<index_type, args...>{}, value_type{0});
return M(extents_type(exts...), value_type{0});
}

template <class M, class... Args>
requires(__Detail::Is_mdarray_v<M>&& M::rank() == sizeof...(Args))
inline M ones(Args... args)
template <class M, class... IndexTypes>
requires(__Detail::Is_mdarray_v<M> && (M::rank() == sizeof...(IndexTypes)))
inline M ones(IndexTypes... exts)
{
using extents_type = typename M::extents_type;
using value_type = typename M::value_type;
using index_type = typename M::index_type;

return M(stdex::extents<index_type, args...>{}, value_type{1});
return M(extents_type(exts...), value_type{1});
}

template <class M = Sci::Matrix<double>>
requires(__Detail::Is_mdarray_v<M>&& M::rank() == 2)
requires(__Detail::Is_mdarray_v<M> && (M::rank() == 2))
inline M identity(std::size_t n)
{
using value_type = typename M::value_type;
Expand All @@ -280,54 +280,54 @@ inline M identity(std::size_t n)

// Create a random MDArray from a normal distribution with zero mean and unit
// variance.
template <class M, class... Args>
requires(__Detail::Is_mdarray_v<M>&& std::is_floating_point_v<typename M::value_type>)
inline M randn(Args... args)
template <class M, class... IndexTypes>
requires(__Detail::Is_mdarray_v<M> && (M::rank() == sizeof...(IndexTypes)) &&
std::is_floating_point_v<typename M::value_type>)
inline M randn(IndexTypes... exts)
{
static_assert(M::rank() == sizeof...(Args));
using value_type = typename M::value_type;

std::random_device rd{};
std::mt19937_64 gen{rd()};
std::normal_distribution<value_type> nd{};

M res(args...);
M res(exts...);
res.apply([&](value_type& x) { x = nd(gen); });
return res;
}

// Create a random MDArray from a uniform real distribution on the
// interval [0, 1).
template <class M, class... Args>
requires(__Detail::Is_mdarray_v<M>&& std::is_floating_point_v<typename M::value_type>)
inline M randu(Args... args)
template <class M, class... IndexTypes>
requires(__Detail::Is_mdarray_v<M> && (M::rank() == sizeof...(IndexTypes)) &&
std::is_floating_point_v<typename M::value_type>)
inline M randu(IndexTypes... exts)
{
static_assert(M::rank() == sizeof...(Args));
using value_type = typename M::value_type;

std::random_device rd{};
std::mt19937_64 gen{rd()};
std::uniform_real_distribution<value_type> ur{};

M res(args...);
M res(exts...);
res.apply([&](value_type& x) { x = ur(gen); });
return res;
}

// Create a random MDArray from a uniform integer distribution on the
// interval [0, 1].
template <class M, class... Args>
requires(__Detail::Is_mdarray_v<M>&& std::is_integral_v<typename M::value_type>)
inline M randi(Args... args)
template <class M, class... IndexTypes>
requires(__Detail::Is_mdarray_v<M> && (M::rank() == sizeof...(IndexTypes)) &&
std::is_integral_v<typename M::value_type>)
inline M randi(IndexTypes... exts)
{
static_assert(M::rank() == sizeof...(Args));
using value_type = typename M::value_type;

std::random_device rd{};
std::mt19937_64 gen{rd()};
std::uniform_int_distribution<value_type> ui{};

M res(args...);
M res(exts...);
res.apply([&](value_type& x) { x = ui(gen); });
return res;
}
Expand Down
20 changes: 20 additions & 0 deletions tests/test_linalg_aux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,23 @@ TEST(TestLinalg, TestSumProd)
EXPECT_EQ(Sci::Linalg::sum(v), 10);
EXPECT_EQ(Sci::Linalg::prod(v), 24);
}

TEST(TestLinalg, TestZerosMatrix)
{
auto m = Sci::Linalg::zeros<Sci::Matrix<int>>(2, 2);
for (Sci::index i = 0; i < m.extent(0); ++i) {
for (Sci::index j = 0; j < m.extent(1); ++j) {
EXPECT_EQ(m(i, j), 0);
}
}
}

TEST(TestLinalg, TestOnesMatrix)
{
auto m = Sci::Linalg::ones<Sci::Matrix<int>>(2, 2);
for (Sci::index i = 0; i < m.extent(0); ++i) {
for (Sci::index j = 0; j < m.extent(1); ++j) {
EXPECT_EQ(m(i, j), 1);
}
}
}

0 comments on commit 0fff1b1

Please sign in to comment.