Skip to content

Commit 98abed8

Browse files
Merge pull request #223 from cognizant-ai-labs/fix-tests
reduce the size of the test cases to automated testing runs faster
2 parents a38ea0b + a6d781c commit 98abed8

File tree

7 files changed

+57
-41
lines changed

7 files changed

+57
-41
lines changed

covid_xprize/examples/predictors/conditional_lstm/conditional_xprize_predictor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,21 @@ def __init__(self, path_to_model_weights, data_url):
4646

4747
self.df = prepare_cases_dataframe(data_url, threshold_min_cases=True)
4848

49-
def train(self, return_results=False) -> Union[Model, tuple[Model, dict]]:
49+
def train(self,
50+
return_results=False,
51+
nb_training_geos: int = NB_TRAINING_DAYS,
52+
nb_testing_geos: int = NB_TESTING_GEOS,
53+
nb_trials: int = NUM_TRIALS,
54+
nb_epochs: int = NUM_EPOCHS,) -> Union[Model, tuple[Model, dict]]:
5055
best_model, results_df = train_predictor(
5156
training_data=self.df,
5257
nb_lookback_days=NB_LOOKBACK_DAYS,
5358
nb_training_days=NB_TRAINING_DAYS,
5459
nb_test_days=NB_TEST_DAYS,
55-
nb_training_geos=NB_TRAINING_GEOS,
56-
nb_testing_geos=NB_TESTING_GEOS,
57-
nb_trials=NUM_TRIALS,
58-
nb_epochs=NUM_EPOCHS,
60+
nb_training_geos=nb_training_geos,
61+
nb_testing_geos=nb_testing_geos,
62+
nb_trials=nb_trials,
63+
nb_epochs=nb_epochs,
5964
lstm_size=LSTM_SIZE,
6065
)
6166
if return_results:
Binary file not shown.

covid_xprize/examples/predictors/conditional_lstm/tests/test_conditional_xprize_predictor.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License.
22

33
import os
4+
from pathlib import Path
45
import unittest
56
import urllib.request
67

@@ -9,11 +10,11 @@
910
from covid_xprize.examples.predictors.conditional_lstm.conditional_xprize_predictor import ConditionalXPrizePredictor
1011
from covid_xprize.oxford_data import load_oxford_data_trimmed
1112

12-
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
13-
FIXTURES_PATH = os.path.join(ROOT_DIR, 'fixtures')
14-
EXAMPLE_INPUT_FILE = os.path.join(ROOT_DIR, "../../../../validation/data/2020-09-30_historical_ip.csv")
15-
DATA_FILE = os.path.join(FIXTURES_PATH, "OxCGRT_trimmed.csv")
16-
PREDICTOR_WEIGHTS = os.path.join(FIXTURES_PATH, "trained_model_weights_for_tests.h5")
13+
ROOT_DIR = Path(__file__).parent
14+
FIXTURES_PATH = ROOT_DIR / 'fixtures'
15+
EXAMPLE_INPUT_FILE = (ROOT_DIR / "../../../../validation/data/2020-09-30_historical_ip.csv").absolute()
16+
DATA_FILE = FIXTURES_PATH / "OxCGRT_trimmed.csv"
17+
PREDICTOR_WEIGHTS = FIXTURES_PATH / "trained_model_weights_for_tests.h5"
1718

1819
TRAINING_START_DATE = "2020-06-01"
1920
TRAINING_END_DATE = "2020-07-31"
@@ -25,12 +26,16 @@ class TestConditionalXPrizePredictor(unittest.TestCase):
2526

2627
@classmethod
2728
def setUpClass(cls):
28-
df = load_oxford_data_trimmed(end_date=TRAINING_END_DATE, start_date=TRAINING_START_DATE)
29-
df.to_csv(DATA_FILE, index=False)
29+
FIXTURES_PATH.mkdir(exist_ok=True)
30+
if not DATA_FILE.exists():
31+
df = load_oxford_data_trimmed(end_date=TRAINING_END_DATE, start_date=TRAINING_START_DATE)
32+
df.to_csv(DATA_FILE, index=False)
3033

3134
def test_train_and_predict(self):
3235
predictor = ConditionalXPrizePredictor(None, DATA_FILE)
33-
model = predictor.train()
36+
37+
# Testing on a small number of epochs, trials and geos to make sure everything is wired correctly
38+
model = predictor.train(nb_epochs=2, nb_trials=2, nb_testing_geos=2, nb_training_geos=2)
3439
model.save_weights(PREDICTOR_WEIGHTS)
3540
self.assertIsNotNone(model)
3641

covid_xprize/examples/predictors/conditional_lstm/train_predictor.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,24 +149,22 @@ def train_predictor(training_data: pd.DataFrame,
149149
# Gather test info
150150
country_indeps = []
151151
country_predss = []
152-
country_casess = []
153152
for model in models:
154-
country_indep, country_preds, country_cases = _lstm_get_test_rollouts(model,
155-
df,
156-
test_countries,
157-
country_samples,
158-
context_column)
153+
country_indep, country_preds = _lstm_get_test_rollouts( model,
154+
df,
155+
test_countries,
156+
country_samples,
157+
context_column)
159158
country_indeps.append(country_indep)
160159
country_predss.append(country_preds)
161-
country_casess.append(country_cases)
162160

163161
# Compute daily smooth cases per 100K mae
164162
test_case_maes = []
165163
for m in range(len(models)):
166164
total_loss = 0.
167165
for c in test_countries:
168166
true_cases = np.array(df[df.GeoID == c].SmoothNewCasesPer100K)[-nb_test_days:]
169-
pred_cases = country_casess[m][c][-nb_test_days:]
167+
pred_cases = country_predss[m][c][-nb_test_days:]
170168
if true_cases.shape != pred_cases.shape: # Insufficient data
171169
continue
172170
total_loss += np.mean(np.abs(true_cases - pred_cases))
@@ -228,7 +226,6 @@ def _lstm_roll_out_predictions(model, initial_context_input, initial_action_inpu
228226
def _lstm_get_test_rollouts(model, df, top_countries, country_samples, context_column):
229227
country_indep = {}
230228
country_preds = {}
231-
country_cases = {}
232229
for c in top_countries:
233230
X_test_context = country_samples[c]['X_test_context']
234231
X_test_action = country_samples[c]['X_test_action']
@@ -250,4 +247,4 @@ def _lstm_get_test_rollouts(model, df, top_countries, country_samples, context_c
250247
future_action_sequence)
251248
country_preds[c] = preds
252249

253-
return country_indep, country_preds, country_cases
250+
return country_indep, country_preds

covid_xprize/examples/predictors/lstm/tests/test_xprize_predictor.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
# Copyright 2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License.
22

33
import os
4+
from pathlib import Path
45
import unittest
56
import urllib.request
67

78
import pandas as pd
89

910
from covid_xprize.examples.predictors.lstm.xprize_predictor import XPrizePredictor
11+
from covid_xprize.oxford_data import load_oxford_data_trimmed
1012

11-
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
12-
FIXTURES_PATH = os.path.join(ROOT_DIR, 'fixtures')
13-
EXAMPLE_INPUT_FILE = os.path.join(ROOT_DIR, "../../../../validation/data/2020-09-30_historical_ip.csv")
14-
DATA_FILE = os.path.join(FIXTURES_PATH, "OxCGRT_latest.csv")
15-
DATA_URL =\
16-
"https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker-legacy/main/legacy_data_202207/OxCGRT_latest.csv"
17-
PREDICTOR_WEIGHTS = os.path.join(FIXTURES_PATH, "trained_model_weights_for_tests.h5")
13+
ROOT_DIR = Path(__file__).parent
14+
FIXTURES_PATH = ROOT_DIR / 'fixtures'
15+
EXAMPLE_INPUT_FILE = (ROOT_DIR / "../../../../validation/data/2020-09-30_historical_ip.csv").absolute()
16+
DATA_FILE = FIXTURES_PATH / "OxCGRT_trimmed.csv"
17+
PREDICTOR_WEIGHTS = FIXTURES_PATH / "trained_model_weights_for_tests.h5"
1818

19+
TRAINING_START_DATE = "2020-06-01"
20+
TRAINING_END_DATE = "2020-07-31"
1921
START_DATE = "2020-08-01"
2022
END_DATE = "2020-08-04"
2123

@@ -24,15 +26,19 @@ class TestXPrizePredictor(unittest.TestCase):
2426

2527
@classmethod
2628
def setUpClass(cls):
27-
# Download and cache the raw data
28-
urllib.request.urlretrieve(DATA_URL, DATA_FILE)
29+
FIXTURES_PATH.mkdir(exist_ok=True)
30+
if not DATA_FILE.exists():
31+
df = load_oxford_data_trimmed(end_date=TRAINING_END_DATE, start_date=TRAINING_START_DATE)
32+
df.to_csv(DATA_FILE, index=False)
33+
34+
def test_train_and_predict(self):
35+
predictor = XPrizePredictor(None, DATA_FILE)
36+
37+
# Testing on a small number of epochs and trials to make sure everything is wired correctly
38+
model = predictor.train(num_trials=2, num_epochs=2)
39+
model.save_weights(PREDICTOR_WEIGHTS)
40+
self.assertIsNotNone(model)
2941

30-
def test_predict(self):
3142
predictor = XPrizePredictor(PREDICTOR_WEIGHTS, DATA_FILE)
3243
pred_df = predictor.predict(START_DATE, END_DATE, EXAMPLE_INPUT_FILE)
3344
self.assertIsInstance(pred_df, pd.DataFrame)
34-
35-
def test_train(self):
36-
predictor = XPrizePredictor(None, DATA_FILE)
37-
model = predictor.train()
38-
self.assertIsNotNone(model)

covid_xprize/examples/predictors/lstm/xprize_predictor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
WINDOW_SIZE = 7
2828
US_PREFIX = "United States / "
2929
NUM_TRIALS = 1
30+
NUM_EPOCHS = 1000
3031
LSTM_SIZE = 32
3132
MAX_NB_COUNTRIES = 20
3233

@@ -213,7 +214,7 @@ def _convert_ratios_to_total_cases(self,
213214
def _smooth_case_list(case_list, window):
214215
return pd.Series(case_list).rolling(window).mean().to_numpy()
215216

216-
def train(self):
217+
def train(self, num_trials=NUM_TRIALS, num_epochs=NUM_EPOCHS):
217218
print("Creating numpy arrays for Keras for each country...")
218219
geos = self._most_affected_geos(self.df, MAX_NB_COUNTRIES, NB_LOOKBACK_DAYS)
219220
country_samples = create_country_samples(self.df, geos, CONTEXT_COLUMN, NB_TEST_DAYS, NB_LOOKBACK_DAYS)
@@ -256,14 +257,14 @@ def train(self):
256257
train_losses = []
257258
val_losses = []
258259
test_losses = []
259-
for t in range(NUM_TRIALS):
260+
for t in range(num_trials):
260261
print('Trial', t)
261262
X_context, X_action, y = self._permute_data(X_context, X_action, y, seed=t)
262263
model, training_model = self._construct_model(nb_context=X_context.shape[-1],
263264
nb_action=X_action.shape[-1],
264265
lstm_size=LSTM_SIZE,
265266
nb_lookback_days=NB_LOOKBACK_DAYS)
266-
history = self._train_model(training_model, X_context, X_action, y, epochs=1000, verbose=0)
267+
history = self._train_model(training_model, X_context, X_action, y, epochs=num_epochs, verbose=0)
267268
top_epoch = np.argmin(history.history['val_loss'])
268269
train_loss = history.history['loss'][top_epoch]
269270
val_loss = history.history['val_loss'][top_epoch]

covid_xprize/oxford_data/oxford_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def create_prediction_initial_context_and_action_vectors(
320320
df = df[df.Date <= start_date]
321321
country_samples = create_country_samples(df, countries, context_column, nb_lookback_days=nb_lookback_days)
322322
for c in countries:
323+
if c not in country_samples.keys():
324+
continue
323325
context_vectors[c] = country_samples[c]['X_test_context'][-1]
324326
action_vectors[c] = country_samples[c]['X_test_action'][-1]
325327
return context_vectors, action_vectors

0 commit comments

Comments
 (0)