diff --git a/dune/solvers/iterationsteps/cgstep.cc b/dune/solvers/iterationsteps/cgstep.cc index d4e61031e4819efe79cc3a399bc0298f927b7d5c..f95821e45df9214e3efe0be84000da3fd030feb3 100644 --- a/dune/solvers/iterationsteps/cgstep.cc +++ b/dune/solvers/iterationsteps/cgstep.cc @@ -1,9 +1,22 @@ +template <class MatrixType, class VectorType> +void CGStep<MatrixType, VectorType>::check() const +{ + if (preconditioner_) + preconditioner_->check(); +} + template <class MatrixType, class VectorType> void CGStep<MatrixType, VectorType>::preprocess() { // Compute the residual (r starts out as the rhs) matrix_.mmv(x_,r_); - p_ = r_; + + if (preconditioner_) { + preconditioner_->setMatrix(matrix_); + preconditioner_->apply(p_, r_); + } else + p_ = r_; + r_squared_old = p_*r_; } @@ -17,9 +30,14 @@ void CGStep<MatrixType, VectorType>::iterate() x_.axpy(alpha, p_); // x_1 = x_0 + alpha_0 p_0 r_.axpy(-alpha, q); // r_1 = r_0 - alpha_0 Ap_0 - const double r_squared = r_ * r_; + if (preconditioner_) + preconditioner_->apply(q, r_); + else + q = r_; + + const double r_squared = q * r_; const double beta = r_squared / r_squared_old; // beta_0 = r_1*r_1/ (r_0*r_0) p_ *= beta; // p_1 = r_1 + beta_0 p_0 - p_ += r_; + p_ += q; r_squared_old = r_squared; } diff --git a/dune/solvers/iterationsteps/cgstep.hh b/dune/solvers/iterationsteps/cgstep.hh index 77e74f58465de6874f3fc5d1ad2a5a5756a0f1b9..f782bae6103870dfa476b0d2b721355f5edd12ca 100644 --- a/dune/solvers/iterationsteps/cgstep.hh +++ b/dune/solvers/iterationsteps/cgstep.hh @@ -1,19 +1,30 @@ #ifndef DUNE_SOLVERS_ITERATIONSTEPS_CGSTEP_HH #define DUNE_SOLVERS_ITERATIONSTEPS_CGSTEP_HH +#include <dune/solvers/common/preconditioner.hh> #include <dune/solvers/iterationsteps/iterationstep.hh> namespace Dune { namespace Solvers { + //! A conjugate gradient solver template <class MatrixType, class VectorType> class CGStep : public IterationStep<VectorType> { public: CGStep(const MatrixType& matrix, - VectorType& x, - const VectorType& rhs) - : p_(rhs.size()), r_(rhs), x_(x), matrix_(matrix) + VectorType& x, + const VectorType& rhs) + : p_(rhs.size()), r_(rhs), x_(x), matrix_(matrix), + preconditioner_(nullptr) + {} + + CGStep(const MatrixType& matrix, + VectorType& x, + const VectorType& rhs, + Preconditioner<MatrixType, VectorType>& preconditioner) + : p_(rhs.size()), r_(rhs), x_(x), matrix_(matrix), + preconditioner_(&preconditioner) {} void check() const; @@ -28,6 +39,7 @@ namespace Dune { VectorType& x_; const MatrixType& matrix_; double r_squared_old; + Preconditioner<MatrixType, VectorType>* preconditioner_; }; #include "cgstep.cc"