Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
axpy.hh 13.66 KiB
#ifndef DUNE_MATRIX_VECTOR_AXPY_HH
#define DUNE_MATRIX_VECTOR_AXPY_HH

#include <type_traits>

#include <dune/common/diagonalmatrix.hh>
#include <dune/common/fmatrix.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/scaledidmatrix.hh>

#include <dune/matrix-vector/algorithm.hh>
#include <dune/matrix-vector/traits/utilities.hh>

namespace Dune {
namespace MatrixVector {

  /// forward declarations of internal helper classes for product operations
  template <class A, class B, class C, bool AisScalar, bool BisScalar,
            bool CisScalar>
  struct ProductHelper;

  template <class A, class Scalar, class B, class C, bool AisScalar,
            bool BisScalar, bool CisScalar>
  struct ScaledProductHelper;

  /** \brief Add a product to some matrix or vector
   *
   * This function computes a+=b*c.
   *
   * This function should tolerate all meaningful
   * combinations of scalars, vectors, and matrices.
   *
   * a,b,c could be matrices with appropriate
   * dimensions. But b can also always be a scalar
   * represented by a 1-dim vector or a 1 by 1 matrix.
   */
  template <class A, class B, class C>
  void addProduct(A& a, const B& b, const C& c) {
    ProductHelper<A, B, C, isScalar<A>(), isScalar<B>(),
                  isScalar<C>()>::addProduct(a, b, c);
  }

  /** \brief Subtract a product from some matrix or vector
   *
   * This function computes a-=b*c.
   *
   * This function should tolerate all meaningful
   * combinations of scalars, vectors, and matrices.
   *
   * a,b,c could be matrices with appropriate
   * dimensions. But b can also always be a scalar
   * represented by a 1-dim vector or a 1 by 1 matrix.
   */
  template <class A, class B, class C>
  void subtractProduct(A& a, const B& b, const C& c) {
    ScaledProductHelper<A, int, B, C, isScalar<A>(), isScalar<B>(),
                        isScalar<C>()>::addProduct(a, -1, b, c);
  }

  /** \brief Add a scaled product to some matrix or vector
   *
   * This function computes a+=b*c*d.
   *
   * This function should tolerate all meaningful
   * combinations of scalars, vectors, and matrices.
   *
   * a,c,d could be matrices with appropriate dimensions. But b must
   * (currently) and c can also always be a scalar represented by a
   * 1-dim vector or a 1 by 1 matrix.
   */
  template <class A, class B, class C, class D>
  EnableScalar<B> addProduct(
      A& a, const B& b, const C& c, const D& d) {
    ScaledProductHelper<A, B, C, D, isScalar<A>(), isScalar<C>(),
                        isScalar<D>()>::addProduct(a, b, c, d);
  }

  /** \brief Subtract a scaled product from some matrix or vector
   *
   * This function computes a-=b*c*d.
   *
   * This function should tolerate all meaningful
   * combinations of scalars, vectors, and matrices.
   *
   * a,c,d could be matrices with appropriate dimensions. But b must
   * (currently) and c can also always be a scalar represented by a
   * 1-dim vector or a 1 by 1 matrix.
   */
  template <class A, class B, class C, class D>
  EnableScalar<B>
  subtractProduct(A& a, const B& b, const C& c, const D& d) {
    ScaledProductHelper<A, B, C, D, isScalar<A>(), isScalar<C>(),
                        isScalar<D>()>::addProduct(a, -b, c, d);
  }

  template <class A, class... BCD>
  A getAddProduct(const A& a, BCD&&... bcd) {
    auto sum = a;
    addProduct(sum, std::forward<BCD>(bcd)...);
    return sum;
  }

  template <class A, class... BCD>
  A getSubtractProduct(const A& a, BCD&&... bcd) {
    auto difference = a;
    subtractProduct(difference, std::forward<BCD>(bcd)...);
    return difference;
  }


