1+
2+ from datetime import datetime
3+ from pathlib import Path
4+
5+ import numpy as np
6+ import pandas as pd
7+ import pytest
8+ import h5py
9+
10+ from turn_by_turn .constants import PRINT_PRECISION
11+ from turn_by_turn .errors import DataTypeError
12+ from turn_by_turn .structures import TbtData , TransverseData
13+ from tests .test_lhc_and_general import create_data , compare_tbt
14+
15+ from turn_by_turn .doros import N_ORBIT_SAMPLES , read_tbt , write_tbt , DEFAULT_BUNCH_ID , POSITIONS
16+
17+ INPUTS_DIR = Path (__file__ ).parent / "inputs"
18+
19+
20+ def test_read_write_real_data (tmp_path ):
21+ tbt = read_tbt (INPUTS_DIR / "test_doros.h5" , bunch_id = 10 )
22+
23+ assert tbt .nbunches == 1
24+ assert len (tbt .matrices ) == 1
25+ assert tbt .nturns == 50000
26+ assert tbt .matrices [0 ].X .shape == (3 , tbt .nturns )
27+ assert tbt .matrices [0 ].Y .shape == (3 , tbt .nturns )
28+ assert len (set (tbt .matrices [0 ].X .index )) == 3
29+ assert np .all (tbt .matrices [0 ].X .index == tbt .matrices [0 ].Y .index )
30+
31+ file_path = tmp_path / "test_file.h5"
32+ write_tbt (tbt , file_path )
33+ new = read_tbt (file_path , bunch_id = 10 )
34+ compare_tbt (tbt , new , no_binary = False )
35+
36+
37+ def test_write_read (tmp_path ):
38+ tbt = _tbt_data ()
39+ file_path = tmp_path / "test_file.h5"
40+ write_tbt (tbt , file_path )
41+ new = read_tbt (file_path )
42+ compare_tbt (tbt , new , no_binary = False )
43+
44+
45+ def test_read_raises_different_bpm_lengths (tmp_path ):
46+ tbt = _tbt_data ()
47+ file_path = tmp_path / "test_file.h5"
48+ write_tbt (tbt , file_path )
49+
50+ bpm = tbt .matrices [0 ].X .index [0 ]
51+
52+ # modify the BPM lengths in the file
53+ with h5py .File (file_path , "r+" ) as h5f :
54+ delta = 10
55+ del h5f [bpm ][N_ORBIT_SAMPLES ]
56+ h5f [bpm ][N_ORBIT_SAMPLES ] = [tbt .matrices [0 ].X .shape [1 ] - delta ]
57+ for key in POSITIONS .values ():
58+ data = h5f [bpm ][key ][:- delta ]
59+ del h5f [bpm ][key ]
60+ h5f [bpm ][key ] = data
61+
62+ with pytest .raises (ValueError ) as e :
63+ read_tbt (file_path )
64+ assert "Not all BPMs have the same number of turns!" in str (e )
65+
66+
67+ def test_read_raises_on_different_bpm_lengths_in_data (tmp_path ):
68+ tbt = _tbt_data ()
69+ file_path = tmp_path / "test_file.h5"
70+ write_tbt (tbt , file_path )
71+
72+ bpms = [tbt .matrices [0 ].X .index [i ] for i in (0 , 2 )]
73+
74+ # modify the BPM lengths in the file
75+ with h5py .File (file_path , "r+" ) as h5f :
76+ for bpm in bpms :
77+ del h5f [bpm ][N_ORBIT_SAMPLES ]
78+ h5f [bpm ][N_ORBIT_SAMPLES ] = [tbt .matrices [0 ].X .shape [1 ] + 10 ]
79+
80+ with pytest .raises (ValueError ) as e :
81+ read_tbt (file_path )
82+ assert "Found BPMs with different data lengths" in str (e )
83+ assert all (bpm in str (e ) for bpm in bpms )
84+
85+
86+ def _tbt_data () -> TbtData :
87+ """TbT data for testing. Adding random noise, so that the data is different per BPM."""
88+ nturns = 2000
89+ bpms = ["TBPM1" , "TBPM2" , "TBPM3" , "TBPM4" ]
90+
91+ return TbtData (
92+ matrices = [
93+ TransverseData (
94+ X = pd .DataFrame (
95+ index = bpms ,
96+ data = create_data (
97+ np .linspace (- np .pi , np .pi , nturns , endpoint = False ),
98+ nbpm = len (bpms ), function = np .sin , noise = 0.02
99+ ),
100+ dtype = float ,
101+ ),
102+ Y = pd .DataFrame (
103+ index = bpms ,
104+ data = create_data (
105+ np .linspace (- np .pi , np .pi , nturns , endpoint = False ),
106+ nbpm = len (bpms ), function = np .cos , noise = 0.015
107+ ),
108+ dtype = float ,
109+ ),
110+ )
111+ ],
112+ date = datetime .now (),
113+ bunch_ids = [DEFAULT_BUNCH_ID ],
114+ nturns = nturns ,
115+ )
0 commit comments