From b18fce9af710a467b7ed797ce59d89de13ab518c Mon Sep 17 00:00:00 2001
From: Patrick Jaap <patrick.jaap@tu-dresden.de>
Date: Mon, 13 Jul 2020 11:54:18 +0200
Subject: [PATCH] DuneFunctionsOperatorAssembler: use one localView twice if
 bases are the same

For operators with equal ansatz and trial space we can use one localView twice
instead of calling the expensive bind() twice.
---
 .../dunefunctionsoperatorassembler.hh         | 51 ++++++++++++++-----
 1 file changed, 38 insertions(+), 13 deletions(-)

diff --git a/dune/fufem/assemblers/dunefunctionsoperatorassembler.hh b/dune/fufem/assemblers/dunefunctionsoperatorassembler.hh
index 8e23f5d6..39f0c327 100644
--- a/dune/fufem/assemblers/dunefunctionsoperatorassembler.hh
+++ b/dune/fufem/assemblers/dunefunctionsoperatorassembler.hh
@@ -3,6 +3,8 @@
 #ifndef DUNE_FUFEM_ASSEMBLERS_DUNEFUNCTIONSOPERATORASSEMBLER_HH
 #define DUNE_FUFEM_ASSEMBLERS_DUNEFUNCTIONSOPERATORASSEMBLER_HH
 
+#include <type_traits>
+
 #include <dune/istl/matrix.hh>
 #include <dune/istl/matrixindexset.hh>
 
@@ -37,15 +39,15 @@ public:
   {
     patternBuilder.resize(trialBasis_, ansatzBasis_);
 
-    auto trialLocalView     = trialBasis_.localView();
-
-    auto ansatzLocalView     = ansatzBasis_.localView();
+    // create two localViews but use only one if bases are the same
+    auto ansatzLocalView = ansatzBasis_.localView();
+    auto seperateTrialLocalView = trialBasis_.localView();
+    auto& trialLocalView = selectTrialLocalView(ansatzLocalView, seperateTrialLocalView);
 
     for (const auto& element : elements(trialBasis_.gridView()))
     {
-      trialLocalView.bind(element);
-
-      ansatzLocalView.bind(element);
+      // bind the localViews to the element
+      bind(ansatzLocalView, trialLocalView, element);
 
       // Add element stiffness matrix onto the global stiffness matrix
       for (size_t i=0; i<trialLocalView.size(); ++i)
@@ -119,9 +121,10 @@ public:
   template <class MatrixBackend, class LocalAssembler>
   void assembleBulkEntries(MatrixBackend&& matrixBackend, LocalAssembler&& localAssembler) const
   {
-    auto trialLocalView     = trialBasis_.localView();
-
-    auto ansatzLocalView     = ansatzBasis_.localView();
+    // create two localViews but use only one if bases are the same
+    auto ansatzLocalView = ansatzBasis_.localView();
+    auto seperateTrialLocalView = trialBasis_.localView();
+    auto& trialLocalView = selectTrialLocalView(ansatzLocalView, seperateTrialLocalView);
 
     using Field = std::decay_t<decltype(matrixBackend(trialLocalView.index(0), ansatzLocalView.index(0)))>;
     using LocalMatrix = Dune::Matrix<Dune::FieldMatrix<Field,1,1>>;
@@ -130,12 +133,10 @@ public:
 
     for (const auto& element : elements(trialBasis_.gridView()))
     {
-      trialLocalView.bind(element);
-
-      ansatzLocalView.bind(element);
+      // bind the localViews to the element
+      bind(ansatzLocalView, trialLocalView, element);
 
       localMatrix.setSize(trialLocalView.size(), ansatzLocalView.size());
-
       localAssembler(element, localMatrix, trialLocalView, ansatzLocalView);
 
       // Add element stiffness matrix onto the global stiffness matrix
@@ -270,6 +271,30 @@ public:
     assembleBulkEntries(matrixBackend, std::forward<LocalAssembler>(localAssembler));
   }
 
+private:
+
+//! helper function to select a trialLocalView (possibly a reference to ansatzLocalView if bases are the same)
+template<class AnsatzLocalView, class TrialLocalView>
+TrialLocalView& selectTrialLocalView(AnsatzLocalView& ansatzLocalView, TrialLocalView& trialLocalView) const
+{
+   if constexpr (std::is_same<TrialBasis,AnsatzBasis>::value)
+     if (&trialBasis_ == &ansatzBasis_)
+       return ansatzLocalView;
+   return trialLocalView;
+}
+
+//! small helper that checks whether two localViews are the same and binds one or both to an element
+template<class AnsatzLocalView, class TrialLocalView, class E>
+void bind(AnsatzLocalView& ansatzLocalView, TrialLocalView& trialLocalView, const E& e) const
+{
+  ansatzLocalView.bind(e);
+  if constexpr (std::is_same<TrialBasis,AnsatzBasis>::value)
+    if (&trialLocalView == &ansatzLocalView)
+      return;
+  // localViews differ: bind trialLocalView too
+  trialLocalView.bind(e);
+}
+
 protected:
   const TrialBasis& trialBasis_;
   const AnsatzBasis& ansatzBasis_;
-- 
GitLab