Skip to content

Commit 1a9bc7d

Browse files
authored
Merge pull request #203 from leaf-ai/prepend-past-to-presc
Prepend past ips to prescriptions
2 parents dea9275 + 6d456d2 commit 1a9bc7d

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

covid_xprize/scoring/prescriptor_scoring.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def weight_prescriptions_by_cost(pres_df, cost_df):
1616
return weighted_df
1717

1818

19-
def generate_cases_and_stringency_for_prescriptions(start_date, end_date, prescription_file, costs_file):
19+
def generate_cases_and_stringency_for_prescriptions(start_date, end_date, prescription_file, costs_file, past_ips_file=None):
2020
start_time = time.time()
2121
# Load the prescriptions, handling Date and regions
2222
pres_df = XPrizePredictor.load_original_data(prescription_file)
@@ -27,6 +27,14 @@ def generate_cases_and_stringency_for_prescriptions(start_date, end_date, prescr
2727
for idx in pres_df['PrescriptionIndex'].unique():
2828
idx_df = pres_df[pres_df['PrescriptionIndex'] == idx]
2929
idx_df = idx_df.drop(columns='PrescriptionIndex') # Predictor doesn't need this
30+
31+
# Prepend past ips if provided. This is used to handle the case when
32+
# prescriptions start after predictor's dataset ends.
33+
if past_ips_file:
34+
past_ips_df = XPrizePredictor.load_original_data(past_ips_file)
35+
past_ips_df = past_ips_df[past_ips_df['Date'] < start_date]
36+
idx_df = past_ips_df.append(idx_df)
37+
3038
# Generate the predictions
3139
pred_df = predictor.predict_from_df(start_date, end_date, idx_df)
3240
print(f"Generated predictions for PrescriptionIndex {idx}")

prescriptor_robojudge.ipynb

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@
6666
{
6767
"cell_type": "code",
6868
"execution_count": null,
69-
"metadata": {},
69+
"metadata": {
70+
"scrolled": false
71+
},
7072
"outputs": [],
7173
"source": [
7274
"from covid_xprize.scoring.predictor_scoring import load_dataset\n",
@@ -81,7 +83,9 @@
8183
{
8284
"cell_type": "code",
8385
"execution_count": null,
84-
"metadata": {},
86+
"metadata": {
87+
"scrolled": false
88+
},
8589
"outputs": [],
8690
"source": [
8791
"IP_FILE = \"prescriptions/robojudge_test_scenario.csv\"\n",
@@ -100,7 +104,9 @@
100104
{
101105
"cell_type": "code",
102106
"execution_count": null,
103-
"metadata": {},
107+
"metadata": {
108+
"scrolled": false
109+
},
104110
"outputs": [],
105111
"source": [
106112
"# Cost weightings for each IP for each geo\n",
@@ -128,7 +134,9 @@
128134
{
129135
"cell_type": "code",
130136
"execution_count": null,
131-
"metadata": {},
137+
"metadata": {
138+
"scrolled": false
139+
},
132140
"outputs": [],
133141
"source": [
134142
"# Generate blind_greedy prescriptions\n",
@@ -159,7 +167,9 @@
159167
{
160168
"cell_type": "code",
161169
"execution_count": null,
162-
"metadata": {},
170+
"metadata": {
171+
"scrolled": false
172+
},
163173
"outputs": [],
164174
"source": [
165175
"# Validate the prescription files\n",
@@ -183,15 +193,15 @@
183193
"cell_type": "code",
184194
"execution_count": null,
185195
"metadata": {
186-
"scrolled": false
196+
"scrolled": true
187197
},
188198
"outputs": [],
189199
"source": [
190200
"# Collect case and stringency data for all prescriptors\n",
191201
"dfs = []\n",
192202
"for prescriptor_name, prescription_file in sorted(prescription_files.items()):\n",
193203
" print(\"Generating predictions for\", prescriptor_name)\n",
194-
" df, _ = generate_cases_and_stringency_for_prescriptions(START_DATE, END_DATE, prescription_file, TEST_COST)\n",
204+
" df, preds = generate_cases_and_stringency_for_prescriptions(START_DATE, END_DATE, prescription_file, TEST_COST, IP_FILE)\n",
195205
" df['PrescriptorName'] = prescriptor_name\n",
196206
" dfs.append(df)\n",
197207
"df = pd.concat(dfs)"
@@ -200,7 +210,9 @@
200210
{
201211
"cell_type": "code",
202212
"execution_count": null,
203-
"metadata": {},
213+
"metadata": {
214+
"scrolled": false
215+
},
204216
"outputs": [],
205217
"source": [
206218
"df[df['CountryName'] == 'Afghanistan']"
@@ -232,7 +244,9 @@
232244
{
233245
"cell_type": "code",
234246
"execution_count": null,
235-
"metadata": {},
247+
"metadata": {
248+
"scrolled": false
249+
},
236250
"outputs": [],
237251
"source": [
238252
"def plot_pareto_curve(objective1_list, objective2_list):\n",
@@ -321,7 +335,9 @@
321335
{
322336
"cell_type": "code",
323337
"execution_count": null,
324-
"metadata": {},
338+
"metadata": {
339+
"scrolled": false
340+
},
325341
"outputs": [],
326342
"source": [
327343
"# Plot stringency and cases of each prescription for a particular country\n",
@@ -347,7 +363,9 @@
347363
{
348364
"cell_type": "code",
349365
"execution_count": null,
350-
"metadata": {},
366+
"metadata": {
367+
"scrolled": false
368+
},
351369
"outputs": [],
352370
"source": []
353371
}

0 commit comments

Comments
 (0)