Skip to content

Commit 1fc07b2

Browse files
authored
Merge pull request #192 from leaf-ai/predict_from_df
#189 Predict directly from DataFrame objects
2 parents ed03d55 + 77842ad commit 1fc07b2

File tree

3 files changed

+21
-31
lines changed

3 files changed

+21
-31
lines changed

covid_xprize/scoring/prescriptor_scoring.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import os
2-
31
import pandas as pd
42

5-
from covid_xprize.standard_predictor.predict import predict
3+
from covid_xprize.standard_predictor.xprize_predictor import XPrizePredictor
64
from covid_xprize.standard_predictor.xprize_predictor import NPI_COLUMNS
75

86

@@ -17,28 +15,18 @@ def weight_prescriptions_by_cost(pres_df, cost_df):
1715

1816

1917
def generate_cases_and_stringency_for_prescriptions(start_date, end_date, prescription_file, costs_file):
20-
# Load prescriptions
21-
pres_df = pd.read_csv(prescription_file)
18+
# Load the prescriptions, handling Date and regions
19+
pres_df = XPrizePredictor.load_original_data(prescription_file)
2220

2321
# Generate predictions for all prescriptions
22+
predictor = XPrizePredictor()
2423
pred_dfs = []
2524
for idx in pres_df['PrescriptionIndex'].unique():
2625
idx_df = pres_df[pres_df['PrescriptionIndex'] == idx]
2726
idx_df = idx_df.drop(columns='PrescriptionIndex') # Predictor doesn't need this
28-
ip_file_path = 'prescriptions/prescription_{}.csv'.format(idx)
29-
os.makedirs(os.path.dirname(ip_file_path), exist_ok=True)
30-
idx_df.to_csv(ip_file_path)
31-
preds_file_path = 'predictions/predictions_{}.csv'.format(idx)
32-
os.makedirs(os.path.dirname(preds_file_path), exist_ok=True)
33-
34-
# Run predictor
35-
predict(start_date, end_date, ip_file_path, preds_file_path)
36-
37-
# Collect predictions
38-
pred_df = pd.read_csv(preds_file_path,
39-
parse_dates=['Date'],
40-
encoding="ISO-8859-1",
41-
error_bad_lines=True)
27+
# Generate the predictions
28+
pred_df = predictor.predict_from_df(start_date, end_date, idx_df)
29+
print(f"Generated predictions for PrescriptionIndex {idx}")
4230
pred_df['PrescriptionIndex'] = idx
4331
pred_dfs.append(pred_df)
4432
pred_df = pd.concat(pred_dfs)

covid_xprize/standard_predictor/predict.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77

88
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
99

10-
# Fixed weights for the standard predictor.
11-
MODEL_WEIGHTS_FILE = os.path.join(ROOT_DIR, "models", "trained_model_weights.h5")
12-
13-
DATA_FILE = os.path.join(ROOT_DIR, 'data', "OxCGRT_latest.csv")
14-
1510

1611
def predict(start_date: str,
1712
end_date: str,
@@ -29,7 +24,7 @@ def predict(start_date: str,
2924
with columns "CountryName,RegionName,Date,PredictedDailyNewCases"
3025
"""
3126
# !!! YOUR CODE HERE !!!
32-
predictor = XPrizePredictor(MODEL_WEIGHTS_FILE, DATA_FILE)
27+
predictor = XPrizePredictor()
3328
# Generate the predictions
3429
preds_df = predictor.predict(start_date, end_date, path_to_ips_file)
3530
# Create the output path

covid_xprize/standard_predictor/xprize_predictor.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
ADDITIONAL_US_STATES_CONTEXT = os.path.join(DATA_PATH, "US_states_populations.csv")
2626
ADDITIONAL_UK_CONTEXT = os.path.join(DATA_PATH, "uk_populations.csv")
2727
ADDITIONAL_BRAZIL_CONTEXT = os.path.join(DATA_PATH, "brazil_populations.csv")
28+
# Fixed weights for the standard predictor.
29+
MODEL_WEIGHTS_FILE = os.path.join(ROOT_DIR, "models", "trained_model_weights.h5")
2830

2931
NPI_COLUMNS = ['C1_School closing',
3032
'C2_Workplace closing',
@@ -72,7 +74,7 @@ class XPrizePredictor(object):
7274
A class that computes a fitness for Prescriptor candidates.
7375
"""
7476

75-
def __init__(self, path_to_model_weights, data_url):
77+
def __init__(self, path_to_model_weights=MODEL_WEIGHTS_FILE, data_url=DATA_FILE_PATH):
7678
if path_to_model_weights:
7779

7880
# Load model weights
@@ -94,13 +96,18 @@ def predict(self,
9496
start_date_str: str,
9597
end_date_str: str,
9698
path_to_ips_file: str) -> pd.DataFrame:
99+
# Load the npis into a DataFrame, handling regions
100+
npis_df = self.load_original_data(path_to_ips_file)
101+
return self.predict_from_df(start_date_str, end_date_str, npis_df)
102+
103+
def predict_from_df(self,
104+
start_date_str: str,
105+
end_date_str: str,
106+
npis_df: pd.DataFrame) -> pd.DataFrame:
97107
start_date = pd.to_datetime(start_date_str, format='%Y-%m-%d')
98108
end_date = pd.to_datetime(end_date_str, format='%Y-%m-%d')
99109
nb_days = (end_date - start_date).days + 1
100110

101-
# Load the npis into a DataFrame, handling regions
102-
npis_df = self._load_original_data(path_to_ips_file)
103-
104111
# Prepare the output
105112
forecast = {"CountryName": [],
106113
"RegionName": [],
@@ -177,7 +184,7 @@ def _prepare_dataframe(self, data_url: str) -> pd.DataFrame:
177184
:return: a Pandas DataFrame with the historical data
178185
"""
179186
# Original df from Oxford
180-
df1 = self._load_original_data(data_url)
187+
df1 = self.load_original_data(data_url)
181188

182189
# Additional context df (e.g Population for each country)
183190
df2 = self._load_additional_context_df()
@@ -224,7 +231,7 @@ def _prepare_dataframe(self, data_url: str) -> pd.DataFrame:
224231
return df
225232

226233
@staticmethod
227-
def _load_original_data(data_url):
234+
def load_original_data(data_url):
228235
latest_df = pd.read_csv(data_url,
229236
parse_dates=['Date'],
230237
encoding="ISO-8859-1",

0 commit comments

Comments
 (0)