From f23106a044a16fb27127b980d30bdf5e1ba6f952 Mon Sep 17 00:00:00 2001 From: Phillip Berndt <phillip.berndt@googlemail.com> Date: Sat, 5 Nov 2016 16:02:03 +0100 Subject: [PATCH] Do not overwrite exceptions from callbacks --- pyradau13.c | 23 ++++++++++++++++++----- test.py | 12 ++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/pyradau13.c b/pyradau13.c index 1d926de..2b0d2a0 100644 --- a/pyradau13.c +++ b/pyradau13.c @@ -125,12 +125,16 @@ static void radau_rhs(int *n, double *x, double *y, double *f, float *rpar, int Py_DECREF(y_current); if(rhs_retval == NULL) { - PyErr_SetString(PyExc_RuntimeError, "The RHS function must return a value"); + if(PyErr_Occurred() == NULL) { + PyErr_SetString(PyExc_RuntimeError, "The RHS function must return a value"); + } return; } PyArrayObject *rv_array = (PyArrayObject *)PyArray_FROM_OTF(rhs_retval, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); if(rv_array == NULL) { - PyErr_SetString(PyExc_RuntimeError, "The RHS function must return a vector"); + if(PyErr_Occurred() == NULL) { + PyErr_SetString(PyExc_RuntimeError, "The RHS function must return a vector"); + } return; } int use_n = PyArray_SIZE(rv_array); @@ -148,12 +152,16 @@ static void radau_jacobian(int *n, double *x, double *y, double *dfy, int *ldfy, Py_DECREF(y_current); if(jacobian_retval == NULL) { - PyErr_SetString(PyExc_RuntimeError, "The jacobian function must return a value"); + if(PyErr_Occurred() == NULL) { + PyErr_SetString(PyExc_RuntimeError, "The jacobian function must return a value"); + } return; } PyArrayObject *rv_array = (PyArrayObject *)PyArray_FROM_OTF(jacobian_retval, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY | NPY_ARRAY_F_CONTIGUOUS); if(rv_array == NULL) { - PyErr_SetString(PyExc_RuntimeError, "The jacobian function must return a matrix"); + if(PyErr_Occurred() == NULL) { + PyErr_SetString(PyExc_RuntimeError, "The jacobian function must return a matrix"); + } return; } int use_n = PyArray_SIZE(rv_array); @@ -193,7 +201,12 @@ static void radau_dense_feedback(int *nr, double *xold, double *x, double *y, do evaluator->cont = NULL; Py_DECREF(evaluator); - if(rhs_retval == NULL || PyObject_IsTrue(rhs_retval)) { + + if(PyErr_Occurred() != NULL) { + // Must abort: An error occured + *irtrn = 1; + } + else if(rhs_retval == NULL || PyObject_IsTrue(rhs_retval)) { *irtrn = -1; } else { diff --git a/test.py b/test.py index fae9025..2c84165 100755 --- a/test.py +++ b/test.py @@ -9,6 +9,18 @@ class TestIntegration(unittest.TestCase): def _pendulum_rhs(self, t, x): return [ x[1], -x[0] ] + def test_exception_from_callback(self): + class _TestException(Exception): + pass + def _rhs(t, x): + raise _TestException() + def _dense_cb(told, t, x, cont): + raise _TestException() + + self.assertRaises(_TestException, lambda: radau13(_rhs, 0, 1)) + self.assertRaises(_TestException, + lambda: radau13(lambda t, x: 1, 0, 1, dense_callback=_dense_cb)) + def test_exp(self): self.assertAlmostEqual(float(radau13(lambda t, x: x, 1, 1)), np.exp(1)) -- GitLab