Skip to content

Commit 4d99b5f

Browse files
authored
Add _debug_fill_halos() for arbitrary halo recalculation (#502)
1 parent 9872aa7 commit 4d99b5f

File tree

3 files changed

+143
-25
lines changed

3 files changed

+143
-25
lines changed

PyMPDATA/impl/field.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
`PyMPDATA.vector_field.VectorField` classes"""
33

44
import abc
5+
import warnings
56
from collections import namedtuple
67

8+
from numba import NumbaExperimentalFeatureWarning
9+
710
from PyMPDATA.boundary_conditions.constant import Constant
811

912
from .enumerations import INNER, INVALID_HALO_VALUE, MAX_DIM_NUM, MID3D, OUTER
@@ -120,3 +123,22 @@ def make_null(
120123
n_dims: int, traversals
121124
): # pylint: disable=missing-function-docstring
122125
raise NotImplementedError()
126+
127+
def _debug_fill_halos(self, traversals, n_threads):
128+
meta_and_data, fill_halos_fun = self.impl
129+
if self.__class__.__name__ == "VectorField":
130+
meta_and_data = (
131+
meta_and_data[0],
132+
(meta_and_data[1], meta_and_data[2], meta_and_data[3]),
133+
)
134+
sut = traversals._code[ # pylint:disable=protected-access
135+
{"ScalarField": "fill_halos_scalar", "VectorField": "fill_halos_vector"}[
136+
self.__class__.__name__
137+
]
138+
]
139+
with warnings.catch_warnings():
140+
warnings.simplefilter(
141+
action="ignore", category=NumbaExperimentalFeatureWarning
142+
)
143+
for thread_id in n_threads:
144+
sut(thread_id, *meta_and_data, fill_halos_fun, traversals.data.buffer)

tests/unit_tests/test_boundary_condition_extrapolated_2d.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,17 @@
11
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
2-
import warnings
32

43
import numpy as np
54
import pytest
6-
from numba import NumbaExperimentalFeatureWarning
75

86
from PyMPDATA import Options, ScalarField, VectorField
97
from PyMPDATA.boundary_conditions import Constant, Extrapolated
108
from PyMPDATA.impl.enumerations import MAX_DIM_NUM
11-
from PyMPDATA.impl.field import Field
129
from PyMPDATA.impl.traversals import Traversals
1310
from tests.unit_tests.quick_look import quick_look
1411

1512
JIT_FLAGS = Options().jit_flags
1613

1714

18-
def fill_halos(field: Field, traversals: Traversals, threads):
19-
field.assemble(traversals)
20-
meta_and_data, fill_halos_fun = field.impl
21-
if isinstance(field, VectorField):
22-
meta_and_data = (
23-
meta_and_data[0],
24-
(meta_and_data[1], meta_and_data[2], meta_and_data[3]),
25-
)
26-
sut = traversals._code[ # pylint:disable=protected-access
27-
{"ScalarField": "fill_halos_scalar", "VectorField": "fill_halos_vector"}[
28-
field.__class__.__name__
29-
]
30-
]
31-
with warnings.catch_warnings():
32-
warnings.simplefilter(action="ignore", category=NumbaExperimentalFeatureWarning)
33-
for thread_id in threads:
34-
sut(thread_id, *meta_and_data, fill_halos_fun, traversals.data.buffer)
35-
36-
3715
class TestBoundaryConditionExtrapolated2D:
3816
@staticmethod
3917
@pytest.mark.parametrize("n_threads", (1, 2))
@@ -55,7 +33,6 @@ def test_scalar_field(
5533
boundary_conditions=boundary_conditions,
5634
halo=n_halo,
5735
)
58-
5936
traversals = Traversals(
6037
grid=advectee.grid,
6138
halo=n_halo,
@@ -64,10 +41,13 @@ def test_scalar_field(
6441
left_first=tuple([True] * MAX_DIM_NUM),
6542
buffer_size=0,
6643
)
44+
advectee.assemble(traversals)
6745

6846
# act / plot
6947
quick_look(advectee, plot)
70-
fill_halos(advectee, traversals, threads=range(n_threads))
48+
advectee._debug_fill_halos( # pylint:disable=protected-access
49+
traversals, range(n_threads)
50+
)
7151
quick_look(advectee, plot)
7252

7353
# assert
@@ -105,10 +85,13 @@ def test_vector_field(
10585
left_first=tuple([True] * MAX_DIM_NUM),
10686
buffer_size=0,
10787
)
88+
advector.assemble(traversals)
10889

10990
# act / plot
11091
quick_look(advector, plot)
111-
fill_halos(advector, traversals, threads=range(n_threads))
92+
advector._debug_fill_halos( # pylint:disable=protected-access
93+
traversals, range(n_threads)
94+
)
11295
quick_look(advector, plot)
11396

11497
# assert
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# pylint:disable=missing-module-docstring,missing-function-docstring,duplicate-code
2+
import numpy as np
3+
import pytest
4+
5+
from PyMPDATA import Options, ScalarField, VectorField
6+
from PyMPDATA.boundary_conditions import Periodic
7+
from PyMPDATA.impl.enumerations import MAX_DIM_NUM
8+
from PyMPDATA.impl.traversals import Traversals
9+
10+
JIT_FLAGS = Options().jit_flags
11+
12+
13+
def assert_slice_size(used_slice: slice, halo):
14+
if used_slice.stop is None:
15+
if abs(used_slice.start) != halo and abs(used_slice.start) != halo - 1:
16+
raise AssertionError("Slice and halo size mismatch")
17+
elif used_slice.stop is not None:
18+
if (
19+
abs(used_slice.stop - used_slice.start) != halo
20+
and abs(used_slice.stop - used_slice.start) != halo - 1
21+
):
22+
raise AssertionError("Slice and halo size mismatch")
23+
else:
24+
assert False
25+
26+
27+
def assert_array_not_equal(arr_a, arr_b):
28+
return np.testing.assert_raises(
29+
AssertionError, np.testing.assert_array_equal, arr_a, arr_b
30+
)
31+
32+
33+
@pytest.mark.parametrize("boundary_condition", (Periodic(),))
34+
@pytest.mark.parametrize("n_threads", (1, 2))
35+
@pytest.mark.parametrize("halo", (1, 2, 3))
36+
@pytest.mark.parametrize(
37+
"field_factory",
38+
(
39+
lambda halo, boundary_condition: ScalarField(
40+
np.zeros(3), halo, boundary_condition
41+
), # 1d
42+
lambda halo, boundary_condition: VectorField(
43+
(np.zeros(3),), halo, boundary_condition
44+
), # 1d
45+
lambda halo, boundary_condition: ScalarField(
46+
np.zeros((3, 3)), halo, boundary_condition
47+
), # 2d
48+
lambda halo, boundary_condition: VectorField(
49+
(
50+
np.zeros(
51+
(4, 3),
52+
),
53+
np.zeros(
54+
(3, 4),
55+
),
56+
),
57+
halo,
58+
boundary_condition,
59+
), # 2d
60+
),
61+
)
62+
def test_explicit_fill_halos(field_factory, halo, boundary_condition, n_threads):
63+
# arange
64+
field = field_factory(halo, (boundary_condition, boundary_condition))
65+
if len(field.grid) == 1 and n_threads > 1:
66+
pytest.skip("Skip 1D tests with n_threads > 1")
67+
traversals = Traversals(
68+
grid=field.grid,
69+
halo=halo,
70+
jit_flags=JIT_FLAGS,
71+
n_threads=n_threads,
72+
left_first=tuple([True] * MAX_DIM_NUM),
73+
buffer_size=0,
74+
)
75+
field.assemble(traversals)
76+
if isinstance(field, ScalarField):
77+
field.get()[:] = np.arange(1, field.grid[0] + 1)
78+
left_halo = slice(0, halo)
79+
right_halo = slice(-halo, None)
80+
left_edge = slice(halo, 2 * halo)
81+
right_edge = slice(-2 * halo, -halo)
82+
slices = [left_halo, right_halo, left_edge, right_edge]
83+
for slice_to_check in slices:
84+
assert_slice_size(slice_to_check, halo)
85+
data = field.data
86+
elif isinstance(field, VectorField):
87+
if field.get_component(0)[:].ndim > 1:
88+
field.get_component(0)[0][:] = np.arange(1, field.grid[0] + 1)
89+
field.get_component(0)[1][:] = np.arange(1, field.grid[0] + 1)
90+
else:
91+
field.get_component(0)[:] = np.arange(1, field.grid[0] + 2)
92+
if halo == 1:
93+
pytest.skip("Skip VectorField test if halo == 1")
94+
left_halo = slice(0, halo - 1)
95+
right_halo = slice(-(halo - 1), None)
96+
left_edge = slice(halo, 2 * (halo - 1) + 1)
97+
right_edge = slice(-2 * (halo - 1) - 1, -(halo - 1) - 1)
98+
slices = [left_halo, right_halo, left_edge, right_edge]
99+
for slice_to_check in slices:
100+
assert_slice_size(slice_to_check, halo)
101+
data = field.data[0]
102+
else:
103+
assert False
104+
assert_array_not_equal(data[left_halo], data[right_edge])
105+
assert_array_not_equal(data[right_halo], data[left_edge])
106+
107+
# act
108+
# pylint:disable=protected-access
109+
field._debug_fill_halos(traversals, range(n_threads))
110+
111+
# assert
112+
np.testing.assert_array_equal(data[left_halo], data[right_edge], verbose=True)
113+
np.testing.assert_array_equal(data[right_halo], data[left_edge], verbose=True)

0 commit comments

Comments
 (0)