| 
4 | 4 | import optax  | 
5 | 5 | import jax  | 
6 | 6 | from jax import numpy as jnp  | 
 | 7 | +import numpy  | 
7 | 8 | import pandas as pd  | 
8 | 9 | from pathlib import Path  | 
9 | 10 | from trecs.scripts import train  | 
@@ -163,3 +164,56 @@ def loss_fn(model, inputs, labels, key):  | 
163 | 164 |         )  | 
164 | 165 |         < 1e-6  | 
165 | 166 |     )  | 
 | 167 | + | 
 | 168 | + | 
 | 169 | +def test_data_loader_regression(example_db_path: Path, tmp_path: Path) -> None:  | 
 | 170 | +    # This test checks that the output of the data iterator conforms to a specific  | 
 | 171 | +    # output as a regression test.  | 
 | 172 | +    experiment = train.load_experiment_from_file(  | 
 | 173 | +        Path(decoder_only_mini_experiment_path)  | 
 | 174 | +    )  | 
 | 175 | +    with patch.dict("os.environ", MPD=str(example_db_path)):  | 
 | 176 | +        experiment.setup_output(tmp_path)  | 
 | 177 | +        data_source = experiment.create_data_source("train")  | 
 | 178 | +        data_loader = experiment.create_data_loader("train", data_source, 17)  | 
 | 179 | + | 
 | 180 | +    inputs, labels = next(iter(data_loader))  | 
 | 181 | + | 
 | 182 | +    expected_pos = [  | 
 | 183 | +        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],  | 
 | 184 | +        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],  | 
 | 185 | +        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 0],  | 
 | 186 | +        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],  | 
 | 187 | +        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],  | 
 | 188 | +        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],  | 
 | 189 | +        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],  | 
 | 190 | +        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0],  | 
 | 191 | +    ]  | 
 | 192 | +    numpy.testing.assert_array_equal(inputs["pos"], expected_pos)  | 
 | 193 | + | 
 | 194 | +    expected_track_id = [  | 
 | 195 | +        [0, 3834, 3835, 3836, 3837, 3838, 3839, 3840, 3841, 3842, 3843, 2, 3845, 3846],  | 
 | 196 | +        [0, 2869, 56, 1398, 2870, 2871, 2872, 2873, 2874, 2875, 1124, 2876, 2877, 2878],  | 
 | 197 | +        [0, 2228, 2229, 2230, 2231, 2232, 2233, 2, 2235, 2236, 2237, 2238, 1, 1],  | 
 | 198 | +        [0, 3652, 1182, 2761, 1130, 2384, 3653, 3654, 211, 615, 3655, 1687, 3656, 1715],  | 
 | 199 | +        [0, 283, 284, 285, 286, 287, 288, 289, 290, 291, 2, 293, 294, 295],  | 
 | 200 | +        [  | 
 | 201 | +            0,  | 
 | 202 | +            342,  | 
 | 203 | +            4135,  | 
 | 204 | +            4136,  | 
 | 205 | +            4137,  | 
 | 206 | +            4138,  | 
 | 207 | +            4139,  | 
 | 208 | +            3658,  | 
 | 209 | +            4140,  | 
 | 210 | +            4141,  | 
 | 211 | +            4142,  | 
 | 212 | +            3734,  | 
 | 213 | +            4143,  | 
 | 214 | +            4144,  | 
 | 215 | +        ],  | 
 | 216 | +        [0, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311],  | 
 | 217 | +        [0, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 2, 802, 1110, 1],  | 
 | 218 | +    ]  | 
 | 219 | +    numpy.testing.assert_array_equal(inputs["track_id"], expected_track_id)  | 
0 commit comments