Commit 30617e82 authored by graeser's avatar graeser
Browse files

Merge branch 'feature/gradient-aware-truncation' into 'master'

Allow to make truncation gradient aware

See merge request !27
parents 3cd463a5 ca025a83
Pipeline #43137 passed with stage
in 7 minutes
......@@ -43,6 +43,23 @@ class BoxConstrainedQuadraticFunctionalConstrainedLinearization
}
}
template<class NV, class NBV, class T>
static void determineGradientAwareTruncation(const NV& x, const NV& lower, const NV& upper, const NV& negativeGradient, NBV&& truncationFlags, const T& truncationTolerance)
{
namespace H = Dune::Hybrid;
if constexpr (IsNumber<NV>())
{
if (((x <= lower+truncationTolerance) and (negativeGradient<0) ) || ((x >= upper - truncationTolerance) and (negativeGradient>0)))
truncationFlags = true;
}
else
{
H::forEach(H::integralRange(H::size(x)), [&](auto&& i) {
This::determineGradientAwareTruncation(x[i], lower[i], upper[i], negativeGradient[i], truncationFlags[i], truncationTolerance);
});
}
}
template<class NV, class NBV>
static void truncateVector(NV& x, const NBV& truncationFlags)
{
......@@ -99,7 +116,8 @@ public:
f_(f),
ignore_(ignore),
truncationTolerance_(1e-10),
regularizeDiagonal_(false)
regularizeDiagonal_(false),
gradientAwareTruncation_(false)
{}
void setTruncationTolerance(double tolerance)
......@@ -125,6 +143,26 @@ public:
regularizeDiagonal_ = regularizeDiagonal;
}
/**
* \brief Enable/disable gradient aware truncation
*
* If this is disabled (default) entries are truncated
* if they are close enough to the upper or lower obstacle.
*
* If this is enabled, it is additionally checked, if the
* gradient points inside of the admissible set for those
* entries. I.e. in order to be truncated, entries
* touching the lower obstacle are required to have a
* positive partial derivative, while entries touching
* the upper obstacle are required to have a negative
* partial derivative.
* This may lead to a less pessimistic active set.
*/
void enableGradientAwareTruncation(bool gradientAwareTruncation)
{
gradientAwareTruncation_ = gradientAwareTruncation;
}
void bind(const Vector& x)
{
negativeGradient_ = derivative(f_)(x);
......@@ -133,7 +171,10 @@ public:
truncationFlags_ = ignore_;
// determine which components to truncate
determineTruncation(x, f_.lowerObstacle(), f_.upperObstacle(), truncationFlags_, truncationTolerance_);
if (gradientAwareTruncation_)
determineGradientAwareTruncation(x, f_.lowerObstacle(), f_.upperObstacle(), negativeGradient_, truncationFlags_, truncationTolerance_);
else
determineTruncation(x, f_.lowerObstacle(), f_.upperObstacle(), truncationFlags_, truncationTolerance_);
// truncate gradient and hessian
truncateVector(negativeGradient_, truncationFlags_);
......@@ -167,6 +208,7 @@ private:
double truncationTolerance_;
bool regularizeDiagonal_;
bool gradientAwareTruncation_;
Vector negativeGradient_;
Matrix hessian_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment