Commit 620c3d54 authored by maxka's avatar maxka
Browse files

Merge branch 'feature/new_tnnmg-simplexSolver' into 'feature/new_tnnmg'

Feature/new tnnmg simplex solver

See merge request !3
parents be6330b1 fe2acad8
Pipeline #6171 failed with stage
in 4 minutes and 53 seconds
// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#ifndef DUNE_TNNMG_LOCALSOLVERS_SIMPLEXSOLVER_HH
#define DUNE_TNNMG_LOCALSOLVERS_SIMPLEXSOLVER_HH
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include <dune/matrix-vector/axpy.hh>
#include "../functionals/boxconstrainedquadraticfunctional.hh"
namespace Dune {
namespace TNNMG {
/**
* \brief A local solver for quadratic minimization problems with lower obstacle
* where the quadratic part is given by a scaled identity. Additionally a sum
* constraint is imposed.
*
* Note: In order to use it as a solver to compute corrections for a
* simplex-constrained problem, make sure the iterate already fulfills the sum
* constraint and you compute corrections with sum constraint 0.
*
* The available implementation solves the problem in n*log(n) time, exploiting
* the decoupling of the energy w.r.t. the coordinate directions.
*
* \todo Add generic case where the quadratic part is not given by a scaled
* identity.
*
*/
template <class Field = double>
struct SimplexSolver {
SimplexSolver(Field sum = 0.0)
: r_(sum) {}
/**
* \brief Project b to a (scaled) Gibbs simplex
* \param x projected vector
* \param b input vector
* \param s scaling of the simplex
*
* Write the problem as follows:
* ( I_N + \partial\phi_N 1_N^T ) ( x ) = ( b )
* ( 1_N 0 ) ( lambda ) = ( s )
* where I_N is the identity matrix with N rows and columns,
* phi_N is an N-vector of phi_0, the latter denoting the obstacle at 0 and
* 1_N is the N-vector of ones.
*
* We determine lambda by a Schur approach:
* For each lambda, (I_N + \partial\phi_N) \ (b - lambda) yields a unique
* solution which we can use to replace x in the second equation. Furthermore
* the contributions do not couple, so we obtain
* s = \sum_i (1 + \partial\phi_0) \ (b_i - lambda).
*
* We sort the values of b by size and successively check if the resulting
* intervals do permit a lambda that induces an activity pattern which does
* indeed cut off all subsequent (lower) values of b.
* Note: This also works in the corner case where b contains duplicate
* values, because lambda cannot be larger than the successive element (with
* the same value) in these cases -- the interval has width 0.
*
* With given lambda, we can compute x easily.
*
*/
template <class X, class R>
static auto projectToSimplex(X& x, X b, R s = 1.0) {
if (s == 0.0) {
x = 0.0;
return;
}
assert(s > 0.0 && "Sum constraint must be non-negative.");
size_t N = b.size();
assert(N > 0 && "Vector has zero size.");
// determine lambda by successively checking the intervals of b
R lambda;
x = b; // alienate x for a sorted version of b
std::sort(x.begin(), x.end(), std::greater<R>());
s *= -1.0;
for (size_t i = 0; i < N; ++i) {
s += x[i]; //
lambda = s / (i + 1);
if (x[i + 1] < lambda)
break;
}
// compute x
for (size_t j = 0; j < N; ++j)
x[j] = std::max(b[j] - lambda, R(0));
}
/**
* \brief Solve a quadratic minimization problem with lower obstacle where
* the quadratic part is given by a scaled identity. Additionally, a sum
* constraint is imposed on the solution.
* \param x the minimizer
* \param f the quadratic functional with lower obstacle
* \param ignore ignore nodes
*
* This determines x for the following system (1).
* The equivalence transformations (2),(3) show what is implemented.
*
* (1a) x = argmin 0.5*<Av,v> - <b,v> where
* (1b) \sum v_i = r
* (1c) v_i \geq l_i
*
* Divide energy by a>0.
* Set c := 1/a * b.
* (2a) x = argmin 0.5*<v,v> - <c,v> where
* (2b) \sum v_i = r
* (2c) v_i \geq l_i
*
* Transform w := v - l.
* Set d := c - l, s = r - \sum l_i, const = 0.5*<l,l> - <c,l>.
* (3a) x = l + argmin 0.5*<w,w> - <d,w> + const
* (3b) \sum w_i = s
* (3c) v_i \geq 0
*
*/
template <class X, class K, int n, class V, class L, class U, class R,
class Ignore>
auto operator()(X&& x, BoxConstrainedQuadraticFunctional<
ScaledIdentityMatrix<K, n>&, V, L, U, R>& f,
Ignore& ignore) const {
if (ignore.all())
return;
if (ignore.any())
DUNE_THROW(NotImplemented, "All or no ignore nodes should be set.");
assert(ignore.none());
const auto& a = f.quadraticPart().scalar();
if (a <= 0.0)
DUNE_THROW(MathError, "ScaledIdentity scaling must be positive.");
for (auto&& ui : f.upperObstacle())
assert(std::isinf(ui) && "Upper obstacle must be infinity.");
auto&& l = f.lowerObstacle();
auto s = std::accumulate(l.begin(), l.end(), r_, std::minus<R>());
auto d = l;
d *= -1.0;
Dune::MatrixVector::addProduct(d, 1.0 / a, f.linearPart());
projectToSimplex(x, d, s);
x += l;
}
private:
const Field r_;
};
} // end namespace TNNMG
} // end namespace Dune
#endif // DUNE_TNNMG_LOCALSOLVERS_SIMPLEXSOLVER_HH
......@@ -5,3 +5,4 @@ dune_add_test(SOURCES multitypegstest)
dune_add_test(SOURCES nonlineargstest)
dune_add_test(SOURCES nonlineargsperformancetest NAME nonlineargsperformancetest_corrections)
dune_add_test(SOURCES nonlineargsperformancetest NAME nonlineargsperformancetest_iterates COMPILE_DEFINITIONS "NEW_TNNMG_COMPUTE_ITERATES_DIRECTLY=1")
dune_add_test(SOURCES simplexsolvertest)
// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#include <config.h>
#include <iostream>
#include <limits>
#include <dune/common/fvector.hh>
#include <dune/common/parallel/mpihelper.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/bvector.hh>
#include <dune/istl/matrixindexset.hh>
#include <dune/istl/scaledidmatrix.hh>
#include <dune/solvers/common/defaultbitvector.hh>
#include <dune/tnnmg/functionals/boxconstrainedquadraticfunctional.hh>
#include <dune/tnnmg/iterationsteps/nonlineargsstep.hh>
#include <dune/tnnmg/localsolvers/simplexsolver.hh>
constexpr size_t N = 7;
using Field = double;
using Solver = Dune::TNNMG::SimplexSolver<Field>;
using LocalVector = Dune::FieldVector<Field, N>;
using LocalMatrix = Dune::ScaledIdentityMatrix<Field, N>;
using LocalFunctional = Dune::TNNMG::BoxConstrainedQuadraticFunctional<
LocalMatrix&, LocalVector&, LocalVector&, LocalVector&, Field>;
using LocalIgnore = std::bitset<N>;
using Vector = Dune::BlockVector<LocalVector>;
using Matrix = Dune::BCRSMatrix<LocalMatrix>;
using Functional = Dune::TNNMG::BoxConstrainedQuadraticFunctional<
Matrix, Vector, Vector, Vector, Field>;
using Ignore = Dune::Solvers::DefaultBitVector_t<Vector>;
constexpr Field tol = 1e-14;
bool equalsWithTol(Field a, Field b) { return fabs(b - a) < tol; }
bool hasSumWithTol(LocalVector& v, Field sum) {
return fabs(std::accumulate(v.begin(), v.end(), -sum)) < tol;
}
int main(int argc, char** argv) try {
Dune::MPIHelper::instance(argc, argv);
bool passed(true);
{ // test locally
LocalVector upper(std::numeric_limits<Field>::infinity());
LocalVector lower(0.0);
LocalMatrix matrix(1.0);
LocalVector rhs(1.0);
LocalFunctional f(matrix, rhs, lower, upper);
LocalVector x(0.0);
LocalIgnore ignore(false);
{ // all values one, constraint N
Solver solver(N);
solver(x, f, ignore);
for (auto&& xi : x)
passed = passed and equalsWithTol(xi, 1.0);
}
{ // all values one, constraint 1 -> project
Solver solver(1.0);
solver(x, f, ignore);
for (auto&& xi : x)
passed = passed and equalsWithTol(xi, 1.0 / N);
}
{ // one entry zero, others 1
Solver solver(N - 1);
rhs[0] = 0.0;
solver(x, f, ignore);
passed = passed and equalsWithTol(x[0], 0.0);
for (size_t i = 1; i < N; ++i)
passed = passed and equalsWithTol(x[i], 1.0);
}
{ // everything ignored
Solver solver;
for (size_t i = 0; i < N; ++i)
ignore[i] = true;
x = std::numeric_limits<Field>::quiet_NaN();
solver(x, f, ignore);
for (auto&& xi : x)
passed = passed and std::isnan(xi);
for (size_t i = 0; i < N; ++i)
ignore[i] = false;
}
{ // partially ignored (throws)
Solver solver;
ignore[0] = true;
try {
solver(x, f, ignore);
passed = passed and false;
} catch (Dune::NotImplemented) {
passed = passed and true;
}
ignore[0] = false;
}
{ // different values, fitting constraint
for (size_t i = 0; i < N; ++i)
rhs[i] = i + 1;
Field sum = N * (N + 1) / 2;
Solver solver(sum);
solver(x, f, ignore);
for (size_t i = 0; i < N; ++i)
passed = passed and equalsWithTol(x[i], i + 1);
}
{ // different values, oversized constraint -> project
for (size_t i = 0; i < N; ++i)
rhs[i] = i + 1;
Field sum = N * (N + 1) / 2;
Solver solver(10 * sum);
solver(x, f, ignore);
Field sumShift = (sum * 10 - sum) / N;
for (size_t i = 0; i < N; ++i)
passed = passed and equalsWithTol(x[i], i + 1 + sumShift);
}
{ // different values, undersized constraint -> project
for (size_t i = 0; i < N; ++i)
rhs[i] = i + 1;
Field sum = N * (N + 1) / 2;
Solver solver(sum - N);
solver(x, f, ignore);
for (size_t i = 0; i < N; ++i)
passed = passed and equalsWithTol(x[i], i);
}
{ // different and coinciding values, undersized constraint -> project
rhs[0] = 1;
for (size_t i = 1; i < N; ++i)
rhs[i] = i;
Field sum = (N - 1) * (N - 2) / 2;
Solver solver(sum);
solver(x, f, ignore);
passed = passed and equalsWithTol(x[0], 0);
for (size_t i = 1; i < N; ++i)
passed = passed and equalsWithTol(x[i], i - 1);
}
{ // non-zero obstacle
lower = 1.0;
for (size_t i = 0; i < N; ++i)
rhs[i] = i;
Field sum = N * (N - 1) / 2 + 1;
Solver solver(sum);
solver(x, f, ignore);
passed = passed and equalsWithTol(x[0], 1);
for (size_t i = 1; i < N; ++i)
passed = passed and equalsWithTol(x[i], i);
}
}
{ // test globally
size_t m = 10;
Vector upper(m);
Vector lower(m);
Vector rhs(m);
Vector x(m);
Ignore ignore(m);
Dune::MatrixIndexSet indices(m, m);
for (size_t i = 0; i < m; ++i)
indices.add(i, i);
Matrix matrix;
indices.exportIdx(matrix);
upper = std::numeric_limits<Field>::infinity();
lower = 0.0;
matrix = 1.0;
for (auto&& rhs_i : rhs)
for (size_t j = 0; j < N; ++j)
rhs_i[j] = j + 1;
Functional f(matrix, rhs, lower, upper);
x = 0.0;
for (auto&& xi : x)
xi[0] = 1.0; // make sure the local sum constraint is fulfilled initially
ignore.unsetAll();
Dune::TNNMG::gaussSeidelLocalSolver(Solver(1.0));
for (auto&& xi : x)
passed = passed and hasSumWithTol(xi, 1.0);
}
return passed ? 0 : 1;
} catch (Dune::Exception e) {
std::cout << e << std::endl;
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment