2525ADDITIONAL_US_STATES_CONTEXT = os .path .join (DATA_PATH , "US_states_populations.csv" )
2626ADDITIONAL_UK_CONTEXT = os .path .join (DATA_PATH , "uk_populations.csv" )
2727ADDITIONAL_BRAZIL_CONTEXT = os .path .join (DATA_PATH , "brazil_populations.csv" )
28+ # Fixed weights for the standard predictor.
29+ MODEL_WEIGHTS_FILE = os .path .join (ROOT_DIR , "models" , "trained_model_weights.h5" )
2830
2931NPI_COLUMNS = ['C1_School closing' ,
3032 'C2_Workplace closing' ,
@@ -72,7 +74,7 @@ class XPrizePredictor(object):
7274 A class that computes a fitness for Prescriptor candidates.
7375 """
7476
75- def __init__ (self , path_to_model_weights , data_url ):
77+ def __init__ (self , path_to_model_weights = MODEL_WEIGHTS_FILE , data_url = DATA_FILE_PATH ):
7678 if path_to_model_weights :
7779
7880 # Load model weights
@@ -94,13 +96,18 @@ def predict(self,
9496 start_date_str : str ,
9597 end_date_str : str ,
9698 path_to_ips_file : str ) -> pd .DataFrame :
99+ # Load the npis into a DataFrame, handling regions
100+ npis_df = self .load_original_data (path_to_ips_file )
101+ return self .predict_from_df (start_date_str , end_date_str , npis_df )
102+
103+ def predict_from_df (self ,
104+ start_date_str : str ,
105+ end_date_str : str ,
106+ npis_df : pd .DataFrame ) -> pd .DataFrame :
97107 start_date = pd .to_datetime (start_date_str , format = '%Y-%m-%d' )
98108 end_date = pd .to_datetime (end_date_str , format = '%Y-%m-%d' )
99109 nb_days = (end_date - start_date ).days + 1
100110
101- # Load the npis into a DataFrame, handling regions
102- npis_df = self ._load_original_data (path_to_ips_file )
103-
104111 # Prepare the output
105112 forecast = {"CountryName" : [],
106113 "RegionName" : [],
@@ -177,7 +184,7 @@ def _prepare_dataframe(self, data_url: str) -> pd.DataFrame:
177184 :return: a Pandas DataFrame with the historical data
178185 """
179186 # Original df from Oxford
180- df1 = self ._load_original_data (data_url )
187+ df1 = self .load_original_data (data_url )
181188
182189 # Additional context df (e.g Population for each country)
183190 df2 = self ._load_additional_context_df ()
@@ -224,7 +231,7 @@ def _prepare_dataframe(self, data_url: str) -> pd.DataFrame:
224231 return df
225232
226233 @staticmethod
227- def _load_original_data (data_url ):
234+ def load_original_data (data_url ):
228235 latest_df = pd .read_csv (data_url ,
229236 parse_dates = ['Date' ],
230237 encoding = "ISO-8859-1" ,
0 commit comments