Skip to content

Commit 9115fdf

Browse files
committed
Add regression test for data pre-processing.
1 parent 07cb58b commit 9115fdf

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

tests/assets/decoder_only_mini_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def setup():
99
seed=17,
1010
eval_every=1000,
1111
checkpoint_every=1000,
12-
context_length=13,
12+
context_length=14,
1313
num_layers=3,
1414
num_heads=8,
1515
num_features=16,

tests/scripts/test_train.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import optax
55
import jax
66
from jax import numpy as jnp
7+
import numpy
78
import pandas as pd
89
from pathlib import Path
910
from trecs.scripts import train
@@ -163,3 +164,56 @@ def loss_fn(model, inputs, labels, key):
163164
)
164165
< 1e-6
165166
)
167+
168+
169+
def test_data_loader_regression(example_db_path: Path, tmp_path: Path) -> None:
170+
# This test checks that the output of the data iterator conforms to a specific
171+
# output as a regression test.
172+
experiment = train.load_experiment_from_file(
173+
Path(decoder_only_mini_experiment_path)
174+
)
175+
with patch.dict("os.environ", MPD=str(example_db_path)):
176+
experiment.setup_output(tmp_path)
177+
data_source = experiment.create_data_source("train")
178+
data_loader = experiment.create_data_loader("train", data_source, 17)
179+
180+
inputs, labels = next(iter(data_loader))
181+
182+
expected_pos = [
183+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
184+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
185+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 0],
186+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
187+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
188+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
189+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
190+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0],
191+
]
192+
numpy.testing.assert_array_equal(inputs["pos"], expected_pos)
193+
194+
expected_track_id = [
195+
[0, 3834, 3835, 3836, 3837, 3838, 3839, 3840, 3841, 3842, 3843, 2, 3845, 3846],
196+
[0, 2869, 56, 1398, 2870, 2871, 2872, 2873, 2874, 2875, 1124, 2876, 2877, 2878],
197+
[0, 2228, 2229, 2230, 2231, 2232, 2233, 2, 2235, 2236, 2237, 2238, 1, 1],
198+
[0, 3652, 1182, 2761, 1130, 2384, 3653, 3654, 211, 615, 3655, 1687, 3656, 1715],
199+
[0, 283, 284, 285, 286, 287, 288, 289, 290, 291, 2, 293, 294, 295],
200+
[
201+
0,
202+
342,
203+
4135,
204+
4136,
205+
4137,
206+
4138,
207+
4139,
208+
3658,
209+
4140,
210+
4141,
211+
4142,
212+
3734,
213+
4143,
214+
4144,
215+
],
216+
[0, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311],
217+
[0, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 2, 802, 1110, 1],
218+
]
219+
numpy.testing.assert_array_equal(inputs["track_id"], expected_track_id)

0 commit comments

Comments
 (0)