Skip to content

Commit 2f5836a

Browse files
committed
Avoid killed kernel with badly defined analytic functions
1 parent 01adce9 commit 2f5836a

File tree

4 files changed

+56
-1
lines changed

4 files changed

+56
-1
lines changed

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Check out the :doc:`usage` section for further information, including how to :re
2525
notebooks/poisson_equation
2626
notebooks/helmholtz_equation
2727
notebooks/hydrogen_atom
28+
notebooks/helium_atom
2829
notebooks/multiwavelets
2930
notebooks/PCMSolvent
3031

src/vampyr/tests/test_projector1d.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
3+
from vampyr import vampyr1d as vp
4+
5+
def test_ScalingProjector():
6+
def f(x):
7+
return x
8+
9+
mra = vp.MultiResolutionAnalysis(box=[0, 1], order=7)
10+
P_scaling = vp.ScalingProjector(mra, 2)
11+
P_wavelet = vp.WaveletProjector(mra, 2)
12+
13+
with pytest.raises(Exception):
14+
P_scaling(f)
15+
16+
with pytest.raises(Exception):
17+
P_wavelet(f)

src/vampyr/tests/test_projector3d.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
3+
from vampyr import vampyr3d as vp
4+
5+
def test_ScalingProjector():
6+
def f(x):
7+
return x
8+
9+
mra = vp.MultiResolutionAnalysis(box=[0, 1], order=7)
10+
P_scaling = vp.ScalingProjector(mra, 2)
11+
P_wavelet = vp.WaveletProjector(mra, 2)
12+
13+
with pytest.raises(Exception):
14+
P_scaling(f)
15+
16+
with pytest.raises(Exception):
17+
P_wavelet(f)

src/vampyr/treebuilders/project.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#pragma once
22

3-
#include <pybind11/functional.h>
3+
#include <typeinfo>
44

5+
#include <pybind11/functional.h>
56
#include "PyProjectors.h"
67

78
namespace vampyr {
@@ -18,6 +19,15 @@ template <int D> void project(pybind11::module &m) {
1819
.def(
1920
"__call__",
2021
[](PyScalingProjector<D> &P, std::function<double(const Coord<D> &r)> func) {
22+
23+
try {
24+
auto arr = std::array<double, D>();
25+
arr.fill(111111.111); // A number which hopefully does not divide by zero
26+
func(arr);
27+
} catch (py::cast_error &e) {
28+
py::print("Error: Invalid definition of analytic function");
29+
throw;
30+
}
2131
auto old_threads = mrcpp_get_num_threads();
2232
set_max_threads(1);
2333
auto out = P(func);
@@ -33,6 +43,16 @@ template <int D> void project(pybind11::module &m) {
3343
.def(
3444
"__call__",
3545
[](PyWaveletProjector<D> &P, std::function<double(const Coord<D> &r)> func) {
46+
47+
try {
48+
auto arr = std::array<double, D>();
49+
arr.fill(111111.111); // A number which hopefully does not divide by zero
50+
func(arr);
51+
} catch (py::cast_error &e) {
52+
py::print("Error: Invalid definition of analytic function");
53+
throw;
54+
}
55+
3656
auto old_threads = mrcpp_get_num_threads();
3757
set_max_threads(1);
3858
auto out = P(func);

0 commit comments

Comments
 (0)