From d6598f9f90c732c8c2fad070af0682fcfb40c991 Mon Sep 17 00:00:00 2001
From: Patrick Jaap <patrick.jaap@tu-dresden.de>
Date: Fri, 31 Mar 2023 14:17:54 +0200
Subject: [PATCH] Extent ProximalNewtonSolver to more general ProximalSolver
 for first order problems

---
 CHANGELOG.md                                  |   2 +
 dune/solvers/solvers/CMakeLists.txt           |   1 +
 ...ximalnewtonsolver.hh => proximalsolver.hh} | 126 ++++++----
 dune/solvers/test/CMakeLists.txt              |   4 +-
 dune/solvers/test/proximalsolvertest.cc       | 226 ++++++++++++++++++
 5 files changed, 316 insertions(+), 43 deletions(-)
 rename dune/solvers/solvers/{proximalnewtonsolver.hh => proximalsolver.hh} (73%)
 create mode 100644 dune/solvers/test/proximalsolvertest.cc

diff --git a/CHANGELOG.md b/CHANGELOG.md
index dca526b..96b6eb4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,8 @@
 
 - A new solver `ProximalNewtonSolver` is added which solves non-smooth minimization problems.
 
+- A more general `ProximalSolver` replaces the `ProximalNewtonSolver`. The new variant also handles first-order subproblems.
+
 # Release 2.9
 
 - The internal matrix of the`EnergyNorm` can now be accessed by `getMatrix()`.
diff --git a/dune/solvers/solvers/CMakeLists.txt b/dune/solvers/solvers/CMakeLists.txt
index 514d97f..ef9ba56 100644
--- a/dune/solvers/solvers/CMakeLists.txt
+++ b/dune/solvers/solvers/CMakeLists.txt
@@ -7,6 +7,7 @@ install(FILES
     loopsolver.cc
     loopsolver.hh
     proximalnewtonsolver.hh
+    proximalsolver.hh
     quadraticipopt.hh
     solver.hh
     tcgsolver.cc
diff --git a/dune/solvers/solvers/proximalnewtonsolver.hh b/dune/solvers/solvers/proximalsolver.hh
similarity index 73%
rename from dune/solvers/solvers/proximalnewtonsolver.hh
rename to dune/solvers/solvers/proximalsolver.hh
index 609d17d..9ed0530 100644
--- a/dune/solvers/solvers/proximalnewtonsolver.hh
+++ b/dune/solvers/solvers/proximalsolver.hh
@@ -1,7 +1,7 @@
 // -*- tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*-
 // vi: set et ts=8 sw=4 sts=4:
-#ifndef DUNE_SOLVERS_SOLVERS_PROXIMALNEWTONSOLVER_HH
-#define DUNE_SOLVERS_SOLVERS_PROXIMALNEWTONSOLVER_HH
+#ifndef DUNE_SOLVERS_SOLVERS_PROXIMALSOLVER_HH
+#define DUNE_SOLVERS_SOLVERS_PROXIMALSOLVER_HH
 
 #include <dune/common/exceptions.hh>
 
@@ -14,7 +14,7 @@
 
 namespace Dune::Solvers
 {
-  namespace ProximalNewton
+  namespace ProxSolver
   {
     /** \brief List of the four stages of the proximal Newton step
      *
@@ -29,7 +29,7 @@ namespace Dune::Solvers
       accepted       // last stage: the new increment was accepted
     };
 
-    //! A dummy class for g=0 in the ProximalNewtonSolver
+    //! A dummy class for g=0 in the ProximalSolver
     template<class VectorType>
     struct ZeroFunctional
     {
@@ -45,7 +45,7 @@ namespace Dune::Solvers
     };
 
 
-    // A simple regularization updater which doubles in case of failure and halves in case of success
+    // A simple regularization updater which multiplies with 10 in case of failure and halves in case of success
     struct SimpleRegUpdater
     {
       void operator()( double& regWeight, Stage stage) const
@@ -63,26 +63,33 @@ namespace Dune::Solvers
 
 
 
-  /** \brief Generic proximal Newton solver to solve a given minimization problem
+  /** \brief Generic Proximal solver to solve a given minimization problem
    *
-   *  The proximal Newton solver aims to solve a minimization problem given in the form
+   *  The proximal solver aims to solve a minimization problem given in the form
    *      Minimize J(x) := f(x) + g(x)
-   *  where f is a C^2 functional and g is a possibly non-smooth functional.
-   *  During the minimization, a sequence of increments dx as solutions of the second order subproblems
-   *      Minimize 0.5*f''(x)[dx,dx] +  f'(x)[dx] + g(x + dx) + r*||dx||^2
+   *  where f is a C^1 or C^2 functional and g is a possibly non-smooth functional.
+   *  During the minimization, a sequence of increments dx as solutions of the subproblems
+   *
+   *      Proximal Gradient: Minimize f'(x)[dx] + g(x + dx) + r*||dx||^2
+   *      Proximal Newton:   Minimize 0.5*f''(x)[dx,dx] +  f'(x)[dx] + g(x + dx) + r*||dx||^2
+   *
    *  is computed until the update x := x + dx converges in some sense.
    *  The user has to provide a suitable regularization strategy to control the regularization weight r,
    *  and a proper norm ||.|| for the subproblem.
+   *
+   *  The `diffOrder` parameter determines the order of the subproblem. Currently, only the values 1 and 2 are supported.
    */
-  template<class SEA, class NSF, class SOS, class VectorType, class ErrorNorm, class RegUpdater, class BitVectorType = DefaultBitVector_t<VectorType>>
-  class ProximalNewtonSolver : public Solver, public CanIgnore<BitVectorType>
+  template<class SEA, class NSF, class SPS, class VectorType, class ErrorNorm, class RegUpdater, class BitVectorType = DefaultBitVector_t<VectorType>, int diffOrder = 2>
+  class ProximalSolver : public Solver, public CanIgnore<BitVectorType>
   {
+    static_assert( diffOrder==1 or diffOrder==2, "The ProximalSolver only supports first and second order problems" );
+
   public:
 
     using SmoothEnergyAssembler = SEA;
     using NonsmoothFunctional = NSF;
-    using SecondOrderSolver = SOS;
-    using MatrixType = typename SecondOrderSolver::MatrixType;
+    using SubproblemSolver = SPS;
+    using MatrixType = typename SubproblemSolver::MatrixType;
 
     void solve() override;
 
@@ -93,8 +100,11 @@ namespace Dune::Solvers
      *             the evaluation by computeEnergy( x ) to return f(x)
      *  \param nsf The NonsmoothFunctional representing g: It must provide the method
      *             updateOffset( x ) to update the offset in g( x + dx ), and an evaluation operator().
-     *  \param sos The SecondOrderSolver which is able to minimize the second order subproblem. It must provide
-     *             a method minimize( f'', f', g, r, x, ignore) which overwrites the parameter x with the minimizer
+     *  \param sps The SubproblemSolver which is able to minimize the subproblem. It must provide
+     *             a method
+     *                  minimize( f'', f', g, r, x, ignore)    for diffOrder = 2, or
+     *                  minimize( f', g, r, x, ignore)         for diffOrder = 1,
+     *             which overwrites the parameter x with the minimizer
      *             and throws a Dune::Exception in case the minimization failed.
      *  \param solution This is the solution of the global problem. It is overwritten during the computation and serves
      *                  also as the initial value.
@@ -102,14 +112,14 @@ namespace Dune::Solvers
      *  \param regUpdater The regularization strategy. It must provide a call operator ( r, Stage ) that overwrites r
      *                    based on the given Stage of the computation
      *  \param initialRegularizationWeight The initial regularization weight to begin with
-     *  \param maxIterations The maximal number of proximal Newton steps before the Proximal Newton solver aborts the loop
+     *  \param maxIterations The maximal number of proximal steps before the Proximal ^solver aborts the loop
      *  \param threshold The threshold to stop the iteration once || dx || < threshold
-     *  \param verbosity If the verbosity is set to Solver::FULL the ProximalNewtonSolver will print a table showing
+     *  \param verbosity If the verbosity is set to Solver::FULL the ProximalSolver will print a table showing
      *                   the current iterations and some useful information.
      */
-    ProximalNewtonSolver( const SmoothEnergyAssembler& sea,
+    ProximalSolver( const SmoothEnergyAssembler& sea,
                           NonsmoothFunctional& nsf,
-                          const SecondOrderSolver& sos,
+                          const SubproblemSolver& sps,
                           VectorType& solution,
                           const ErrorNorm& errorNorm,
                           const RegUpdater& regUpdater,
@@ -119,7 +129,7 @@ namespace Dune::Solvers
                           Solver::VerbosityMode verbosity)
     : smoothEnergyAssembler_(&sea)
     , nonsmoothFunctional_(&nsf)
-    , sos_(&sos)
+    , sps_(&sps)
     , solution_(&solution)
     , norm_(&errorNorm)
     , regUpdater_(regUpdater)
@@ -209,12 +219,12 @@ namespace Dune::Solvers
 
     const SmoothEnergyAssembler* smoothEnergyAssembler_;
     NonsmoothFunctional* nonsmoothFunctional_;
-    const SecondOrderSolver* sos_;
+    const SubproblemSolver* sps_;
 
     // current iterate of the solution of the minimization problem
     VectorType* solution_;
 
-    // increments of the proximal Newton step
+    // increments of the proximal step
     std::shared_ptr<VectorType> correction_;
 
     const ErrorNorm* norm_;
@@ -250,8 +260,8 @@ namespace Dune::Solvers
 
 
 
-  template<class SEA, class NSF, class SOS, class V, class EN, class RU, class BV>
-  void ProximalNewtonSolver<SEA,NSF,SOS,V,EN,RU,BV>::solve()
+  template<class SEA, class NSF, class SPS, class V, class EN, class RU, class BV, int diffOrder>
+  void ProximalSolver<SEA,NSF,SPS,V,EN,RU,BV,diffOrder>::solve()
   {
     using VectorType = V;
 
@@ -285,7 +295,7 @@ namespace Dune::Solvers
       if ( (1.0 + regWeight)*normCorrection < threshold_ )
       {
         if ( printOutput )
-          std::cout << "ProximalNewtonSolver terminated because of weighted correction is below threshold: " << (1.0 + regWeight)*normCorrection << std::endl;
+          std::cout << "ProximalSolver terminated because of weighted correction is below threshold: " << (1.0 + regWeight)*normCorrection << std::endl;
         break;
       }
 
@@ -297,7 +307,7 @@ namespace Dune::Solvers
         if ( std::get<0>(r) )
         {
           if ( printOutput )
-            std::cout << "ProximalNewtonSolver terminated because of a user added stop criterion: " << std::get<1>( r ) << std::endl;
+            std::cout << "ProximalSolver terminated because of a user added stop criterion: " << std::get<1>( r ) << std::endl;
 
           stop = true;
           break;
@@ -315,15 +325,27 @@ namespace Dune::Solvers
       // store some information in case the step gets discarded
       auto oldX = *solution_;
 
-      // assemble the quadratic and linear part if not recycled from previous step
-      if ( not hasGradient() or not hasHessian() )
+      // assemble the derivative(s) if not recycled from previous step
+      if constexpr ( diffOrder == 1 )
       {
-        hessianPtr_ = std::make_shared<MatrixType>();
-        gradientPtr_ = std::make_shared<VectorType>();
+        if ( not hasGradient() )
+        {
+          gradientPtr_ = std::make_shared<VectorType>();
+          smoothEnergyAssembler_->assembleGradient( oldX, *gradientPtr_ );
+        }
+      }
+      if constexpr ( diffOrder == 2 )
+      {
+        if ( not hasGradient() or not hasHessian() )
+        {
+          hessianPtr_ = std::make_shared<MatrixType>();
+          gradientPtr_ = std::make_shared<VectorType>();
 
-        smoothEnergyAssembler_->assembleGradientAndHessian( oldX, *gradientPtr_, *hessianPtr_ );
+          smoothEnergyAssembler_->assembleGradientAndHessian( oldX, *gradientPtr_, *hessianPtr_ );
+        }
       }
 
+
       // shift the nonsmoothFunctional by the current x
       nonsmoothFunctional_->updateOffset( oldX );
 
@@ -333,20 +355,27 @@ namespace Dune::Solvers
       auto oldEnergy = smoothEnergyAssembler_->computeEnergy( oldX ) + (*nonsmoothFunctional_)( zeroVector );
 
       ///////////////////////////////////////////////////////////////////////////////////////////
-      /// Stage I: Try to compute a Proximal Newton step ////////////////////////////////////////
+      /// Stage I: Try to compute a Proximal step ///////////////////////////////////////////////
       ///////////////////////////////////////////////////////////////////////////////////////////
 
-      // compute one Proximal Newton Step with the second order solver
+      // compute one Proximal Step with the second order solver
       try
       {
-        sos_->minimize( *hessianPtr_, *gradientPtr_, *nonsmoothFunctional_, regWeight, *correction_, this->ignore() );
+        if constexpr ( diffOrder == 1 )
+        {
+          sps_->minimize( *gradientPtr_, *nonsmoothFunctional_, regWeight, *correction_, this->ignore() );
+        }
+        if constexpr ( diffOrder == 2 )
+        {
+          sps_->minimize( *hessianPtr_, *gradientPtr_, *nonsmoothFunctional_, regWeight, *correction_, this->ignore() );
+        }
       }
       catch(const MathError& e)
       {
         if ( printOutput )
-          printLine( iter_, usedReg, 0, oldEnergy, 0, "The Proximal Newton Step reported an error: " + std::string(e.what()) );
+          printLine( iter_, usedReg, 0, oldEnergy, 0, "The Proximal Step reported an error: " + std::string(e.what()) );
 
-        regUpdater_(regWeight, ProximalNewton::Stage::minimize );
+        regUpdater_(regWeight, ProxSolver::Stage::minimize );
         continue;
       }
 
@@ -371,7 +400,7 @@ namespace Dune::Solvers
         if ( printOutput )
           printLine( iter_, usedReg, 0, oldEnergy, 0, "Computing the new energy resulted in an error: " + std::string(e.what()) );
 
-        regUpdater_(regWeight, ProximalNewton::Stage::configuration );
+        regUpdater_(regWeight, ProxSolver::Stage::configuration );
         continue;
       }
 
@@ -405,7 +434,7 @@ namespace Dune::Solvers
         if ( printOutput )
           printLine( iter_, usedReg, 0, oldEnergy, 0, "The following descent criterion was not accepted: " + errorMessage );
 
-        regUpdater_(regWeight, ProximalNewton::Stage::descent );
+        regUpdater_(regWeight, ProxSolver::Stage::descent );
         continue;
       }
 
@@ -422,7 +451,7 @@ namespace Dune::Solvers
       /// Stage IV: Update the regularization weight for the next step  /////////////////////////
       ///////////////////////////////////////////////////////////////////////////////////////////
 
-      regUpdater_(regWeight, ProximalNewton::Stage::accepted );
+      regUpdater_(regWeight, ProxSolver::Stage::accepted );
 
 
       // seems like the step was accepted:
@@ -430,9 +459,24 @@ namespace Dune::Solvers
 
       // reset gradient and hessian since x is updated
       gradientPtr_.reset();
-      hessianPtr_.reset();
+
+      if constexpr ( diffOrder == 2 )
+      {
+        hessianPtr_.reset();
+      }
     }
   }
+
+  // some convenient definitions
+
+  // ProximalNewtonSolver is a ProximalSolver with differentiability order 2
+  template<class SEA, class NSF, class SPS, class V, class EN, class RU, class BV = DefaultBitVector_t<V>>
+  using ProximalNewtonSolver = ProximalSolver<SEA,NSF,SPS,V,EN,RU,BV,2>;
+
+  // ProximalGradientSolver is a ProximalSolver with differentiability order 1
+    template<class SEA, class NSF, class SPS, class V, class EN, class RU, class BV = DefaultBitVector_t<V>>
+  using ProximalGradientSolver = ProximalSolver<SEA,NSF,SPS,V,EN,RU,BV,1>;
+
 } // namespace Dune::Solvers
 
 
diff --git a/dune/solvers/test/CMakeLists.txt b/dune/solvers/test/CMakeLists.txt
index 5d7a0f5..f883aa9 100644
--- a/dune/solvers/test/CMakeLists.txt
+++ b/dune/solvers/test/CMakeLists.txt
@@ -26,6 +26,6 @@ endif()
 if(SuiteSparse_CHOLMOD_FOUND)
   dune_add_test(SOURCES cholmodsolvertest.cc)
   add_dune_suitesparse_flags(cholmodsolvertest)
-  dune_add_test(SOURCES proximalnewtonsolvertest.cc)
-  add_dune_suitesparse_flags(proximalnewtonsolvertest)
+  dune_add_test(SOURCES proximalsolvertest.cc)
+  add_dune_suitesparse_flags(proximalsolvertest)
 endif()
diff --git a/dune/solvers/test/proximalsolvertest.cc b/dune/solvers/test/proximalsolvertest.cc
new file mode 100644
index 0000000..22f5d9d
--- /dev/null
+++ b/dune/solvers/test/proximalsolvertest.cc
@@ -0,0 +1,226 @@
+// -*- tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 4 -*-
+// vi: set et ts=8 sw=4 sts=4:
+
+#include <config.h>
+
+#include <cmath>
+#include <iostream>
+#include <sstream>
+
+#include <dune/common/bitsetvector.hh>
+#include <dune/common/parallel/mpihelper.hh>
+
+#include <dune/solvers/norms/energynorm.hh>
+#include <dune/solvers/norms/twonorm.hh>
+#include <dune/solvers/solvers/cholmodsolver.hh>
+#include <dune/solvers/solvers/proximalsolver.hh>
+
+#include "common.hh"
+
+using namespace Dune;
+
+// grid dimension
+const int dim = 3;
+
+// f(x) = 1/4 * ||x||^4
+template<class Matrix, class Vector>
+struct F
+{
+  void assembleGradient( const Vector& x, Vector& grad ) const
+  {
+    auto norm2 = x.two_norm2();
+
+    grad = x;
+    grad *= norm2;
+  }
+  void assembleGradientAndHessian( const Vector& x, Vector& grad, Matrix& hess ) const
+  {
+    assembleGradient( x , grad );
+
+    auto norm2 = x.two_norm2();
+
+    for(size_t i=0; i<x.size(); i++)
+      for(size_t j=0; j<x.size(); j++)
+        hess[i][j] = (i==j)*norm2 + 2*x[i]*x[j];
+  }
+
+  double computeEnergy( const Vector& x ) const
+  {
+    return  0.25* x.two_norm2() * x.two_norm2();
+  }
+};
+
+
+// g(x) = ( -1, -1, -1, ... , -1)^T * x
+template<class Vector>
+struct G
+{
+  double operator()( const Vector& x ) const
+  {
+    auto realX = x;
+    realX += offset_;
+
+    double sum = 0.0;
+    for ( size_t i=0; i<realX.size(); i++ )
+      sum -= realX[i];
+
+    return sum;
+  }
+
+  void updateOffset( const Vector& offset )
+  {
+    offset_ = offset;
+  }
+
+  Vector offset_;
+};
+
+template<class Matrix, class Vector>
+struct SecondOrdersolver
+{
+  using MatrixType = Matrix;
+
+  // second order variant
+  template<class BitVector>
+  void minimize( const Matrix& hess, const Vector& grad, const G<Vector>& g, double reg, Vector& dx, const BitVector& /*ignore*/) const
+  {
+    // add reg
+    auto HH = hess;
+    for ( size_t i=0; i<dx.size(); i++ )
+      HH[i][i] += reg;
+
+    // add g to grad
+    auto gg = grad;
+    gg *= -1.0;
+    for ( size_t i=0; i<dx.size(); i++ )
+      gg[i] += 1.0;
+
+
+    Solvers::CholmodSolver cholmod( HH, dx, gg );
+    cholmod.solve();
+  }
+
+  // first order variant
+  template<class BitVector>
+  void minimize( const Vector& grad, const G<Vector>& g, double reg, Vector& dx, const BitVector& /*ignore*/) const
+  {
+    MatrixType HH = 0;
+    for ( size_t i=0; i<dx.size(); i++ )
+      HH[i][i] += reg;
+
+    // add g to grad
+    auto gg = grad;
+    gg *= -1.0;
+    for ( size_t i=0; i<dx.size(); i++ )
+      gg[i] += 1.0;
+
+
+    Solvers::CholmodSolver cholmod( HH, dx, gg );
+    cholmod.solve();
+  }
+
+};
+
+
+
+int main (int argc, char *argv[])
+{
+  // initialize MPI, finalize is done automatically on exit
+  [[maybe_unused]] MPIHelper& mpiHelper = MPIHelper::instance(argc, argv);
+
+  const int size = 10;
+
+  using Vector = FieldVector<double,size>;
+  using Matrix = FieldMatrix<double,size>;
+  using BitVector = std::bitset<size>;
+
+  // this is the analytical solution of the problem
+  Vector exactSol, zeroVector;
+  exactSol = std::pow( size, -1.0/3.0 );
+
+  zeroVector = 0.0;
+
+  F<Matrix,Vector> f;
+  G<Vector> g;
+  g.updateOffset(zeroVector);
+  SecondOrdersolver<Matrix,Vector> sos;
+
+
+  // choose some random initial vector
+  Vector sol;
+  for ( size_t i=0; i<sol.size(); i++ )
+    sol[i] = i;
+
+  TwoNorm<Vector> norm;
+
+  // create the solver
+  Solvers::ProxSolver::SimpleRegUpdater regUpdater;
+
+  bool passed = true;
+
+  // Proximal Newton
+  {
+    double reg = 42.0;
+    auto fPN = f;
+    auto gPN = g;
+    auto solPN = sol;
+    Solvers::ProximalNewtonSolver proximalNewtonSolver( fPN, gPN, sos, solPN, norm, regUpdater, reg, 100, 1e-14, Solver::FULL);
+
+    // set some empty ignore field
+    BitVector ignore;
+    proximalNewtonSolver.setIgnore( ignore );
+
+    // go!
+    proximalNewtonSolver.solve();
+
+    // compute diff the exact solution
+    solPN -= exactSol;
+
+    std::cout << "Proximal Newton: Difference to exact solution = " << solPN.two_norm() << std::endl;
+
+    passed = passed && solPN.two_norm() < 1e-15;
+  }
+
+  // Proximal Gradient
+  {
+    double reg = 42.0;
+    auto fPG = f;
+    auto gPG = g;
+    auto solPG = sol;
+    Solvers::ProximalGradientSolver proximalGradientSolver( fPG, gPG, sos, solPG, norm, regUpdater, reg, 100, 1e-7, Solver::FULL);
+
+    // set some empty ignore field
+    BitVector ignore;
+    proximalGradientSolver.setIgnore( ignore );
+
+    // set a decrease criterion
+    proximalGradientSolver.addDescentCriterion(
+      [&](){
+        auto correction = proximalGradientSolver.correction();
+        auto zeroVector = correction;
+        zeroVector *= 0.0;
+
+        auto solPlusCorrection = solPG;
+        solPlusCorrection += correction;
+
+        auto oldEnergy = fPG.computeEnergy( solPG )             + gPG( zeroVector );
+        auto newEnergy = fPG.computeEnergy( solPlusCorrection ) + gPG( correction );
+
+        auto actualDescent = newEnergy - oldEnergy;
+
+        return std::tuple( actualDescent < 0 , std::string("no descent!"));
+      },"");
+
+    // go!
+    proximalGradientSolver.solve();
+
+    // compute diff the exact solution
+    solPG -= exactSol;
+
+    std::cout << "Proximal Gradient: Difference to exact solution = " << solPG.two_norm() << std::endl;
+
+    passed = passed && solPG.two_norm() < 1e-7;
+  }
+
+  return passed ? 0 : 1;
+}
-- 
GitLab