diff --git a/pyradau13.c b/pyradau13.c index 1d926dea91cb460ca9509dafbf01f0a5bc7eea9c..2b0d2a03795bdd2af0f7282614d242827022554e 100644 --- a/pyradau13.c +++ b/pyradau13.c @@ -125,12 +125,16 @@ static void radau_rhs(int *n, double *x, double *y, double *f, float *rpar, int Py_DECREF(y_current); if(rhs_retval == NULL) { - PyErr_SetString(PyExc_RuntimeError, "The RHS function must return a value"); + if(PyErr_Occurred() == NULL) { + PyErr_SetString(PyExc_RuntimeError, "The RHS function must return a value"); + } return; } PyArrayObject *rv_array = (PyArrayObject *)PyArray_FROM_OTF(rhs_retval, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); if(rv_array == NULL) { - PyErr_SetString(PyExc_RuntimeError, "The RHS function must return a vector"); + if(PyErr_Occurred() == NULL) { + PyErr_SetString(PyExc_RuntimeError, "The RHS function must return a vector"); + } return; } int use_n = PyArray_SIZE(rv_array); @@ -148,12 +152,16 @@ static void radau_jacobian(int *n, double *x, double *y, double *dfy, int *ldfy, Py_DECREF(y_current); if(jacobian_retval == NULL) { - PyErr_SetString(PyExc_RuntimeError, "The jacobian function must return a value"); + if(PyErr_Occurred() == NULL) { + PyErr_SetString(PyExc_RuntimeError, "The jacobian function must return a value"); + } return; } PyArrayObject *rv_array = (PyArrayObject *)PyArray_FROM_OTF(jacobian_retval, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY | NPY_ARRAY_F_CONTIGUOUS); if(rv_array == NULL) { - PyErr_SetString(PyExc_RuntimeError, "The jacobian function must return a matrix"); + if(PyErr_Occurred() == NULL) { + PyErr_SetString(PyExc_RuntimeError, "The jacobian function must return a matrix"); + } return; } int use_n = PyArray_SIZE(rv_array); @@ -193,7 +201,12 @@ static void radau_dense_feedback(int *nr, double *xold, double *x, double *y, do evaluator->cont = NULL; Py_DECREF(evaluator); - if(rhs_retval == NULL || PyObject_IsTrue(rhs_retval)) { + + if(PyErr_Occurred() != NULL) { + // Must abort: An error occured + *irtrn = 1; + } + else if(rhs_retval == NULL || PyObject_IsTrue(rhs_retval)) { *irtrn = -1; } else { diff --git a/test.py b/test.py index fae9025a03ee28398c6f134efa88354432bbb37b..2c84165bfa0f31dcd57b1951c02c5637c279f168 100755 --- a/test.py +++ b/test.py @@ -9,6 +9,18 @@ class TestIntegration(unittest.TestCase): def _pendulum_rhs(self, t, x): return [ x[1], -x[0] ] + def test_exception_from_callback(self): + class _TestException(Exception): + pass + def _rhs(t, x): + raise _TestException() + def _dense_cb(told, t, x, cont): + raise _TestException() + + self.assertRaises(_TestException, lambda: radau13(_rhs, 0, 1)) + self.assertRaises(_TestException, + lambda: radau13(lambda t, x: 1, 0, 1, dense_callback=_dense_cb)) + def test_exp(self): self.assertAlmostEqual(float(radau13(lambda t, x: x, 1, 1)), np.exp(1))