Skip to content

Commit 0cdccf9

Browse files
fscarlierfsoubelet
andauthored
New datatype and loading for tracking data
* test * added new class for full simulation data * deleted random file * docstring, new constant to be used for tests * adapt compare_tbt to be able to check all fields if tracking simulation data is specified * add test to read trackone file fully (all fields) as SimulationData in the TbtData matrices * up minor version * update type hint of matrices in TbtData * field instead of plane * make fieldnames classmethod and return a list of the fields * with docstrings * change full_sim_data for is_tracking_data * group numpy_to_tbt and numpy_to_sim_tbt into one function, with argument to be provided * rename to is_tracking_data here too * remove SIMDATA_FIELDS and use the fieldnames classmethod instead * generate_average_tbtdata now goes through fields automatically * differenciate between index of bunch and index of field * declare DataType in structures, import and use it for type hints * TrackingData instead of SimulationData --------- Co-authored-by: Felix Soubelet <felix.soubelet@protonmail.com>
1 parent cedbe25 commit 0cdccf9

File tree

6 files changed

+124
-40
lines changed

6 files changed

+124
-40
lines changed

tests/test_lhc_and_general.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55

6-
from turn_by_turn.constants import PLANES, PRINT_PRECISION
6+
from turn_by_turn.constants import PRINT_PRECISION
77
from turn_by_turn.errors import DataTypeError
88
from turn_by_turn.io import read_tbt, write_lhc_ascii, write_tbt
99
from turn_by_turn.structures import TbtData
@@ -56,15 +56,16 @@ def test_tbt_write_read_ascii(_sdds_file, _test_file):
5656
# ----- Helpers ----- #
5757

5858

59-
def compare_tbt(origin: TbtData, new: TbtData, no_binary: bool, max_deviation=ASCII_PRECISION) -> None:
59+
def compare_tbt(origin: TbtData, new: TbtData, no_binary: bool, max_deviation = ASCII_PRECISION, is_tracking_data: bool = False) -> None:
6060
assert new.nturns == origin.nturns
6161
assert new.nbunches == origin.nbunches
6262
assert new.bunch_ids == origin.bunch_ids
6363
for index in range(origin.nbunches):
64-
for plane in PLANES:
65-
assert np.all(new.matrices[index][plane].index == origin.matrices[index][plane].index)
66-
origin_mat = origin.matrices[index][plane].to_numpy()
67-
new_mat = new.matrices[index][plane].to_numpy()
64+
# In matrices are either TransverseData or TrackingData and we can get all the fields from the `fieldnames` classmethod
65+
for field in origin.matrices[0].fieldnames():
66+
assert np.all(new.matrices[index][field].index == origin.matrices[index][field].index)
67+
origin_mat = origin.matrices[index][field].to_numpy()
68+
new_mat = new.matrices[index][field].to_numpy()
6869
if no_binary:
6970
assert np.nanmax(np.abs(origin_mat - new_mat)) < max_deviation
7071
else:

tests/test_ptc_trackone.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import pandas as pd
77
import pytest
88

9-
from tests.test_lhc_and_general import compare_tbt, INPUTS_DIR
9+
from tests.test_lhc_and_general import ASCII_PRECISION, INPUTS_DIR, compare_tbt
1010
from turn_by_turn import ptc, trackone
1111
from turn_by_turn.errors import PTCFormatError
12-
from turn_by_turn.structures import TbtData, TransverseData
12+
from turn_by_turn.structures import TbtData, TrackingData, TransverseData
1313

1414

1515
def test_read_ptc(_ptc_file):
@@ -63,14 +63,23 @@ def test_read_trackone_looseparticles(_ptc_file_losses):
6363
assert not new.matrices[0].X.isna().any().any()
6464

6565

