Skip to content
Snippets Groups Projects
Forked from agnumpde / dune-tectonic
184 commits ahead of the upstream repository.
stateupdater.hh 2.65 KiB
#ifndef SRC_TIME_STEPPING_STATE_STATEUPDATER_HH
#define SRC_TIME_STEPPING_STATE_STATEUPDATER_HH

#include <memory>
#include <vector>

// state updater for each coupling
template <class ScalarVectorTEMPLATE, class Vector> class LocalStateUpdater {
public:
  using ScalarVector = ScalarVectorTEMPLATE;

  void setBodyIndex(const size_t bodyIdx) {
      bodyIdx_ = bodyIdx;
  }

  auto bodyIndex() const {
      return bodyIdx_;
  }

  void virtual nextTimeStep() = 0;
  void virtual setup(double _tau) = 0;
  void virtual solve(const Vector& velocity_field) = 0;
  void virtual extractAlpha(ScalarVector& alpha) = 0;

  std::shared_ptr<LocalStateUpdater<ScalarVector, Vector>> virtual clone() const = 0;

private:
  size_t bodyIdx_;
};


template <class ScalarVectorTEMPLATE, class Vector> class StateUpdater {
public:
  using ScalarVector = ScalarVectorTEMPLATE;
  using LocalStateUpdater = LocalStateUpdater<ScalarVector, Vector>;

  StateUpdater(const std::vector<size_t>& leafVertexCounts) :
    leafVertexCounts_(leafVertexCounts) {}

  void addLocalUpdater(std::shared_ptr<LocalStateUpdater> localStateUpdater) {
    localStateUpdaters_.emplace_back(localStateUpdater);
  }

  void nextTimeStep() {
    for (size_t i=0; i<localStateUpdaters_.size(); i++) {
        localStateUpdaters_[i]->nextTimeStep();
    }
  }

  void setup(double tau) {
    for (size_t i=0; i<localStateUpdaters_.size(); i++) {
        localStateUpdaters_[i]->setup(tau);
    }
  }

  void solve(const std::vector<Vector>& velocity_field) {
    for (size_t i=0; i<localStateUpdaters_.size(); i++) {
        auto& localStateUpdater = localStateUpdaters_[i];
        localStateUpdater->solve(velocity_field[localStateUpdater->bodyIndex()]);
    }
  }

  void extractAlpha(std::vector<ScalarVector>& alpha) {
    alpha.resize(leafVertexCounts_.size());
    for (size_t i=0; i<alpha.size(); i++) {
        alpha[i].resize(leafVertexCounts_[i], 0.0);
    }

    for (size_t i=0; i<localStateUpdaters_.size(); i++) {
        auto& localStateUpdater = localStateUpdaters_[i];
        localStateUpdater->extractAlpha(alpha[localStateUpdater->bodyIndex()]);
    }
  }

  std::shared_ptr<StateUpdater<ScalarVector, Vector>> virtual clone() const {
      auto updater = std::make_shared<StateUpdater<ScalarVector, Vector>>(leafVertexCounts_);

      for (size_t i=0; i<localStateUpdaters_.size(); i++) {
          auto localUpdater = localStateUpdaters_[i]->clone();
          updater->addLocalUpdater(localUpdater);
      }
      return updater; // std::make_shared<StateUpdater<ScalarVector, Vector>>(*this);
  }

private:
  std::vector<size_t> leafVertexCounts_;
  std::vector<std::shared_ptr<LocalStateUpdater>> localStateUpdaters_;
};
#endif