Skip to content

Commit edd95a4

Browse files
Merge pull request #228 from cognizant-ai-labs/conditional_lstm_data_processing
Conditional lstm data processing
2 parents 14cd81a + 48eb714 commit edd95a4

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

covid_xprize/examples/predictors/conditional_lstm/conditional_xprize_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, path_to_model_weights, data_url):
4444
nb_lookback_days=NB_LOOKBACK_DAYS)
4545
self.predictor.load_weights(path_to_model_weights)
4646

47-
self.df = prepare_cases_dataframe(data_url, threshold_min_cases=True)
47+
self.df = prepare_cases_dataframe(data_url)
4848

4949
def train(self,
5050
return_results=False,

covid_xprize/examples/predictors/conditional_lstm/train_predictor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from keras.constraints import Constraint
2222

2323
from covid_xprize.examples.predictors.conditional_lstm.conditional_lstm_model import construct_conditional_lstm_model
24-
from covid_xprize.oxford_data import most_affected_countries, create_country_samples
24+
from covid_xprize.oxford_data import most_affected_countries, create_country_samples, threshold_min_cases
2525

2626

2727
def construct_model(nb_context: int, nb_action: int, lstm_size: int = 32, nb_lookback_days: int = 21) -> Model:
@@ -71,6 +71,9 @@ def train_predictor(training_data: pd.DataFrame,
7171
df = training_data
7272
context_column = 'SmoothNewCasesPer100K'
7373

74+
# Only look at data with # cases above a minimum.
75+
df = threshold_min_cases(df)
76+
7477
# Create data set for training
7578
if nb_training_geos == None: # Use all countries
7679
nb_training_geos = len(df.GeoID.unique())

covid_xprize/oxford_data/oxford_data.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def add_population_column(df: pd.DataFrame) -> pd.DataFrame:
177177
return df.merge(pop_df[['GeoID', 'Population']], on=['GeoID'], how='left', suffixes=('', '_y'))
178178

179179

180-
def prepare_cases_dataframe(data_url: str, threshold_min_cases=False) -> pd.DataFrame:
180+
def prepare_cases_dataframe(data_url: str) -> pd.DataFrame:
181181
"""
182182
Loads the cases dataset from the given file, cleans it, and computes cases columns.
183183
:param data_url: the url containing the original data
@@ -218,10 +218,6 @@ def prepare_cases_dataframe(data_url: str, threshold_min_cases=False) -> pd.Data
218218
df['DeathRatio'] = df.groupby('GeoID', group_keys=False).SmoothNewDeaths.pct_change(
219219
).fillna(0).replace(np.inf, 0) + 1
220220

221-
# Remove all rows with too few cases
222-
if threshold_min_cases:
223-
df.drop(df[df.ConfirmedCases < MIN_CASES].index, inplace=True)
224-
225221
# Add column for proportion of population infected
226222
df['ProportionInfected'] = df['ConfirmedCases'] / df['Population']
227223

@@ -234,6 +230,11 @@ def prepare_cases_dataframe(data_url: str, threshold_min_cases=False) -> pd.Data
234230
return df
235231

236232

233+
def threshold_min_cases(df: pd.DataFrame) -> pd.DataFrame:
234+
"""Remove all rows with too few cases"""
235+
return df.drop(df[df.ConfirmedCases < MIN_CASES].index)
236+
237+
237238
def create_country_samples(df: pd.DataFrame,
238239
countries: list[str],
239240
context_column: str,

0 commit comments

Comments
 (0)