66+
def test_read_trackone_simdata(_ptc_file):
67+
new = trackone.read_tbt(_ptc_file, is_tracking_data=True) # read all fields (includes PX, PY, T, PT, S, E)
68+
origin = _original_simulation_data()
69+
compare_tbt(origin, new, True, is_tracking_data=True)
70+
71+
72+
# ----- Helpers ----- #
73+
74+
6675
def _original_trackone(track: bool = False) -> TbtData:
6776
names = np.array(["C1.BPM1"])
6877
matrix = [
69-
TransverseData(
78+
TransverseData( # first "bunch"
7079
X=pd.DataFrame(index=names, data=[[0.001, -0.0003606, -0.00165823, -0.00266631]]),
7180
Y=pd.DataFrame(index=names, data=[[0.001, 0.00070558, -0.00020681, -0.00093807]]),
7281
),
73-
TransverseData(
82+
TransverseData( # second "bunch"
7483
X=pd.DataFrame(index=names, data=[[0.0011, -0.00039666, -0.00182406, -0.00293294]]),
7584
Y=pd.DataFrame(index=names, data=[[0.0011, 0.00077614, -0.00022749, -0.00103188]]),
7685
),
@@ -79,6 +88,37 @@ def _original_trackone(track: bool = False) -> TbtData:
7988
return origin
8089

8190

91+
def _original_simulation_data() -> TbtData:
92+
names = np.array(["C1.BPM1"])
93+
matrices = [
94+
TrackingData( # first "bunch"
95+
X=pd.DataFrame(index=names, data=[[0.001, -0.000361, -0.001658, -0.002666]]),
96+
PX=pd.DataFrame(index=names, data=[[0.0, -0.000202, -0.000368, -0.00047]]),
97+
Y=pd.DataFrame(index=names, data=[[0.001, 0.000706, -0.000207, -0.000938]]),
98+
PY=pd.DataFrame(index=names, data=[[0.0, -0.000349, -0.000392, -0.000092]]),
99+
T=pd.DataFrame(index=names, data=[[0.0, -0.000008, -0.000015, -0.000023]]),
100+
PT=pd.DataFrame(index=names, data=[[0, 0, 0, 0]]),
101+
S=pd.DataFrame(index=names, data=[[0, 0, 0, 0]]),
102+
E=pd.DataFrame(index=names, data=[[500.00088, 500.00088, 500.00088, 500.00088]]),
103+
),
104+
TrackingData( # second "bunch"
105+
X=pd.DataFrame(index=names, data=[[0.0011, -0.000397, -0.001824, -0.002933]]),
106+
PX=pd.DataFrame(index=names, data=[[0.0, -0.000222, -0.000405, -0.000517]]),
107+
Y=pd.DataFrame(index=names, data=[[0.0011, 0.000776, -0.000227, -0.001032]]),
108+
PY=pd.DataFrame(index=names, data=[[0.0, -0.000384, -0.000431, -0.000101]]),
109+
T=pd.DataFrame(index=names, data=[[-0.0, -0.000009, -0.000018, -0.000028]]),
110+
PT=pd.DataFrame(index=names, data=[[0, 0, 0, 0]]),
111+
S=pd.DataFrame(index=names, data=[[0, 0, 0, 0]]),
112+
E=pd.DataFrame(index=names, data=[[500.00088, 500.00088, 500.00088, 500.00088]]),
113+
)
114+
]
115+
origin = TbtData(matrices, date=None, bunch_ids=[0, 1], nturns=4) # [0, 1] for bunch_ids because it's from tracking
116+
return origin
117+
118+
119+
# ----- Fixtures ----- #
120+
121+
82122
@pytest.fixture()
83123
def _ptc_file_no_date() -> Path:
84124
return INPUTS_DIR / "test_trackone_no_date"

turn_by_turn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
__title__ = "turn_by_turn"
66
__description__ = "Read and write turn-by-turn measurement files from different particle accelerator formats."
77
__url__ = "https://github.com/pylhc/turn_by_turn"
8-
__version__ = "0.4.2"
8+
__version__ = "0.5.0"
99
__author__ = "pylhc"
1010
__author_email__ = "pylhc@github.com"
1111
__license__ = "MIT"

turn_by_turn/structures.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
from dataclasses import dataclass, field, fields
88
from datetime import datetime
9-
from typing import Dict, List, Sequence
9+
from typing import List, Sequence, Union
1010

1111
import pandas as pd
1212
from dateutil import tz
@@ -21,23 +21,54 @@ class TransverseData:
2121
X: pd.DataFrame # horizontal data
2222
Y: pd.DataFrame # vertical data
2323

24-
def fieldnames(self):
25-
return (f.name for f in fields(self))
24+
@classmethod
25+
def fieldnames(self) -> List[str]:
26+
"""Return a list of the fields of this dataclass."""
27+
return list(f.name for f in fields(self))
2628

2729
def __getitem__(self, item): # to access X and Y like one would with a dictionary
2830
if item not in self.fieldnames():
2931
raise KeyError(f"'{item}' is not in the fields of a {self.__class__.__name__} object.")
3032
return getattr(self, item)
3133

3234

35+
@dataclass
36+
class TrackingData:
37+
"""
38+
Object holding multidimensional turn-by-turn simulation data in the form of pandas DataFrames.
39+
"""
40+
41+
X: pd.DataFrame # horizontal data
42+
PX: pd.DataFrame # horizontal momentum data
43+
Y: pd.DataFrame # vertical data
44+
PY: pd.DataFrame # vertical momentum data
45+
T: pd.DataFrame # longitudinal data
46+
PT: pd.DataFrame # longitudinal momentum data
47+
S: pd.DataFrame # longitudinal position data
48+
E: pd.DataFrame # energy data
49+
50+
@classmethod
51+
def fieldnames(self) -> List[str]:
52+
"""Return a list of the fields of this dataclass."""
53+
return list(f.name for f in fields(self))
54+
55+
def __getitem__(self, item): # to access fields like one would with a dictionary
56+
if item not in self.fieldnames():
57+
raise KeyError(f"'{item}' is not in the fields of a {self.__class__.__name__} object.")
58+
return getattr(self, item)
59+
60+
61+
DataType = Union[TransverseData, TrackingData]
62+
63+
3364
@dataclass
3465
class TbtData:
3566
"""
3667
Object holding a representation of a Turn-by-Turn data measurement. The date of the measurement,
3768
the transverse data, number of turns and bunches as well as the bunch IDs are encapsulated in this object.
3869
"""
3970

40-
matrices: Sequence[TransverseData] # each entry corresponds to a bunch
71+
matrices: Sequence[DataType] # each entry corresponds to a bunch
4172
date: datetime = None # will default in post_init
4273
bunch_ids: List[int] = None # will default in post_init
4374
nturns: int = None

turn_by_turn/trackone.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,34 @@
1313

1414
import numpy as np
1515

16-
from turn_by_turn.structures import TbtData
16+
from turn_by_turn.structures import TbtData, TrackingData, TransverseData
1717
from turn_by_turn.utils import numpy_to_tbt
1818

1919
LOGGER = logging.getLogger()
2020

2121

22-
def read_tbt(file_path: Union[str, Path]) -> TbtData:
22+
def read_tbt(file_path: Union[str, Path], is_tracking_data: bool = False) -> TbtData:
2323
"""
2424
Reads turn-by-turn data from the ``MAD-X`` **trackone** format file.
2525
2626
Args:
2727
file_path (Union[str, Path]): path to the turn-by-turn measurement file.
28+
is_tracking_data (bool): if ``True``, all (``X``, ``PX``, ``Y``, ``PY``,
29+
``T``, ``PT``, ``S``, ``E``) fields are expected in the file as it
30+
is considered a full tracking simulation output. Those are then read
31+
into ``TrackingData`` objects. Defaults to ``False``.
2832
2933
Returns:
3034
A ``TbTData`` object with the loaded data.
3135
"""
3236
nturns, npart = get_trackone_stats(file_path)
3337
names, matrix = get_structure_from_trackone(nturns, npart, file_path)
34-
# matrix[0, 2] contains just (x, y) samples.
35-
return numpy_to_tbt(names, matrix[[0, 2]])
38+
if is_tracking_data:
39+
# Converts full tracking output to TbTData.
40+
return numpy_to_tbt(names, matrix, datatype=TrackingData)
41+
else:
42+
# matrix[0, 2] contains just (x, y) samples.
43+
return numpy_to_tbt(names, matrix[[0, 2]], datatype=TransverseData)
3644

3745

3846
def get_trackone_stats(file_path: Union[str, Path], write_out: bool = False) -> Tuple[int, int]:

turn_by_turn/utils.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
Utility functions for convenience operations on turn-by-turn data objects in this package.
66
"""
77
import logging
8-
9-
from typing import Dict, Sequence
8+
from typing import Sequence, Union
109

1110
import numpy as np
1211
import pandas as pd
1312

14-
from turn_by_turn.constants import PLANES, PLANE_TO_NUM
13+
from turn_by_turn.constants import PLANE_TO_NUM, PLANES
1514
from turn_by_turn.errors import ExclusiveArgumentsError
16-
from turn_by_turn.structures import TbtData, TransverseData
15+
from turn_by_turn.structures import DataType, TbtData, TransverseData
1716

1817
LOGGER = logging.getLogger(__name__)
1918

@@ -31,19 +30,18 @@ def generate_average_tbtdata(tbtdata: TbtData) -> TbtData:
3130
"""
3231
data = tbtdata.matrices
3332
bpm_names = data[0].X.index
33+
datatype = tbtdata.matrices[0].__class__
3434

3535
new_matrices = [
36-
TransverseData(
37-
X=pd.DataFrame(
38-
index=bpm_names,
39-
data=get_averaged_data(bpm_names, data, "X", tbtdata.nturns),
40-
dtype=float,
41-
),
42-
Y=pd.DataFrame(
43-
index=bpm_names,
44-
data=get_averaged_data(bpm_names, data, "Y", tbtdata.nturns),
45-
dtype=float,
46-
),
36+
datatype( # datatype is directly the class to load data into
37+
**{ # for each field in the datatype, load the corresponding matrix
38+
field: pd.DataFrame(
39+
index=bpm_names,
40+
data=get_averaged_data(bpm_names, data, field, tbtdata.nturns),
41+
dtype=float,
42+
)
43+
for field in datatype.fieldnames()
44+
}
4745
)
4846
]
4947
return TbtData(new_matrices, tbtdata.date, [1], tbtdata.nturns)
@@ -151,14 +149,18 @@ def add_noise_to_tbt(data: TbtData, noise: float = None, sigma: float = None, se
151149
)
152150

153151

154-
def numpy_to_tbt(names: np.ndarray, matrix: np.ndarray) -> TbtData:
152+
def numpy_to_tbt(names: np.ndarray, matrix: np.ndarray, datatype: DataType = TransverseData) -> TbtData:
155153
"""
156154
Converts turn by turn matrices and names into a ``TbTData`` object.
157155
158156
Args:
159157
names (np.ndarray): Numpy array of BPM names.
160158
matrix (np.ndarray): 4D Numpy array [quantity, BPM, particle/bunch No., turn No.]
161159
quantities in order [x, y].
160+
datatype (DataType): The type of data to be converted to in the matrices. Either
161+
``TransverseData`` (which implies reading ``X`` and ``Y`` fields) or
162+
``TrackingData`` (which implies reading all 8 fields). Defaults to
163+
``TransverseData``.
162164
163165
Returns:
164166
A ``TbtData`` object loaded with the matrices in the provided numpy arrays.
@@ -167,12 +169,14 @@ def numpy_to_tbt(names: np.ndarray, matrix: np.ndarray) -> TbtData:
167169
_, _, nbunches, nturns = matrix.shape
168170
matrices = []
169171
indices = []
170-
for index in range(nbunches):
172+
for idx_bunch in range(nbunches):
171173
matrices.append(
172-
TransverseData(
173-
X=pd.DataFrame(index=names, data=matrix[0, :, index, :]),
174-
Y=pd.DataFrame(index=names, data=matrix[1, :, index, :]),
174+
datatype( # datatype is directly the class to load data into (TransverseData or TrackingData)
175+
**{ # for each field in the datatype, load the corresponding matrix
176+
field: pd.DataFrame(index=names, data=matrix[idx_field, :, idx_bunch, :])
177+
for idx_field, field in enumerate(datatype.fieldnames())
178+
}
175179
)
176180
)
177-
indices.append(index)
181+
indices.append(idx_bunch)
178182
return TbtData(matrices=matrices, bunch_ids=indices, nturns=nturns)

0 commit comments

Comments
 (0)