Skip to content

Commit 286640d

Browse files
authored
Merge pull request #193 from leaf-ai/return_preds
#189 Make generate_cases_and_stringency_for_prescriptions return the generated DataFrames
2 parents 1fc07b2 + 51d41c9 commit 286640d

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

covid_xprize/scoring/prescriptor_scoring.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import time
2+
13
import pandas as pd
24

35
from covid_xprize.standard_predictor.xprize_predictor import XPrizePredictor
@@ -15,21 +17,22 @@ def weight_prescriptions_by_cost(pres_df, cost_df):
1517

1618

1719
def generate_cases_and_stringency_for_prescriptions(start_date, end_date, prescription_file, costs_file):
20+
start_time = time.time()
1821
# Load the prescriptions, handling Date and regions
1922
pres_df = XPrizePredictor.load_original_data(prescription_file)
2023

2124
# Generate predictions for all prescriptions
2225
predictor = XPrizePredictor()
23-
pred_dfs = []
26+
pred_dfs = {}
2427
for idx in pres_df['PrescriptionIndex'].unique():
2528
idx_df = pres_df[pres_df['PrescriptionIndex'] == idx]
2629
idx_df = idx_df.drop(columns='PrescriptionIndex') # Predictor doesn't need this
2730
# Generate the predictions
2831
pred_df = predictor.predict_from_df(start_date, end_date, idx_df)
2932
print(f"Generated predictions for PrescriptionIndex {idx}")
3033
pred_df['PrescriptionIndex'] = idx
31-
pred_dfs.append(pred_df)
32-
pred_df = pd.concat(pred_dfs)
34+
pred_dfs[idx] = pred_df
35+
pred_df = pd.concat(list(pred_dfs.values()))
3336

3437
# Aggregate cases by prescription index and geo
3538
agg_pred_df = pred_df.groupby(['CountryName',
@@ -65,8 +68,12 @@ def generate_cases_and_stringency_for_prescriptions(start_date, end_date, prescr
6568
'PrescriptionIndex',
6669
'PredictedDailyNewCases',
6770
'Stringency']]
71+
end_time = time.time()
72+
elapsed_time = end_time - start_time
73+
elapsed_time_tring = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
74+
print(f"Evaluated {len(pred_dfs)} PrescriptionIndex in {elapsed_time_tring} seconds")
6875

69-
return df
76+
return df, pred_dfs
7077

7178

7279
# Compute domination relationship for each pair of prescriptors for each geo

prescriptor_robojudge.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@
191191
"dfs = []\n",
192192
"for prescriptor_name, prescription_file in sorted(prescription_files.items()):\n",
193193
" print(\"Generating predictions for\", prescriptor_name)\n",
194-
" df = generate_cases_and_stringency_for_prescriptions(START_DATE, END_DATE, prescription_file, TEST_COST)\n",
194+
" df, _ = generate_cases_and_stringency_for_prescriptions(START_DATE, END_DATE, prescription_file, TEST_COST)\n",
195195
" df['PrescriptorName'] = prescriptor_name\n",
196196
" dfs.append(df)\n",
197197
"df = pd.concat(dfs)"

0 commit comments

Comments
 (0)