Skip to content

Commit 2a4d2d5

Browse files
olastrzslayoo
andauthored
Vtk exporter fix to save more than one product (#1439)
Co-authored-by: Sylwester Arabas <sylwester.arabas@agh.edu.pl>
1 parent 5641771 commit 2a4d2d5

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

PySDM/exporters/vtk_exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def export_products(self, particulator):
134134

135135
if isinstance(v, np.ndarray):
136136
if v.shape == particulator.mesh.grid:
137-
payload[k] = v[:, :, np.newaxis]
137+
payload[k] = v[:, :, np.newaxis].copy()
138138
else:
139139
if self.verbose:
140140
print(

tests/unit_tests/exporters/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
""" checks for VTK exporter """
2+
3+
from collections import namedtuple
4+
5+
import numpy as np
6+
7+
from PySDM.exporters import VTKExporter
8+
9+
10+
def test_vtk_exporter_copies_product_data(tmp_path):
11+
"""note: since VTK files contain unencoded binary data, we cannot use XML parsers;
12+
not to introduce a new dependency to PySDM, we read the binary data with NumPy"""
13+
# arrange
14+
productc_filename = tmp_path / "prod"
15+
sut = VTKExporter(products_filename=productc_filename)
16+
17+
grid = (1, 1)
18+
arr = np.zeros(shape=grid, dtype=float)
19+
20+
incr = 666
21+
22+
def plusplus(arr):
23+
arr += incr
24+
return arr
25+
26+
prod = namedtuple(typename="MockProductA", field_names=("get",))(
27+
get=lambda: plusplus(arr)
28+
)
29+
30+
particulator = namedtuple(
31+
typename="MockParticulator", field_names=("products", "n_steps", "dt", "mesh")
32+
)(
33+
n_steps=1,
34+
products={
35+
"a": prod,
36+
"b": prod,
37+
},
38+
dt=0,
39+
mesh=namedtuple(typename="MockMesh", field_names=("dimension", "grid", "size"))(
40+
dimension=2,
41+
grid=grid,
42+
size=(1, 1),
43+
),
44+
)
45+
46+
# act
47+
sut.export_products(particulator)
48+
49+
# assert
50+
offsets = (113, 129)
51+
with open(str(productc_filename) + "_num0000000001.vts", mode="rb") as vtk:
52+
binary_data = vtk.readlines()[14]
53+
for i, off in enumerate(offsets):
54+
assert (
55+
np.frombuffer(binary_data[off : off + 8], dtype=np.float64)
56+
== (i + 1) * incr
57+
)

0 commit comments

Comments
 (0)