Skip to content
Snippets Groups Projects
Commit bad7fc61 authored by Max Kahnt's avatar Max Kahnt
Browse files

Make LocalSolvers have an explicit type.

parent 36d177a6
No related branches found
No related tags found
1 merge request!14Feature/blockgsstep default constructible
......@@ -240,70 +240,108 @@ auto diagRegularize(double p, LocalSolver&& localSolver) {
//! \brief is @LinearSolvers::direct with ignore nodes.
//! \param r Regularization parameter. Set to 0.0 to switch off regularization.
template <typename dummy = void>
auto direct(double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) {
return [r](const auto& m, const auto& b, const auto& ignore) {
auto directSolver = LinearSolvers::direct<std::decay_t<decltype(m)>,
std::decay_t<decltype(b)>>;
struct Direct {
Direct(double r = LocalSolverRegularizer::defaultDiagRegularizeParameter)
: r_(r) {}
template <class M, class V, class I>
V operator()(const M& m, const V& b, const I& ignore) const {
auto directSolver = LinearSolvers::direct<M, V>;
auto directTruncatedSolver =
LocalSolverFromLinearSolver::truncateSymmetrically(directSolver);
auto directTruncatedRegularizedSolver =
LocalSolverRegularizer::diagRegularize(r, directTruncatedSolver);
LocalSolverRegularizer::diagRegularize(r_, directTruncatedSolver);
return directTruncatedRegularizedSolver(m, b, ignore);
}
private:
double r_;
};
template <class... Args>
auto direct(Args&&... args) {
return Direct(std::forward<Args>(args)...);
}
//! \brief is @LinearSolvers::ldlt with ignore nodes.
template <typename dummy = void>
auto ldlt(double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) {
return [r](const auto& m, const auto& b, const auto& ignore) {
auto ldltSolver = LinearSolvers::ldlt<std::decay_t<decltype(m)>,
std::decay_t<decltype(b)>>;
struct LDLt {
LDLt(double r = LocalSolverRegularizer::defaultDiagRegularizeParameter)
: r_(r) {}
template <class M, class V, class I>
V operator()(const M& m, const V& b, const I& ignore) const {
auto ldltSolver = LinearSolvers::ldlt<M, V>;
auto ldltTruncatedSolver =
LocalSolverFromLinearSolver::truncateSymmetrically(ldltSolver);
auto ldltTruncatedRegularizedSolver =
LocalSolverRegularizer::diagRegularize(r, ldltTruncatedSolver);
LocalSolverRegularizer::diagRegularize(r_, ldltTruncatedSolver);
return ldltTruncatedRegularizedSolver(m, b, ignore);
}
private:
double r_;
};
template <class... Args>
auto ldlt(Args&&... args) {
return LDLt(std::forward<Args>(args)...);
}
//! \brief is @LinearSolvers::cg with ignore nodes.
template <typename dummy = void>
auto cg(size_t maxIter = LinearSolvers::defaultCgMaxIter,
struct CG {
CG(size_t maxIter = LinearSolvers::defaultCgMaxIter,
double tol = LinearSolvers::defaultCgTol,
double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) {
return [maxIter, tol, r](const auto& m, const auto& b, const auto& ignore) {
double r = LocalSolverRegularizer::defaultDiagRegularizeParameter)
: maxIter_(maxIter)
, tol_(tol)
, r_(r) {}
template <class M, class V, class I>
V operator()(const M& m, const V& b, const I& ignore) const {
using namespace std::placeholders;
auto cgSolver = std::bind(
LinearSolvers::cg<std::decay_t<decltype(m)>, std::decay_t<decltype(b)>>,
_1, _2, maxIter, tol);
auto cgTruncatedSolver = LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver);
auto cgSolver = std::bind(LinearSolvers::cg<M, V>, _1, _2, maxIter_, tol_);
auto cgTruncatedSolver =
LocalSolverFromLinearSolver::truncateSymmetrically(cgSolver);
auto cgTruncatedRegularizedSolver =
LocalSolverRegularizer::diagRegularize(r, cgTruncatedSolver);
LocalSolverRegularizer::diagRegularize(r_, cgTruncatedSolver);
return cgTruncatedRegularizedSolver(m, b, ignore);
}
private:
size_t maxIter_;
double tol_;
double r_;
};
template <class... Args>
auto cg(Args&&... args) {
return CG(std::forward<Args>(args)...);
}
/**
* \brief is @LinearSolvers::gs with ignore nodes.
* \param tol Tolerance value for skipping potentially zero diagonal entries
* after regularization.
* \param r Regularization parameter. Set to 0.0 to switch off regularization.
*/
template <typename dummy = void>
auto gs(double tol = LinearSolvers::defaultGsTol,
double r = LocalSolverRegularizer::defaultDiagRegularizeParameter) {
return [tol, r](const auto& m, const auto& b, const auto& ignore) {
struct GS {
GS(double tol = LinearSolvers::defaultGsTol,
double r = LocalSolverRegularizer::defaultDiagRegularizeParameter)
: tol_(tol)
, r_(r) {}
template <class M, class V, class I>
V operator()(const M& m, const V& b, const I& ignore) const {
using namespace std::placeholders;
auto gsSolver = std::bind(
LinearSolvers::gs<std::decay_t<decltype(m)>, std::decay_t<decltype(b)>>,
_1, _2, tol);
auto gsSolver = std::bind(LinearSolvers::gs<M, V>, _1, _2, tol_);
auto gsTruncatedSolver =
LocalSolverFromLinearSolver::truncateSymmetrically(gsSolver);
auto gsTruncatedRegularizedSolver =
LocalSolverRegularizer::diagRegularize(r, gsTruncatedSolver);
LocalSolverRegularizer::diagRegularize(r_, gsTruncatedSolver);
return gsTruncatedRegularizedSolver(m, b, ignore);
}
private:
double tol_;
double r_;
};
template <class... Args>
auto gs(Args&&... args) {
return GS(std::forward<Args>(args)...);
}
} // end namespace LocalSolvers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment