Skip to content
Snippets Groups Projects
Commit 9350852c authored by Elias Pipping's avatar Elias Pipping Committed by Elias Pipping
Browse files

Add more flexible modified gradient method

parent eef21539
Branches
No related tags found
No related merge requests found
......@@ -3,11 +3,15 @@ SUBDIRS =
noinst_PROGRAMS = \
bisection-example \
bisection-example-flexible \
bisection-simpler-example \
bisection-simpler-example2 \
bisection-simpler-example2-gradient
bisection_example_SOURCES = bisection-example.cc
bisection_example_flexible_SOURCES = \
bisection-example-flexible.cc \
properscalarincreasingconvexfunction.hh
bisection_simpler_example_SOURCES = bisection-simpler-example.cc
bisection_simpler_example2_SOURCES = bisection-simpler-example2.cc
bisection_simpler_example2_gradient_SOURCES = bisection-simpler-example2-gradient.cc
......@@ -17,6 +21,7 @@ check-am:
./bisection-simpler-example2
./bisection-simpler-example2-gradient
./bisection-example
./bisection-example-flexible
AM_CXXFLAGS = -Wall -Wextra
......
/* -*- mode:c++; mode: flymake -*- */
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include <dune/common/exceptions.hh>
#include <dune/common/stdstreams.hh>
#include <dune/fufem/interval.hh>
#include <dune/tnnmg/nonlinearities/smallfunctional.hh>
#include <dune/tnnmg/problem-classes/bisection.hh>
#include <cassert>
#include <cstdlib>
#include <limits>
#include "properscalarincreasingconvexfunction.hh"
// TODO: default to TrivialFunction
template <int dimension, class Function = TrivialFunction>
class SampleFunctional {
public:
typedef Dune::FieldVector<double, dimension> SmallVector;
typedef Dune::FieldMatrix<double, dimension, dimension> SmallMatrix;
// FIXME: hardcoded function H
SampleFunctional(SmallMatrix A, SmallVector b) : A_(A), b_(b), func_() {}
double operator()(const SmallVector v) const {
SmallVector y;
A_.mv(v, y); // y = Av
y /= 2; // y = 1/2 Av
y -= b_; // y = 1/2 Av - b
return y * v + func_(v.two_norm()); // <1/2 Av - b,v> + H(|v|)
}
double directionalDerivative(const SmallVector x,
const SmallVector dir) const {
if (x == SmallVector(0.0))
// Well, not in this way -- but can we compute them?
DUNE_THROW(Dune::Exception,
"Directional derivatives cannot be computed at zero.");
if (x * dir > 0)
return PlusGrad(x) * dir;
else
return MinusGrad(x) * dir;
}
SmallVector minimise(const SmallVector x, unsigned int iterations) const {
SmallVector descDir = ModifiedGradient(x);
Dune::dverb << "Starting at x with J(x) = " << operator()(x) << std::endl;
Dune::dverb << "Minimizing in direction w with dJ(x,w) = "
<< directionalDerivative(x, descDir) << std::endl;
double l = 0;
double r = 1;
SmallVector tmp;
while (true) {
tmp = x;
tmp.axpy(r, descDir);
if (directionalDerivative(tmp, descDir) >= 0)
break;
l = r;
r *= 2;
Dune::dverb << "Widened interval!" << std::endl;
}
Dune::dverb << "Interval now [" << l << "," << r << "]" << std::endl;
// Debugging
{
SmallVector tmpl = x;
tmpl.axpy(l, descDir);
SmallVector tmpr = x;
tmpr.axpy(r, descDir);
assert(directionalDerivative(tmpl, descDir) < 0);
assert(directionalDerivative(tmpr, descDir) > 0);
}
double m = l / 2 + r / 2;
SmallVector middle;
for (unsigned int count = 0; count < iterations; ++count) {
Dune::dverb << "now at m = " << m << std::endl;
Dune::dverb << "Value of J here: " << operator()(x + middle) << std::endl;
middle = descDir;
middle *= m;
double derivative = directionalDerivative(x + middle, descDir);
if (derivative < 0) {
l = m;
m = (m + r) / 2;
} else if (derivative > 0) {
r = m;
m = (l + m) / 2;
} else
break;
}
return middle;
}
private:
SmallMatrix A_;
SmallVector b_;
Function func_;
// Gradient of the smooth part
SmallVector SmoothGrad(const SmallVector x) const {
SmallVector y;
A_.mv(x, y); // y = Av
y -= b_; // y = Av - b
return y;
}
SmallVector PlusGrad(const SmallVector x) const {
SmallVector y = SmoothGrad(x);
y.axpy(func_.rightDifferential(x.two_norm()) / x.two_norm(), x);
return y;
}
SmallVector MinusGrad(const SmallVector x) const {
SmallVector y = SmoothGrad(x);
y.axpy(func_.leftDifferential(x.two_norm()) / x.two_norm(), x);
return y;
}
SmallVector ModifiedGradient(const SmallVector x) const {
if (x == SmallVector(0.0))
// TODO
DUNE_THROW(Dune::Exception, "The case x = 0 is not yet handled.");
SmallVector const pg = PlusGrad(x);
SmallVector const mg = MinusGrad(x);
SmallVector ret;
// TODO: collinearity checks suck
if (pg * x == pg.two_norm() * x.two_norm() &&
-(mg * x) == mg.two_norm() * x.two_norm()) {
return SmallVector(0);
} else if (pg * x >= 0 && mg * x >= 0) {
ret = pg;
} else if (pg * x <= 0 && mg * x <= 0) {
ret = mg;
} else {
ret = project(SmoothGrad(x), x);
}
ret *= -1;
return ret;
}
SmallVector project(const SmallVector z, const SmallVector x) const {
SmallVector y = z;
y.axpy(-(z * x) / x.two_norm2(), x);
return y;
}
};
void testSampleFunction() {
int const dim = 2;
typedef SampleFunctional<dim, SampleFunction> SampleFunctional;
SampleFunctional::SmallMatrix A;
A[0][0] = 3;
A[0][1] = 0;
A[1][0] = 0;
A[1][1] = 3;
SampleFunctional::SmallVector b;
b[0] = 1;
b[1] = 2;
SampleFunctional J(A, b);
std::cout << J.directionalDerivative(b, b) << std::endl;
assert(J.directionalDerivative(b, b) == 10 + 2 * sqrt(5));
SampleFunctional::SmallVector start = b;
start *= 17;
SampleFunctional::SmallVector correction = J.minimise(start, 20);
assert(J(start + correction) <= J(start));
assert(std::abs(J(start + correction) + 0.254644) < 1e-8);
std::cout << J(start + correction) << std::endl;
}
void testTrivialFunction() {
int const dim = 2;
typedef SampleFunctional<dim> SampleFunctional;
SampleFunctional::SmallMatrix A;
A[0][0] = 3;
A[0][1] = 0;
A[1][0] = 0;
A[1][1] = 3;
SampleFunctional::SmallVector b;
b[0] = 1;
b[1] = 2;
SampleFunctional J(A, b);
std::cout << J.directionalDerivative(b, b) << std::endl;
assert(J.directionalDerivative(b, b) == 10);
SampleFunctional::SmallVector start = b;
start *= 17;
SampleFunctional::SmallVector correction = J.minimise(start, 20);
assert(J(start + correction) <= J(start));
assert(std::abs(J(start + correction) + 0.833333) < 1e-6);
std::cout << J(start + correction) << std::endl;
}
int main() {
try {
testSampleFunction();
testTrivialFunction();
return 0;
}
catch (Dune::Exception &e) {
Dune::derr << "Dune reported error: " << e << std::endl;
}
}
class ProperScalarIncreasingConvexFunction {
public:
virtual double operator()(const double s) const = 0;
virtual double leftDifferential(const double s) const = 0;
virtual double rightDifferential(const double s) const = 0;
};
class SampleFunction : public ProperScalarIncreasingConvexFunction {
public:
double operator()(const double s) const { return (s < 1) ? s : (2 * s - 1); }
double leftDifferential(const double s) const { return (s <= 1) ? 1 : 2; }
double rightDifferential(const double s) const { return (s < 1) ? 1 : 2; }
};
class TrivialFunction : public ProperScalarIncreasingConvexFunction {
public:
double operator()(const double) const { return 0; }
double leftDifferential(const double) const { return 0; }
double rightDifferential(const double) const { return 0; }
};
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment