Skip to content
Snippets Groups Projects
Commit c8774a97 authored by Elias Pipping's avatar Elias Pipping
Browse files

[Cleanup] Inherit from BNGSP

parent 935590db
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <dune/solvers/common/interval.hh> #include <dune/solvers/common/interval.hh>
#include <dune/solvers/computeenergy.hh> #include <dune/solvers/computeenergy.hh>
#include <dune/tnnmg/problem-classes/bisection.hh> #include <dune/tnnmg/problem-classes/bisection.hh>
#include <dune/tnnmg/problem-classes/blocknonlineargsproblem.hh>
#include "ellipticenergy.hh" #include "ellipticenergy.hh"
#include "globalnonlinearity.hh" #include "globalnonlinearity.hh"
...@@ -19,13 +20,18 @@ ...@@ -19,13 +20,18 @@
/** \brief Base class for problems where each block can be solved with a /** \brief Base class for problems where each block can be solved with a
* modified gradient method */ * modified gradient method */
template <class ConvexProblem> class MyBlockProblem { template <class ConvexProblem>
class MyBlockProblem : /* NOT PUBLIC */ BlockNonlinearGSProblem<ConvexProblem> {
private:
typedef BlockNonlinearGSProblem<ConvexProblem> BNGSP;
public: public:
using ConvexProblemType = ConvexProblem; using typename BNGSP::ConvexProblemType;
using VectorType = typename ConvexProblem::VectorType; using typename BNGSP::LocalMatrixType;
using MatrixType = typename ConvexProblem::MatrixType; using typename BNGSP::LocalVectorType;
using LocalVector = typename ConvexProblem::LocalVectorType; using typename BNGSP::MatrixType;
using LocalMatrix = typename ConvexProblem::LocalMatrixType; using typename BNGSP::VectorType;
size_t static const block_size = ConvexProblem::block_size; size_t static const block_size = ConvexProblem::block_size;
size_t static const coarse_block_size = block_size; size_t static const coarse_block_size = block_size;
...@@ -35,7 +41,7 @@ template <class ConvexProblem> class MyBlockProblem { ...@@ -35,7 +41,7 @@ template <class ConvexProblem> class MyBlockProblem {
struct Linearization { struct Linearization {
size_t static const block_size = coarse_block_size; size_t static const block_size = coarse_block_size;
using LocalMatrix = typename MyBlockProblem<ConvexProblem>::LocalMatrix; using LocalMatrix = typename MyBlockProblem<ConvexProblem>::LocalMatrixType;
using MatrixType = Dune::BCRSMatrix<typename Linearization::LocalMatrix>; using MatrixType = Dune::BCRSMatrix<typename Linearization::LocalMatrix>;
using VectorType = using VectorType =
Dune::BlockVector<Dune::FieldVector<double, Linearization::block_size>>; Dune::BlockVector<Dune::FieldVector<double, Linearization::block_size>>;
...@@ -48,9 +54,10 @@ template <class ConvexProblem> class MyBlockProblem { ...@@ -48,9 +54,10 @@ template <class ConvexProblem> class MyBlockProblem {
Dune::BitSetVector<Linearization::block_size> truncation; Dune::BitSetVector<Linearization::block_size> truncation;
}; };
MyBlockProblem(Dune::ParameterTree const &parset, MyBlockProblem(Dune::ParameterTree const &parset, ConvexProblem &problem)
ConvexProblem const &problem) : BNGSP(parset, problem),
: parset(parset), problem(problem), localBisection() // NOTE: defaults parset_(parset),
localBisection() // TODO
{} {}
std::string getOutput(bool header = false) const { std::string getOutput(bool header = false) const {
...@@ -92,8 +99,8 @@ template <class ConvexProblem> class MyBlockProblem { ...@@ -92,8 +99,8 @@ template <class ConvexProblem> class MyBlockProblem {
MyDirectionalConvexFunction< MyDirectionalConvexFunction<
GlobalNonlinearity<MatrixType, VectorType>> const GlobalNonlinearity<MatrixType, VectorType>> const
psi(computeDirectionalA(problem.A, v), psi(computeDirectionalA(problem_.A, v),
computeDirectionalb(problem.A, problem.f, u, v), problem.phi, u, v); computeDirectionalb(problem_.A, problem_.f, u, v), problem_.phi, u, v);
Dune::Solvers::Interval<double> D; Dune::Solvers::Interval<double> D;
psi.subDiff(0, D); psi.subDiff(0, D);
...@@ -115,7 +122,7 @@ template <class ConvexProblem> class MyBlockProblem { ...@@ -115,7 +122,7 @@ template <class ConvexProblem> class MyBlockProblem {
linearization.truncation.resize(u.size()); linearization.truncation.resize(u.size());
linearization.truncation.unsetAll(); linearization.truncation.unsetAll();
for (size_t i = 0; i < u.size(); ++i) { for (size_t i = 0; i < u.size(); ++i) {
if (problem.phi.regularity(i, u[i]) > 1e8) { // TODO: Make customisable if (problem_.phi.regularity(i, u[i]) > 1e8) { // TODO: Make customisable
linearization.truncation[i] = true; linearization.truncation[i] = true;
continue; continue;
} }
...@@ -126,31 +133,31 @@ template <class ConvexProblem> class MyBlockProblem { ...@@ -126,31 +133,31 @@ template <class ConvexProblem> class MyBlockProblem {
} }
// construct sparsity pattern for linearization // construct sparsity pattern for linearization
Dune::MatrixIndexSet indices(problem.A.N(), problem.A.M()); Dune::MatrixIndexSet indices(problem_.A.N(), problem_.A.M());
indices.import(problem.A); indices.import(problem_.A);
problem.phi.addHessianIndices(indices); problem_.phi.addHessianIndices(indices);
// construct matrix from pattern and initialize it // construct matrix from pattern and initialize it
indices.exportIdx(linearization.A); indices.exportIdx(linearization.A);
linearization.A = 0.0; linearization.A = 0.0;
// compute quadratic part of hessian (linearization.A += problem.A) // compute quadratic part of hessian (linearization.A += problem_.A)
for (size_t i = 0; i < problem.A.N(); ++i) { for (size_t i = 0; i < problem_.A.N(); ++i) {
auto const end = problem.A[i].end(); auto const end = problem_.A[i].end();
for (auto it = problem.A[i].begin(); it != end; ++it) for (auto it = problem_.A[i].begin(); it != end; ++it)
linearization.A[i][it.index()] += *it; linearization.A[i][it.index()] += *it;
} }
// compute nonlinearity part of hessian // compute nonlinearity part of hessian
problem.phi.addHessian(u, linearization.A); problem_.phi.addHessian(u, linearization.A);
// compute quadratic part of gradient // compute quadratic part of gradient
linearization.b.resize(u.size()); linearization.b.resize(u.size());
problem.A.mv(u, linearization.b); problem_.A.mv(u, linearization.b);
linearization.b -= problem.f; linearization.b -= problem_.f;
// compute nonlinearity part of gradient // compute nonlinearity part of gradient
problem.phi.addGradient(u, linearization.b); problem_.phi.addGradient(u, linearization.b);
// -grad is needed for Newton step // -grad is needed for Newton step
linearization.b *= -1.0; linearization.b *= -1.0;
...@@ -182,14 +189,14 @@ template <class ConvexProblem> class MyBlockProblem { ...@@ -182,14 +189,14 @@ template <class ConvexProblem> class MyBlockProblem {
/** \brief Constructs and returns an iterate object */ /** \brief Constructs and returns an iterate object */
IterateObject getIterateObject() { IterateObject getIterateObject() {
return IterateObject(parset, localBisection, problem); return IterateObject(parset_, localBisection, problem_);
} }
private: private:
Dune::ParameterTree const &parset; Dune::ParameterTree const &parset_;
// problem data // problem data
ConvexProblem const &problem; using BNGSP::problem_;
Bisection const localBisection; Bisection const localBisection;
...@@ -220,7 +227,7 @@ class MyBlockProblem<ConvexProblem>::IterateObject { ...@@ -220,7 +227,7 @@ class MyBlockProblem<ConvexProblem>::IterateObject {
} }
/** \brief Update the i-th block of the current iterate */ /** \brief Update the i-th block of the current iterate */
void updateIterate(LocalVector const &ui, size_t i) { void updateIterate(LocalVectorType const &ui, size_t i) {
u[i] = ui; u[i] = ui;
return; return;
} }
...@@ -231,11 +238,11 @@ class MyBlockProblem<ConvexProblem>::IterateObject { ...@@ -231,11 +238,11 @@ class MyBlockProblem<ConvexProblem>::IterateObject {
* \param ignore Set of degrees of freedom to leave untouched * \param ignore Set of degrees of freedom to leave untouched
*/ */
void solveLocalProblem( void solveLocalProblem(
LocalVector &ui, size_t m, LocalVectorType &ui, size_t m,
typename Dune::BitSetVector<block_size>::const_reference ignore) { typename Dune::BitSetVector<block_size>::const_reference ignore) {
{ {
LocalMatrix const *localA = nullptr; LocalMatrixType const *localA = nullptr;
LocalVector localb(problem.f[m]); LocalVectorType localb(problem.f[m]);
auto const end = problem.A[m].end(); auto const end = problem.A[m].end();
for (auto it = problem.A[m].begin(); it != end; ++it) { for (auto it = problem.A[m].begin(); it != end; ++it) {
......
...@@ -361,7 +361,7 @@ int main(int argc, char *argv[]) { ...@@ -361,7 +361,7 @@ int main(int argc, char *argv[]) {
myGlobalNonlinearity->updateLogState(_alpha); myGlobalNonlinearity->updateLogState(_alpha);
// NIT: Do we really need to pass u here? // NIT: Do we really need to pass u here?
typename NonlinearFactory::ConvexProblem const convexProblem( typename NonlinearFactory::ConvexProblem convexProblem(
1.0, velocityMatrix, *myGlobalNonlinearity, velocityRHS, 1.0, velocityMatrix, *myGlobalNonlinearity, velocityRHS,
_velocityIterate); _velocityIterate);
typename NonlinearFactory::BlockProblem velocityProblem(parset, typename NonlinearFactory::BlockProblem velocityProblem(parset,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment