Skip to content

Commit 843dd8b

Browse files
authored
Merge pull request #16 from pylhc/sps_remove_trailing_planes
remove trailing planes from sps data
2 parents 0cdccf9 + c3d1d03 commit 843dd8b

File tree

5 files changed

+86
-28
lines changed

5 files changed

+86
-28
lines changed

tests/test_sps.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from turn_by_turn import sps, TbtData, TransverseData
99

1010

11-
def test_read_write_real_data(_sps_file, tmp_path):
12-
input_sdds = sps.read_tbt(_sps_file)
11+
@pytest.mark.parametrize("remove_planes", [True, False])
12+
def test_read_write_real_data(_sps_file, tmp_path, remove_planes):
13+
input_sdds = sps.read_tbt(_sps_file, remove_trailing_bpm_plane=remove_planes)
1314
tmp_sdds = tmp_path / "sps.sdds"
14-
sps.write_tbt(tmp_sdds, input_sdds)
15-
read_sdds = sps.read_tbt(tmp_sdds)
15+
sps.write_tbt(tmp_sdds, input_sdds, add_trailing_bpm_plane=remove_planes)
16+
read_sdds = sps.read_tbt(tmp_sdds, remove_trailing_bpm_plane=remove_planes)
1617
compare_tbt(input_sdds, read_sdds, no_binary=True)
1718

1819

@@ -25,21 +26,44 @@ def test_write_read(tmp_path):
2526
matrices=[
2627
TransverseData(
2728
X=pd.DataFrame(
28-
index=[f"BPM{i}.H" for i in range(nbpms_x)],
29+
index=[f"BPMH{i}.H" for i in range(nbpms_x)],
2930
data=create_data(np.linspace(-np.pi, np.pi, nturns, endpoint=False), nbpms_x, np.sin)
3031
),
3132
Y=pd.DataFrame(
32-
index=[f"BPM{i}.V" for i in range(nbpms_y)],
33+
index=[f"BPMV{i}.V" for i in range(nbpms_y)],
3334
data=create_data(np.linspace(-np.pi, np.pi, nturns, endpoint=False), nbpms_y, np.cos)
3435
),
3536
)
3637
],
3738
)
3839
tmp_sdds = tmp_path / "sps_fake_data.sdds"
39-
sps.write_tbt(tmp_sdds, original)
40-
read_sdds = sps.read_tbt(tmp_sdds)
40+
# Normal read/write test
41+
sps.write_tbt(tmp_sdds, original, add_trailing_bpm_plane=False)
42+
read_sdds = sps.read_tbt(tmp_sdds, remove_trailing_bpm_plane=False)
4143
compare_tbt(original, read_sdds, no_binary=True)
4244

45+
# Test no name changes when writing and planes already present
46+
sps.write_tbt(tmp_sdds, original, add_trailing_bpm_plane=True)
47+
read_sdds = sps.read_tbt(tmp_sdds, remove_trailing_bpm_plane=False)
48+
compare_tbt(original, read_sdds, no_binary=True)
49+
50+
# Test plane removal on reading
51+
read_sdds = sps.read_tbt(tmp_sdds, remove_trailing_bpm_plane=True)
52+
assert not any(read_sdds.matrices[0].X.index.str.endswith(".H"))
53+
assert not any(read_sdds.matrices[0].Y.index.str.endswith(".V"))
54+
55+
# Test planes stay off when writing
56+
sps.write_tbt(tmp_sdds, read_sdds, add_trailing_bpm_plane=False)
57+
read_sdds = sps.read_tbt(tmp_sdds, remove_trailing_bpm_plane=False)
58+
assert not any(read_sdds.matrices[0].X.index.str.endswith(".H"))
59+
assert not any(read_sdds.matrices[0].Y.index.str.endswith(".V"))
60+
61+
# Test adding planes again
62+
sps.write_tbt(tmp_sdds, read_sdds, add_trailing_bpm_plane=True)
63+
read_sdds = sps.read_tbt(tmp_sdds, remove_trailing_bpm_plane=False)
64+
compare_tbt(original, read_sdds, no_binary=True)
65+
66+
4367

4468
@pytest.fixture()
4569
def _sps_file() -> Path:

turn_by_turn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
__title__ = "turn_by_turn"
66
__description__ = "Read and write turn-by-turn measurement files from different particle accelerator formats."
77
__url__ = "https://github.com/pylhc/turn_by_turn"
8-
__version__ = "0.5.0"
8+
__version__ = "0.6.0"
99
__author__ = "pylhc"
1010
__author_email__ = "pylhc@github.com"
1111
__license__ = "MIT"

turn_by_turn/ascii.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
containing the columns:
88
- Plane (0 for horizontal, 1 for vertical)
99
- Observation point (i.e. BPM name)
10-
- BPM index,
10+
- BPM index/longitunial location
1111
- Value Turn 1, Turn 2, etc.
1212
"""
1313
import logging
1414
from datetime import datetime
1515
from pathlib import Path
16+
import re
1617
from typing import TextIO, Union, Tuple, List, Optional
1718

1819
import numpy as np
@@ -98,23 +99,27 @@ def _write_tbt_data(tbt_data: TbtData, bunch_id: int, output_file: TextIO) -> No
9899

99100
# ----- Reader ----- #
100101

101-
def read_tbt(
102-
file_path: Union[str, Path]
103-
) -> Tuple[List[TransverseData], Optional[datetime]]:
102+
def read_tbt(file_path: Union[str, Path], bunch_id: int = None) -> TbtData:
104103
"""
105104
Reads turn-by-turn data from an ASCII turn-by-turn format file, and return the date as well as
106105
parsed matrices for construction of a ``TbtData`` object.
107106
108107
Args:
109108
file_path (Union[str, Path]): path to the turn-by-turn measurement file.
109+
bunch_id (int, optional): the bunch id associated with this file.
110+
Defaults to `None`, but is then attempted to parsed
111+
from the filename. If not found, `0` is used.
110112
111113
Returns:
112-
Turn-by-turn data matrices and
114+
Turn-by-turn data
113115
"""
114116
data_lines = Path(file_path).read_text().splitlines()
115117
bpm_names = {"X": [], "Y": []}
116118
bpm_data = {"X": [], "Y": []}
117119
date = None # will switch to TbtData.date's default if not found in file
120+
121+
if bunch_id is None:
122+
bunch_id = _parse_bunch_id(file_path)
118123

119124
for line in data_lines:
120125
line = line.strip()
@@ -144,17 +149,18 @@ def read_tbt(
144149
Y=pd.DataFrame(index=bpm_names["Y"], data=np.array(bpm_data["Y"])),
145150
)
146151
]
147-
return matrices, date
152+
return TbtData(matrices=matrices, date=date, bunch_ids=[bunch_id], nturns=matrices[0].X.shape[1])
148153

149154

150155
# ----- Helpers ----- #
151156

157+
152158
def _parse_samples(line: str) -> Tuple[str, str, np.ndarray]:
153159
"""Parse a line into its different elements."""
154160
parts = line.split()
155161
plane_num = parts[0]
156162
bpm_name = parts[1]
157-
# bunch_id = part[2] # not used, comment for clarification
163+
# bpm_location = part[2] # not used, comment for clarification
158164
bpm_samples = np.array([float(part) for part in parts[3:]])
159165
return plane_num, bpm_name, bpm_samples
160166

@@ -172,6 +178,17 @@ def _parse_date(line: str) -> datetime:
172178
return datetime.today().replace(tzinfo=tz.tzutc())
173179

174180

181+
def _parse_bunch_id(file_path: Path) -> int:
182+
"""Parse the bunch_id from the filename."""
183+
bunch_id_match = re.match(r".*_(?P<bunch_id>\d+)(.sdds)?$", file_path.name)
184+
if bunch_id_match:
185+
try:
186+
return int(bunch_id_match.group("bunch_id"))
187+
except ValueError:
188+
pass
189+
return 0
190+
191+
175192
# For backwards compatibility <0.4.2:
176193
write_ascii = write_tbt
177194
read_ascii = read_tbt

turn_by_turn/lhc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def read_tbt(file_path: Union[str, Path]) -> TbtData:
5151
LOGGER.debug(f"Reading LHC file at path: '{file_path.absolute()}'")
5252

5353
if is_ascii_file(file_path):
54-
matrices, date = read_ascii(file_path)
55-
return TbtData(matrices, date, [0], matrices[0].X.shape[1])
54+
return read_ascii(file_path)
5655

5756
sdds_file = sdds.read(file_path)
5857
nbunches = sdds_file.values[N_BUNCHES]

turn_by_turn/sps.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
from datetime import datetime
99
from pathlib import Path
10+
import re
1011
from typing import Union
1112

1213
import numpy as np
@@ -26,13 +27,17 @@
2627
BPM_PLANES: str = "MonPlanes"
2728

2829

29-
def read_tbt(file_path: Union[str, Path]) -> TbtData:
30+
def read_tbt(file_path: Union[str, Path], remove_trailing_bpm_plane: bool = True) -> TbtData:
3031
"""
3132
Reads turn-by-turn data from the ``SPS``'s **SDDS** format file.
3233
Will first determine if it is in ASCII format to figure out which reading method to use.
3334
3435
Args:
3536
file_path (Union[str, Path]): path to the turn-by-turn measurement file.
37+
remove_trailing_bpm_plane (bool, optional): if ``True``, will remove the trailing
38+
BPM plane ('.H', '.V') from the BPM-names.
39+
This makes the measurement data compatible with the madx-models.
40+
Defaults to ``True``.
3641
3742
Returns:
3843
A ``TbTData`` object with the loaded data.
@@ -41,8 +46,7 @@ def read_tbt(file_path: Union[str, Path]) -> TbtData:
4146
LOGGER.debug(f"Reading SPS file at path: '{file_path.absolute()}'")
4247

4348
if is_ascii_file(file_path):
44-
matrices, date = read_ascii(file_path)
45-
return TbtData(matrices, date, [0], matrices[0].X.shape[1])
49+
return read_ascii(file_path)
4650

4751
sdds_file = sdds.read(file_path)
4852

@@ -53,30 +57,39 @@ def read_tbt(file_path: Union[str, Path]) -> TbtData:
5357
bpm_names = np.array(sdds_file.values[BPM_NAMES])
5458
bpm_planes = np.array(sdds_file.values[BPM_PLANES]).astype(bool)
5559

56-
ver_bpms = bpm_names[bpm_planes]
57-
hor_bpms = bpm_names[~bpm_planes]
60+
bpm_names_y = bpm_names[bpm_planes]
61+
bpm_names_x = bpm_names[~bpm_planes]
5862

59-
tbt_data_x = [sdds_file.values[bpm] for bpm in hor_bpms]
60-
tbt_data_y = [sdds_file.values[bpm] for bpm in ver_bpms]
63+
tbt_data_x = [sdds_file.values[bpm] for bpm in bpm_names_x]
64+
tbt_data_y = [sdds_file.values[bpm] for bpm in bpm_names_y]
65+
66+
if remove_trailing_bpm_plane:
67+
pattern = re.compile("\.[HV]$", flags=re.IGNORECASE)
68+
bpm_names_x = [pattern.sub("", bpm) for bpm in bpm_names_x]
69+
bpm_names_y = [pattern.sub("", bpm) for bpm in bpm_names_y]
6170

6271
matrices = [
6372
TransverseData(
64-
X=pd.DataFrame(index=hor_bpms, data=tbt_data_x, dtype=float),
65-
Y=pd.DataFrame(index=ver_bpms, data=tbt_data_y, dtype=float),
73+
X=pd.DataFrame(index=bpm_names_x, data=tbt_data_x, dtype=float),
74+
Y=pd.DataFrame(index=bpm_names_y, data=tbt_data_y, dtype=float),
6675
)
6776
]
6877

6978
return TbtData(matrices, date, [0], nturns)
7079

7180

72-
def write_tbt(output_path: Union[str, Path], tbt_data: TbtData) -> None:
81+
def write_tbt(output_path: Union[str, Path], tbt_data: TbtData, add_trailing_bpm_plane: bool = True) -> None:
7382
"""
7483
Write a ``TbtData`` object's data to file, in a ``SPS``'s **SDDS** format.
7584
The format is reduced to the necessary parameters used by the reader.
7685
7786
Args:
7887
output_path (Union[str, Path]): path to a the disk location where to write the data.
7988
tbt_data (TbtData): the ``TbtData`` object to write to disk.
89+
add_trailing_bpm_plane (bool, optional): if ``True``, will add the trailing
90+
BPM plane ('.H', '.V') to the BPM-names. This assures that all BPM-names are unique,
91+
and that the measurement data is compatible with the sdds files from the FESA-class.
92+
Defaults to ``True``.
8093
"""
8194
output_path = Path(output_path)
8295
LOGGER.info(f"Writing TbTdata in binary SDDS (SPS) format at '{output_path.absolute()}'")
@@ -85,6 +98,11 @@ def write_tbt(output_path: Union[str, Path], tbt_data: TbtData) -> None:
8598

8699
# bpm names
87100
bpm_names_x, bpm_names_y = df_x.index.to_list(), df_y.index.to_list()
101+
102+
if add_trailing_bpm_plane:
103+
bpm_names_x = [f"{bpm_name}.H" if not bpm_name.endswith(".H") else bpm_name for bpm_name in bpm_names_x]
104+
bpm_names_y = [f"{bpm_name}.V" if not bpm_name.endswith(".V") else bpm_name for bpm_name in bpm_names_y]
105+
88106
bpm_names = bpm_names_x + bpm_names_y
89107

90108
# bpm planes

0 commit comments

Comments
 (0)