  /** \brief Internal helper class for product operations
   *
   */
  template <class A, class B, class C, bool AisScalar, bool BisScalar,
            bool CisScalar>
  struct ProductHelper {
    template <class ADummy = A, DisableMatrix<ADummy, int> = 0>
    static void addProduct(A& a, const B& b, const C& c) {
      b.umv(c, a);
    }

    template <class ADummy = A, EnableMatrix<ADummy, int> = 0>
    static void addProduct(A& a, const B& b, const C& c) {
      sparseRangeFor(b, [&](auto&& bi, auto&& i) {
        sparseRangeFor(bi, [&](auto&& bik, auto&& k) {
          sparseRangeFor(c[k], [&](auto&& ckj, auto&& j) {
            Dune::MatrixVector::addProduct(a[i][j], bik, ckj);
          });
        });
      });
    }
  };

  // Internal helper class for scaled product operations (i.e., b is always a
  // scalar)
  template <class A, class Scalar, class B, class C, bool AisScalar,
            bool BisScalar, bool CisScalar>
  struct ScaledProductHelper {
    template <class ADummy = A, DisableMatrix<ADummy, int> = 0>
    static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) {
      b.usmv(scalar, c, a);
    }

    template <class ADummy = A, EnableMatrix<ADummy, int> = 0>
    static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) {
      sparseRangeFor(b, [&](auto&& bi, auto&& i) {
        sparseRangeFor(bi, [&](auto&& bik, auto&& k) {
          sparseRangeFor(c[k], [&](auto&& ckj, auto&& j) {
            Dune::MatrixVector::addProduct(a[i][j], scalar, bik, ckj);
          });
        });
      });
    }
  };

  template <class T, int n, int m, int k>
  struct ProductHelper<Dune::FieldMatrix<T, n, k>, Dune::FieldMatrix<T, n, m>,
                       Dune::FieldMatrix<T, m, k>, false, false, false> {
    typedef Dune::FieldMatrix<T, n, k> A;
    typedef Dune::FieldMatrix<T, n, m> B;
    typedef Dune::FieldMatrix<T, m, k> C;
    static void addProduct(A& a, const B& b, const C& c) {
      for (size_t row = 0; row < n; ++row)
        for (size_t col = 0; col < k; ++col)
          for (size_t i = 0; i < m; ++i)
            a[row][col] += b[row][i] * c[i][col];
    }
  };

  template <class T, int n, int m, int k>
  struct ScaledProductHelper<Dune::FieldMatrix<T, n, k>, T,
                             Dune::FieldMatrix<T, n, m>,
                             Dune::FieldMatrix<T, m, k>, false, false, false> {
    typedef Dune::FieldMatrix<T, n, k> A;
    typedef Dune::FieldMatrix<T, n, m> C;
    typedef Dune::FieldMatrix<T, m, k> D;
    static void addProduct(A& a, const T& b, const C& c, const D& d) {
      for (size_t row = 0; row < n; ++row)
        for (size_t col = 0; col < k; ++col)
          for (size_t i = 0; i < m; ++i)
            a[row][col] += b * c[row][i] * d[i][col];
    }
  };

  template <class T, int n>
  struct ProductHelper<Dune::DiagonalMatrix<T, n>, Dune::DiagonalMatrix<T, n>,
                       Dune::DiagonalMatrix<T, n>, false, false, false> {
    typedef Dune::DiagonalMatrix<T, n> A;
    typedef A B;
    typedef B C;
    static void addProduct(A& a, const B& b, const C& c) {
      for (size_t i = 0; i < n; ++i)
        a.diagonal(i) += b.diagonal(i) * c.diagonal(i);
    }
  };

  template <class T, int n>
  struct ScaledProductHelper<Dune::DiagonalMatrix<T, n>, T,
                             Dune::DiagonalMatrix<T, n>,
                             Dune::DiagonalMatrix<T, n>, false, false, false> {
    typedef Dune::DiagonalMatrix<T, n> A;
    typedef A C;
    typedef C D;
    static void addProduct(A& a, const T& b, const C& c, const D& d) {
      for (size_t i = 0; i < n; ++i)
        a.diagonal(i) += b * c.diagonal(i) * d.diagonal(i);
    }
  };

  template <class T, int n>
  struct ProductHelper<Dune::FieldMatrix<T, n, n>, Dune::DiagonalMatrix<T, n>,
                       Dune::DiagonalMatrix<T, n>, false, false, false> {
    typedef Dune::FieldMatrix<T, n, n> A;
    typedef Dune::DiagonalMatrix<T, n> B;
    typedef B C;
    static void addProduct(A& a, const B& b, const C& c) {
      for (size_t i = 0; i < n; ++i)
        a[i][i] += b.diagonal(i) * c.diagonal(i);
    }
  };

  template <class T, int n>
  struct ScaledProductHelper<Dune::FieldMatrix<T, n, n>, T,
                             Dune::DiagonalMatrix<T, n>,
                             Dune::DiagonalMatrix<T, n>, false, false, false> {
    typedef Dune::FieldMatrix<T, n, n> A;
    typedef Dune::DiagonalMatrix<T, n> C;
    typedef C D;
    static void addProduct(A& a, const T& b, const C& c, const D& d) {
      for (size_t i = 0; i < n; ++i)
        a[i][i] += b * c.diagonal(i) * d.diagonal(i);
    }
  };

  template <class T, int n, int m>
  struct ProductHelper<Dune::FieldMatrix<T, n, m>, Dune::DiagonalMatrix<T, n>,
                       Dune::FieldMatrix<T, n, m>, false, false, false> {
    typedef Dune::FieldMatrix<T, n, m> A;
    typedef Dune::DiagonalMatrix<T, n> B;
    typedef A C;
    static void addProduct(A& a, const B& b, const C& c) {
      for (size_t i = 0; i < n; ++i)
        a[i].axpy(b.diagonal(i), c[i]);
    }
  };

  template <class T, int n, int m>
  struct ScaledProductHelper<Dune::FieldMatrix<T, n, m>, T,
                             Dune::DiagonalMatrix<T, n>,
                             Dune::FieldMatrix<T, n, m>, false, false, false> {
    typedef Dune::FieldMatrix<T, n, m> A;
    typedef Dune::DiagonalMatrix<T, n> C;
    typedef A D;
    static void addProduct(A& a, const T& b, const C& c, const D& d) {
      for (size_t i = 0; i < n; ++i)
        a[i].axpy(b * c.diagonal(i), d[i]);
    }
  };

  template <class T, int n, int m>
  struct ProductHelper<Dune::FieldMatrix<T, n, m>, Dune::FieldMatrix<T, n, m>,
                       Dune::DiagonalMatrix<T, m>, false, false, false> {
    typedef Dune::FieldMatrix<T, n, m> A;
    typedef Dune::DiagonalMatrix<T, m> C;
    typedef A B;

    static void addProduct(A& a, const B& b, const C& c) {
      for (size_t i = 0; i < n; ++i)
        for (size_t j = 0; j < m; j++)
          a[i][j] += c.diagonal(j) * b[i][j];
    }
  };

  template <class T, int n, int m>
  struct ScaledProductHelper<Dune::FieldMatrix<T, n, m>, T,
                             Dune::FieldMatrix<T, n, m>,
                             Dune::DiagonalMatrix<T, m>, false, false, false> {
    typedef Dune::FieldMatrix<T, n, m> A;
    typedef Dune::DiagonalMatrix<T, m> D;
    typedef A C;

    static void addProduct(A& a, const T& b, const C& c, const D& d) {
      for (size_t i = 0; i < n; ++i)
        for (size_t j = 0; j < m; j++)
          a[i][j] += b * d.diagonal(j) * c[i][j];
    }
  };

  template <class T, int n>
  struct ProductHelper<Dune::ScaledIdentityMatrix<T, n>,
                       Dune::ScaledIdentityMatrix<T, n>,
                       Dune::ScaledIdentityMatrix<T, n>, false, false, false> {
    typedef Dune::ScaledIdentityMatrix<T, n> A;
    typedef A B;
    typedef B C;

    static void addProduct(A& a, const B& b, const C& c) {
      a.scalar() += b.scalar() * c.scalar();
    }
  };

  template <class T, class Scalar, int n>
  struct ScaledProductHelper<Dune::ScaledIdentityMatrix<T, n>, Scalar,
                             Dune::ScaledIdentityMatrix<T, n>,
                             Dune::ScaledIdentityMatrix<T, n>, false, false,
                             false> {
    typedef Dune::ScaledIdentityMatrix<T, n> A;
    typedef A C;
    typedef C D;

    static void addProduct(A& a, const Scalar& b, const C& c, const D& d) {
      a.scalar() += b * c.scalar() * d.scalar();
    }
  };

  /** \brief Specialization for b being a scalar type
    */
  template <class A, class ScalarB, class C, bool AisScalar, bool CisScalar>
  struct ProductHelper<A, ScalarB, C, AisScalar, true, CisScalar> {
    typedef ScalarB B;

    template <class ADummy = A, DisableMatrix<ADummy, int> = 0>
    static void addProduct(A& a, const B& b, const C& c) {
      a.axpy(b, c);
    }

    template <class ADummy = A, EnableMatrix<ADummy, int> = 0>
    static void addProduct(A& a, const B& b, const C& c) {
      sparseRangeFor(c, [&](auto&& ci, auto && i) {
        sparseRangeFor(ci, [&](auto&& cij, auto && j) {
          Dune::MatrixVector::addProduct(a[i][j], b, cij);
        });
      });
    }
  };

  template <class A, class Scalar, class ScalarB, class C, bool AisScalar,
            bool CisScalar>
  struct ScaledProductHelper<A, Scalar, ScalarB, C, AisScalar, true,
                             CisScalar> {
    typedef ScalarB B;

    template <class ADummy = A, DisableMatrix<ADummy, int> = 0>
    static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) {
      a.axpy(scalar * b, c);
    }

    template <class ADummy = A, EnableMatrix<ADummy, int> = 0>
    static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) {
      sparseRangeFor(c, [&](auto&& ci, auto&& i) {
        sparseRangeFor(ci, [&](auto&& cij, auto&& j) {
          Dune::MatrixVector::addProduct(a[i][j], scalar, b, cij);
        });
      });
    }
  };

  template <class T, int n, class ScalarB>
  struct ProductHelper<Dune::ScaledIdentityMatrix<T, n>, ScalarB,
                       Dune::ScaledIdentityMatrix<T, n>, false, true, false> {
    typedef Dune::ScaledIdentityMatrix<T, n> A;
    typedef ScalarB B;
    typedef Dune::ScaledIdentityMatrix<T, n> C;

    static void addProduct(A& a, const B& b, const C& c) {
      a.scalar() += b * c.scalar();
    }
  };

  template <class T, int n, class Scalar, class ScalarB>
  struct ScaledProductHelper<Dune::ScaledIdentityMatrix<T, n>, Scalar, ScalarB,
                             Dune::ScaledIdentityMatrix<T, n>, false, true,
                             false> {
    typedef Dune::ScaledIdentityMatrix<T, n> A;
    typedef ScalarB B;
    typedef Dune::ScaledIdentityMatrix<T, n> C;

    static void addProduct(A& a, const Scalar& scalar, const B& b, const C& c) {
      a.scalar() += scalar * b * c.scalar();
    }
  };

  template <class A, class B, class C>
  struct ProductHelper<A, B, C, true, true, true> {
    static void addProduct(A& a, const B& b, const C& c) { a += b * c; }
  };

  template <class A, class B, class C, class D>
  struct ScaledProductHelper<A, B, C, D, true, true, true> {
    static void addProduct(A& a, const B& b, const C& c, const D& d) {
      a += b * c * d;
    }
  };

} // end namespace MatrixVector
} // end namespace Dune

#endif // DUNE_MATRIX_VECTOR_AXPY_HH