diff --git a/src/compute_state.cc b/src/compute_state.cc
index adfb4cd9e99c1dba762756ddc706ebcfa82fc2e5..27a6ef67ccc58df62b53dd28bf949532a760b3cf 100644
--- a/src/compute_state.cc
+++ b/src/compute_state.cc
@@ -1,3 +1,5 @@
+#include <cassert>
+
 #include "LambertW.h"
 
 #include <gsl/gsl_sf_lambert.h>
@@ -51,3 +53,23 @@ double compute_state_update_lambert_gsl(double h, double unorm, double L,
   double const rhs = unorm / L - old_state;
   return gsl_sf_lambert_W0(h * std::exp(rhs)) - rhs;
 }
+
+double compute_state_update(double h, double unorm, double L,
+                            double old_state) {
+  double ret1 = compute_state_update_bisection(h, unorm, L, old_state);
+  assert(std::abs(1.0 / h * ret1 - (old_state - unorm / L) / h -
+                  std::exp(-ret1)) < 1e-10);
+
+  double ret2 = compute_state_update_lambert(h, unorm, L, old_state);
+  assert(std::abs(1.0 / h * ret2 - (old_state - unorm / L) / h -
+                  std::exp(-ret2)) < 1e-10);
+
+  double ret3 = compute_state_update_lambert_gsl(h, unorm, L, old_state);
+  assert(std::abs(1.0 / h * ret3 - (old_state - unorm / L) / h -
+                  std::exp(-ret3)) < 1e-10);
+
+  assert(std::abs(ret1 - ret2) < 1e-14);
+  assert(std::abs(ret1 - ret3) < 1e-14);
+
+  return ret1;
+}
diff --git a/src/compute_state.hh b/src/compute_state.hh
index 2710f625ce721a3e384336a3dbd93d6206ef7de4..6b4e8778affcf364b5a3812b8087800c1ff02ffc 100644
--- a/src/compute_state.hh
+++ b/src/compute_state.hh
@@ -1,11 +1,6 @@
 #ifndef COMPUTE_STATE_HH
 #define COMPUTE_STATE_HH
 
-double compute_state_update_bisection(double h, double unorm, double L,
-                                      double old_state);
-double compute_state_update_lambert(double h, double unorm, double L,
-                                    double old_state);
-double compute_state_update_lambert_gsl(double h, double unorm, double L,
-                                        double old_state);
+double compute_state_update(double h, double unorm, double L, double old_state);
 
 #endif
diff --git a/src/one-body-sample.cc b/src/one-body-sample.cc
index 4812c267c69bda14ddf13032744305b2be987259..0e380ee18f102dd530c726ed6277b5701a8c64f0 100644
--- a/src/one-body-sample.cc
+++ b/src/one-body-sample.cc
@@ -349,26 +349,7 @@ int main(int argc, char *argv[]) {
               double const L = parset.get<double>("boundary.friction.ruina.L");
               double const unorm = u4_diff[i].two_norm();
 
-              double ret1 =
-                  compute_state_update_bisection(h, unorm, L, s4_old[i][0]);
-              assert(std::abs(1.0 / h * ret1 - (s4_old[i] - unorm / L) / h -
-                              std::exp(-ret1)) < 1e-10);
-
-              double ret2 =
-                  compute_state_update_lambert(h, unorm, L, s4_old[i][0]);
-              assert(std::abs(1.0 / h * ret2 - (s4_old[i] - unorm / L) / h -
-                              std::exp(-ret2)) < 1e-10);
-
-              double ret3 =
-                  compute_state_update_lambert_gsl(h, unorm, L, s4_old[i][0]);
-              assert(std::abs(1.0 / h * ret3 - (s4_old[i] - unorm / L) / h -
-                              std::exp(-ret3)) < 1e-10);
-
-              assert(std::abs(ret1 - ret2) < 1e-14);
-              assert(std::abs(ret1 - ret3) < 1e-14);
-              assert(std::abs(ret2 - ret3) < 1e-14);
-
-              (*s4_new)[i][0] = ret1;
+              (*s4_new)[i][0] = compute_state_update(h, unorm, L, s4_old[i][0]);
             }
           }
         }
@@ -421,26 +402,7 @@ int main(int argc, char *argv[]) {
               double const L = parset.get<double>("boundary.friction.ruina.L");
               double const unorm = u5_diff[i].two_norm();
 
-              double ret1 =
-                  compute_state_update_bisection(h, unorm, L, s5_old[i][0]);
-              assert(std::abs(1.0 / h * ret1 - (s5_old[i] - unorm / L) / h -
-                              std::exp(-ret1)) < 1e-12);
-
-              double ret2 =
-                  compute_state_update_lambert(h, unorm, L, s5_old[i][0]);
-              assert(std::abs(1.0 / h * ret2 - (s5_old[i] - unorm / L) / h -
-                              std::exp(-ret2)) < 1e-12);
-
-              double ret3 =
-                  compute_state_update_lambert_gsl(h, unorm, L, s4_old[i][0]);
-              assert(std::abs(1.0 / h * ret3 - (s4_old[i] - unorm / L) / h -
-                              std::exp(-ret3)) < 1e-10);
-
-              assert(std::abs(ret1 - ret2) < 1e-14);
-              assert(std::abs(ret1 - ret3) < 1e-14);
-              assert(std::abs(ret2 - ret3) < 1e-14);
-
-              (*s5_new)[i][0] = ret1;
+              (*s5_new)[i][0] = compute_state_update(h, unorm, L, s5_old[i][0]);
             }
           }
         }