Skip to content
Snippets Groups Projects
Commit 24803525 authored by Max Kahnt's avatar Max Kahnt
Browse files

Allow different types for transformations T1 and T2.

parent 81df1fe8
Branches
No related tags found
No related merge requests found
......@@ -14,12 +14,14 @@ namespace MatrixVector {
// add transformed matrix A += T1^t*B*T2
// ******************************************************
template <class MatrixA, class MatrixB, class TransformationMatrix,
bool AisScalar, bool BisScalar, bool TisScalar>
template <class MatrixA, class MatrixB, class TransformationMatrix1,
class TransformationMatrix2, bool AisScalar, bool BisScalar,
bool T1isScalar, bool T2isScalar>
struct TransformMatrixHelper {
static void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1,
static void addTransformedMatrix(MatrixA& A,
const TransformationMatrix1& T1,
const MatrixB& B,
const TransformationMatrix& T2) {
const TransformationMatrix2& T2) {
for (size_t i = 0; i < A.N(); ++i) {
for (auto jIt = A[i].begin(); jIt != A[i].end(); ++jIt) {
for (size_t k = 0; k < B.N(); ++k) {
......@@ -72,7 +74,8 @@ namespace MatrixVector {
template <class K1, class K2, class K3, int n, int m>
struct TransformMatrixHelper<
Dune::FieldMatrix<K1, m, m>, Dune::FieldMatrix<K3, n, n>,
Dune::FieldMatrix<K2, n, m>, false, false, false> {
Dune::FieldMatrix<K2, n, m>, Dune::FieldMatrix<K2, n, m>, false, false,
false, false> {
typedef Dune::FieldMatrix<K1, m, m> MatrixA;
typedef Dune::FieldMatrix<K3, n, n> MatrixB;
typedef Dune::FieldMatrix<K2, n, m> TransformationMatrix;
......@@ -111,25 +114,30 @@ namespace MatrixVector {
// }
};
template <class MatrixA, class MatrixB, class ScalarTransform, bool AisScalar,
bool BisScalar>
struct TransformMatrixHelper<MatrixA, MatrixB, ScalarTransform, AisScalar,
BisScalar, true> {
static void addTransformedMatrix(MatrixA& A, const ScalarTransform& T1,
template <class MatrixA, class MatrixB, class ScalarTransform1,
class ScalarTransform2, bool AisScalar, bool BisScalar>
struct TransformMatrixHelper<MatrixA, MatrixB, ScalarTransform1,
ScalarTransform2, AisScalar, BisScalar, true,
true> {
static void addTransformedMatrix(MatrixA& A, const ScalarTransform1& T1,
const MatrixB& B,
const ScalarTransform& T2) {
const ScalarTransform2& T2) {
addProduct(A, T1 * T2, B);
}
};
template <class MatrixA, class ScalarB, class TransformationMatrix,
bool AisScalar, bool TisScalar>
struct TransformMatrixHelper<MatrixA, ScalarB, TransformationMatrix,
AisScalar, true, TisScalar> {
static void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1,
template <class MatrixA, class ScalarB, class TransformationMatrix1,
class TransformationMatrix2, bool AisScalar, bool T1isScalar,
bool T2isScalar>
struct TransformMatrixHelper<MatrixA, ScalarB, TransformationMatrix1,
TransformationMatrix2, AisScalar, true,
T1isScalar, T2isScalar> {
static void addTransformedMatrix(MatrixA& A,
const TransformationMatrix1& T1,
const ScalarB& B,
const TransformationMatrix& T2) {
for (size_t k = 0; k < TransformationMatrix::rows; ++k) {
const TransformationMatrix2& T2) {
assert(TransformationMatrix1::rows == TransformationMatrix2::rows);
for (size_t k = 0; k < TransformationMatrix1::rows; ++k) {
for (auto Skj = T2[k].begin(); Skj != T2[k].end(); ++Skj) {
for (auto Tki = T1[k].begin(); Tki != T1[k].end(); Tki++)
if (A.exists(Tki.index(), Skj.index()))
......@@ -139,41 +147,44 @@ namespace MatrixVector {
}
};
template <class FieldType, int n, class ScalarB, class TransformationMatrix,
bool AisScalar, bool TisScalar>
struct TransformMatrixHelper<Dune::ScaledIdentityMatrix<FieldType, n>,
ScalarB, TransformationMatrix, AisScalar, true,
TisScalar> {
template <class FieldType, int n, class ScalarB, class TransformationMatrix1,
class TransformationMatrix2, bool AisScalar, bool T1isScalar,
bool T2isScalar>
struct TransformMatrixHelper<
Dune::ScaledIdentityMatrix<FieldType, n>, ScalarB, TransformationMatrix1,
TransformationMatrix2, AisScalar, true, T1isScalar, T2isScalar> {
typedef Dune::ScaledIdentityMatrix<FieldType, n> MatrixA;
static void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1,
static void addTransformedMatrix(MatrixA& A,
const TransformationMatrix1& T1,
const ScalarB& B,
const TransformationMatrix& T2) {
TransformMatrixHelper<FieldType, ScalarB,
typename TransformationMatrix::field_type, true,
true, true>::addTransformedMatrix(A.scalar(),
T1.scalar(), B,
T2.scalar());
const TransformationMatrix2& T2) {
TransformMatrixHelper<
FieldType, ScalarB, typename TransformationMatrix1::field_type,
typename TransformationMatrix2::field_type, true, true, true,
true>::addTransformedMatrix(A.scalar(), T1.scalar(), B, T2.scalar());
}
};
template <class ScalarA, class ScalarB, class ScalarTransform>
struct TransformMatrixHelper<ScalarA, ScalarB, ScalarTransform, true, true,
true> {
static void addTransformedMatrix(ScalarA& A, const ScalarTransform& T1,
template <class ScalarA, class ScalarB, class ScalarTransform1,
class ScalarTransform2>
struct TransformMatrixHelper<ScalarA, ScalarB, ScalarTransform1,
ScalarTransform2, true, true, true, true> {
static void addTransformedMatrix(ScalarA& A, const ScalarTransform1& T1,
const ScalarB& B,
const ScalarTransform& T2) {
const ScalarTransform2& T2) {
A += T1 * B * T2;
}
};
template <class MatrixA, typename FieldType, int n,
class TransformationMatrix>
class TransformationMatrix1, class TransformationMatrix2>
struct TransformMatrixHelper<MatrixA, Dune::DiagonalMatrix<FieldType, n>,
TransformationMatrix, false, false, false> {
TransformationMatrix1, TransformationMatrix2,
false, false, false, false> {
static void addTransformedMatrix(
MatrixA& A, const TransformationMatrix& T1,
MatrixA& A, const TransformationMatrix1& T1,
const Dune::DiagonalMatrix<FieldType, n>& B,
const TransformationMatrix& T2) {
const TransformationMatrix2& T2) {
for (size_t k = 0; k < n; ++k) {
for (auto Skj = T2[k].begin(); Skj != T2[k].end(); ++Skj) {
for (auto Tki = T1[k].begin(); Tki != T1[k].end(); Tki++)
......@@ -186,28 +197,30 @@ namespace MatrixVector {
};
template <class MatrixA, typename FieldType, int n,
class TransformationMatrix>
struct TransformMatrixHelper<MatrixA,
Dune::ScaledIdentityMatrix<FieldType, n>,
TransformationMatrix, false, false, false> {
class TransformationMatrix1, class TransformationMatrix2>
struct TransformMatrixHelper<
MatrixA, Dune::ScaledIdentityMatrix<FieldType, n>, TransformationMatrix1,
TransformationMatrix2, false, false, false, false> {
static void addTransformedMatrix(
MatrixA& A, const TransformationMatrix& T1,
MatrixA& A, const TransformationMatrix1& T1,
const Dune::ScaledIdentityMatrix<FieldType, n>& B,
const TransformationMatrix& T2) {
TransformMatrixHelper<MatrixA, FieldType, TransformationMatrix, false,
true, false>::addTransformedMatrix(A, T1,
B.scalar(), T2);
const TransformationMatrix2& T2) {
TransformMatrixHelper<MatrixA, FieldType, TransformationMatrix1,
TransformationMatrix2, false, true, false,
false>::addTransformedMatrix(A, T1, B.scalar(), T2);
}
};
template <typename FieldType, int n, class TransformationMatrix>
template <typename FieldType, int n, class TransformationMatrix1,
class TransformationMatrix2>
struct TransformMatrixHelper<Dune::DiagonalMatrix<FieldType, n>,
Dune::DiagonalMatrix<FieldType, n>,
TransformationMatrix, false, false, false> {
TransformationMatrix1, TransformationMatrix2,
false, false, false, false> {
static void addTransformedMatrix(
Dune::DiagonalMatrix<FieldType, n>& A, const TransformationMatrix& T1,
Dune::DiagonalMatrix<FieldType, n>& A, const TransformationMatrix1& T1,
const Dune::DiagonalMatrix<FieldType, n>& B,
const TransformationMatrix& T2) {
const TransformationMatrix2& T2) {
for (size_t k = 0; k < n; k++) {
for (auto Tki = T1[k].begin(); Tki != T1[k].end(); ++Tki)
A.diagonal(Tki.index()) +=
......@@ -216,15 +229,17 @@ namespace MatrixVector {
}
};
template <typename FieldType, int n, class TransformationMatrix>
template <typename FieldType, int n, class TransformationMatrix1,
class TransformationMatrix2>
struct TransformMatrixHelper<Dune::ScaledIdentityMatrix<FieldType, n>,
Dune::ScaledIdentityMatrix<FieldType, n>,
TransformationMatrix, false, false, false> {
TransformationMatrix1, TransformationMatrix2,
false, false, false, false> {
static void addTransformedMatrix(
Dune::ScaledIdentityMatrix<FieldType, n>& A,
const TransformationMatrix& T1,
const TransformationMatrix1& T1,
const Dune::ScaledIdentityMatrix<FieldType, n>& B,
const TransformationMatrix& T2) {
const TransformationMatrix2& T2) {
for (size_t k = 0; k < n; k++)
A.scalar() += T1[k][0] * B.scalar() * T2[k][0];
}
......@@ -232,50 +247,54 @@ namespace MatrixVector {
template <typename FieldType, int n>
struct TransformMatrixHelper<Dune::ScaledIdentityMatrix<FieldType, n>,
Dune::ScaledIdentityMatrix<FieldType, n>,
Dune::ScaledIdentityMatrix<FieldType, n>,
Dune::ScaledIdentityMatrix<FieldType, n>, false,
false, false> {
false, false, false> {
static void addTransformedMatrix(
Dune::ScaledIdentityMatrix<FieldType, n>& A,
const Dune::ScaledIdentityMatrix<FieldType, n>& T1,
const Dune::ScaledIdentityMatrix<FieldType, n>& B,
const Dune::ScaledIdentityMatrix<FieldType, n>& T2) {
TransformMatrixHelper<FieldType, FieldType, FieldType, true, true,
true>::addTransformedMatrix(A.scalar(), T1.scalar(),
B.scalar(),
T2.scalar());
TransformMatrixHelper<
FieldType, FieldType, FieldType, FieldType, true, true, true,
true>::addTransformedMatrix(A.scalar(), T1.scalar(), B.scalar(),
T2.scalar());
}
};
template <class MatrixA, class MatrixB, class TransformationMatrix>
void addTransformedMatrix(MatrixA& A, const TransformationMatrix& T1,
const MatrixB& B, const TransformationMatrix& T2) {
template <class MatrixA, class MatrixB, class TransformationMatrix1,
class TransformationMatrix2>
void addTransformedMatrix(MatrixA& A, const TransformationMatrix1& T1,
const MatrixB& B, const TransformationMatrix2& T2) {
TransformMatrixHelper<
MatrixA, MatrixB, TransformationMatrix, ScalarTraits<MatrixA>::isScalar,
ScalarTraits<MatrixB>::isScalar,
ScalarTraits<TransformationMatrix>::isScalar>::addTransformedMatrix(A,
T1,
B,
T2);
MatrixA, MatrixB, TransformationMatrix1, TransformationMatrix2,
ScalarTraits<MatrixA>::isScalar, ScalarTraits<MatrixB>::isScalar,
ScalarTraits<TransformationMatrix1>::isScalar,
ScalarTraits<
TransformationMatrix2>::isScalar>::addTransformedMatrix(A, T1, B,
T2);
}
template <class MatrixA, class MatrixB, class TransformationMatrix>
void transformMatrix(MatrixA& A, const TransformationMatrix& T1,
const MatrixB& B, const TransformationMatrix& T2) {
template <class MatrixA, class MatrixB, class TransformationMatrix1,
class TransformationMatrix2>
void transformMatrix(MatrixA& A, const TransformationMatrix1& T1,
const MatrixB& B, const TransformationMatrix2& T2) {
A = 0;
TransformMatrixHelper<
MatrixA, MatrixB, TransformationMatrix, ScalarTraits<MatrixA>::isScalar,
ScalarTraits<MatrixB>::isScalar,
ScalarTraits<TransformationMatrix>::isScalar>::addTransformedMatrix(A,
T1,
B,
T2);
MatrixA, MatrixB, TransformationMatrix1, TransformationMatrix2,
ScalarTraits<MatrixA>::isScalar, ScalarTraits<MatrixB>::isScalar,
ScalarTraits<TransformationMatrix1>::isScalar,
ScalarTraits<
TransformationMatrix2>::isScalar>::addTransformedMatrix(A, T1, B,
T2);
}
template <class MatrixBlockA, class MatrixB, class TransformationMatrix>
template <class MatrixBlockA, class MatrixB, class TransformationMatrix1,
class TransformationMatrix2>
static void transformMatrix(Dune::BCRSMatrix<MatrixBlockA>& A,
const TransformationMatrix& T1, const MatrixB& B,
const TransformationMatrix& T2) {
const TransformationMatrix1& T1, const MatrixB& B,
const TransformationMatrix2& T2) {
transformMatrixPattern(A, T1, B, T2);
A = 0.0;
for (size_t k = 0; k < B.N(); ++k) {
......@@ -294,11 +313,12 @@ namespace MatrixVector {
}
}
template <class MatrixBlockA, class MatrixB, class TransformationMatrix>
template <class MatrixBlockA, class MatrixB, class TransformationMatrix1,
class TransformationMatrix2>
static void transformMatrixPattern(Dune::BCRSMatrix<MatrixBlockA>& A,
const TransformationMatrix& T1,
const TransformationMatrix1& T1,
const MatrixB& B,
const TransformationMatrix& T2) {
const TransformationMatrix2& T2) {
Dune::MatrixIndexSet indices(T1.M(), T2.M());
for (size_t k = 0; k < B.N(); ++k) {
for (auto BklIt = B[k].begin(); BklIt != B[k].end(); ++BklIt) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment