Skip to content
Snippets Groups Projects
Commit 83950abb authored by Elias Pipping's avatar Elias Pipping
Browse files

Add triangularsolvetest

parent 87862e01
No related branches found
No related tags found
No related merge requests found
dune_add_test(SOURCES arithmetictest.cc) dune_add_test(SOURCES arithmetictest.cc)
dune_add_test(SOURCES staticmatrixtoolstest.cc) dune_add_test(SOURCES staticmatrixtoolstest.cc)
dune_add_test(SOURCES triangularsolvetest.cc)
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#undef NDEBUG
#include <dune/common/bitsetvector.hh>
#include <dune/common/fmatrix.hh>
#include <dune/common/fvector.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/bvector.hh>
#include <dune/istl/matrixindexset.hh>
#include "../triangularsolve.hh"
#include "common.hh"
double const tol = 1e-10;
template <int n, bool lower>
bool test() {
bool passed = true;
Dune::BCRSMatrix<Dune::FieldMatrix<double, 1, 1>> M;
Dune::MatrixIndexSet indices(n, n);
for (size_t i = 0; i < n; ++i)
for (size_t j = lower ? 0 : i; j < (lower ? i + 1 : n); ++j)
indices.add(i, j);
indices.exportIdx(M);
std::random_device randomDevice;
std::mt19937 generator(randomDevice());
std::uniform_real_distribution<> distribution(0.1, 0.4);
for (auto it = M.begin(); it != M.end(); ++it) {
size_t const i = it.index();
for (auto cIt = it->begin(); cIt != it->end(); ++cIt) {
size_t const j = cIt.index();
*cIt = distribution(generator);
if (i == j)
*cIt += 0.6;
}
}
Dune::BlockVector<Dune::FieldVector<double, 1>> sol(n);
for (auto &x : sol)
x = distribution(generator);
Dune::BlockVector<Dune::FieldVector<double, 1>> b(n);
M.mv(sol, b);
Dune::BitSetVector<1> *ignore = nullptr;
{
auto x = lower ? Dune::MatrixVector::lowerTriangularSolve(M, b, ignore)
: Dune::MatrixVector::upperTriangularSolve(M, b, ignore);
Dune::BlockVector<Dune::FieldVector<double, 1>> y(n);
M.mv(x, y);
auto diff = diffDune(b, y);
std::cout << "Difference: " << diff << std::endl;
passed &= diff < tol;
}
{
auto x = lower
? Dune::MatrixVector::upperTriangularSolve(M, b, ignore, true)
: Dune::MatrixVector::lowerTriangularSolve(M, b, ignore, true);
Dune::BlockVector<Dune::FieldVector<double, 1>> y(n);
M.mtv(x, y);
auto diff = diffDune(b, y);
std::cout << "Difference: " << diff << std::endl;
passed &= diff < tol;
}
return passed;
}
int main() {
bool passed = true;
passed &= test<1000, false>();
passed &= test<1000, true>();
return passed ? 0 : 1;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment