Skip to content

Commit 3985f8d

Browse files
authored
Upgraded to MEDS v0.4 (#181)
1 parent 58dd44d commit 3985f8d

File tree

3 files changed

+44
-63
lines changed

3 files changed

+44
-63
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"pytimeparse == 1.1.*",
2727
"networkx == 3.3.*",
2828
"pyarrow == 17.*",
29-
"meds == 0.3.3",
29+
"meds ~= 0.4.0",
3030
]
3131

3232
[tool.setuptools]

src/aces/run.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import polars as pl
1212
import pyarrow as pa
1313
import pyarrow.parquet as pq
14-
from meds import label_schema, prediction_time_field, subject_id_field
14+
from meds import LabelSchema
1515
from omegaconf import DictConfig, OmegaConf
1616

1717
from . import config, predicates, query
@@ -20,15 +20,15 @@
2020
config_yaml = files("aces").joinpath("configs/_aces.yaml")
2121

2222
MEDS_LABEL_MANDATORY_TYPES = {
23-
subject_id_field: pl.Int64,
23+
LabelSchema.subject_id_name: pl.Int64,
2424
}
2525

2626
MEDS_LABEL_OPTIONAL_TYPES = {
27-
"boolean_value": pl.Boolean,
28-
"integer_value": pl.Int64,
29-
"float_value": pl.Float64,
30-
"categorical_value": pl.String,
31-
prediction_time_field: pl.Datetime("us"),
27+
LabelSchema.prediction_time_name: pl.Datetime("us"),
28+
LabelSchema.boolean_value_name: pl.Boolean,
29+
LabelSchema.integer_value_name: pl.Int64,
30+
LabelSchema.float_value_name: pl.Float64,
31+
LabelSchema.categorical_value_name: pl.String,
3232
}
3333

3434

@@ -56,9 +56,9 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
5656
>>> get_and_validate_label_schema(df)
5757
Traceback (most recent call last):
5858
...
59-
ValueError: MEDS Data DataFrame must have a 'subject_id' column of type Int64.
59+
ValueError: MEDS Label DataFrame must have a 'subject_id' column of type Int64.
6060
>>> df = pl.DataFrame({
61-
... subject_id_field: pl.Series([1, 3, 2], dtype=pl.UInt32),
61+
... "subject_id": pl.Series([1, 3, 2], dtype=pl.UInt32),
6262
... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)],
6363
... "boolean_value": [1, 0, 100],
6464
... })
@@ -68,7 +68,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
6868
prediction_time: timestamp[us]
6969
boolean_value: bool
7070
integer_value: int64
71-
float_value: double
71+
float_value: float
7272
categorical_value: string
7373
----
7474
subject_id: [[1,3,2]]
@@ -80,7 +80,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
8080
"""
8181

8282
schema = df.schema
83-
if "prediction_time" not in schema:
83+
if LabelSchema.prediction_time_name not in schema:
8484
logger.warning(
8585
"Output DataFrame is missing a 'prediction_time' column. If this is not intentional, add a "
8686
"'index_timestamp' (yes, it should be different) key to the task configuration identifying "
@@ -92,7 +92,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
9292
if col in schema and schema[col] != dtype:
9393
df = df.with_columns(pl.col(col).cast(dtype, strict=False))
9494
elif col not in schema:
95-
errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.")
95+
errors.append(f"MEDS Label DataFrame must have a '{col}' column of type {dtype}.")
9696

9797
if errors:
9898
raise ValueError("\n".join(errors))
@@ -115,16 +115,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
115115
)
116116
df = df.drop(extra_cols)
117117

118-
df = df.select(
119-
subject_id_field,
120-
"prediction_time",
121-
"boolean_value",
122-
"integer_value",
123-
"float_value",
124-
"categorical_value",
125-
)
126-
127-
return df.to_arrow().cast(label_schema)
118+
return LabelSchema.align(df.to_arrow())
128119

129120

130121
@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem)
@@ -154,18 +145,18 @@ def main(cfg: DictConfig) -> None: # pragma: no cover
154145

155146
if cfg.data.standard.lower() == "meds":
156147
for in_col, out_col in [
157-
("subject_id", subject_id_field),
158-
("index_timestamp", "prediction_time"),
159-
("label", "boolean_value"),
148+
("subject_id", LabelSchema.subject_id_name),
149+
("index_timestamp", LabelSchema.prediction_time_name),
150+
("label", LabelSchema.boolean_value_name),
160151
]:
161152
if in_col in result.columns:
162153
result = result.rename({in_col: out_col})
163-
if subject_id_field not in result.columns:
154+
if LabelSchema.subject_id_name not in result.columns:
164155
if not result_is_empty:
165156
raise ValueError("Output dataframe is missing a 'subject_id' column.")
166157
else:
167158
logger.warning("Output dataframe is empty; adding an empty patient ID column.")
168-
result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias(subject_id_field))
159+
result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias(LabelSchema.subject_id_name))
169160
result = result.head(0)
170161
if cfg.window_stats_dir:
171162
Path(cfg.window_stats_filepath).parent.mkdir(exist_ok=True, parents=True)

tests/test_meds.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import polars as pl
99
import pyarrow as pa
10-
from meds import label_schema, subject_id_field
10+
from meds import DataSchema, LabelSchema
1111
from yaml import load as load_yaml
1212

1313
from .utils import (
@@ -36,24 +36,23 @@
3636

3737
# TODO: Make use meds library
3838
MEDS_PL_SCHEMA = {
39-
subject_id_field: pl.Int64,
40-
"time": pl.Datetime("us"),
41-
"code": pl.Utf8,
42-
"numeric_value": pl.Float32,
43-
"numeric_value/is_inlier": pl.Boolean,
39+
DataSchema.subject_id_name: pl.Int64,
40+
DataSchema.time_name: pl.Datetime("us"),
41+
DataSchema.code_name: pl.Utf8,
42+
DataSchema.numeric_value_name: pl.Float32,
4443
}
4544

4645

4746
MEDS_LABEL_MANDATORY_TYPES = {
48-
subject_id_field: pl.Int64,
47+
LabelSchema.subject_id_name: pl.Int64,
4948
}
5049

5150
MEDS_LABEL_OPTIONAL_TYPES = {
52-
"boolean_value": pl.Boolean,
53-
"integer_value": pl.Int64,
54-
"float_value": pl.Float64,
55-
"categorical_value": pl.String,
56-
"prediction_time": pl.Datetime("us"),
51+
LabelSchema.boolean_value_name: pl.Boolean,
52+
LabelSchema.integer_value_name: pl.Int64,
53+
LabelSchema.float_value_name: pl.Float64,
54+
LabelSchema.categorical_value_name: pl.String,
55+
LabelSchema.prediction_time_name: pl.Datetime("us"),
5756
}
5857

5958

@@ -113,16 +112,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
113112
)
114113
df = df.drop(extra_cols)
115114

116-
df = df.select(
117-
subject_id_field,
118-
"prediction_time",
119-
"boolean_value",
120-
"integer_value",
121-
"float_value",
122-
"categorical_value",
123-
)
124-
125-
return df.to_arrow().cast(label_schema)
115+
return LabelSchema.align(df.to_arrow())
126116

127117

128118
def parse_meds_csvs(
@@ -140,7 +130,7 @@ def reader(csv_str: str) -> pl.DataFrame:
140130
cols = csv_str.strip().split("\n")[0].split(",")
141131
read_schema = {k: v for k, v in default_read_schema.items() if k in cols}
142132
return pl.read_csv(StringIO(csv_str), schema=read_schema).with_columns(
143-
pl.col("time").str.strptime(MEDS_PL_SCHEMA["time"], DEFAULT_CSV_TS_FORMAT)
133+
pl.col("time").str.strptime(MEDS_PL_SCHEMA[DataSchema.time_name], DEFAULT_CSV_TS_FORMAT)
144134
)
145135

146136
if isinstance(csvs, str):
@@ -169,9 +159,9 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
169159

170160
# Data (input)
171161
MEDS_SHARDS = parse_shards_yaml(
172-
f"""
162+
"""
173163
"train/0": |-
174-
{subject_id_field},time,code,numeric_value
164+
subject_id,time,code,numeric_value
175165
2,,SNP//rs234567,
176166
2,,SNP//rs345678,
177167
2,,GENDER//FEMALE,
@@ -196,7 +186,7 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
196186
2,6/8/1996 3:00,DEATH,
197187
198188
"train/1": |-2
199-
{subject_id_field},time,code,numeric_value
189+
subject_id,time,code,numeric_value
200190
4,,GENDER//MALE,
201191
4,,SNP//rs123456,
202192
4,12/1/1989 12:03,ADMISSION//CARDIAC,
@@ -246,7 +236,7 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
246236
6,3/12/1996 0:00,DEATH,
247237
248238
"held_out/0/0": |-2
249-
{subject_id_field},time,code,numeric_value
239+
subject_id,time,code,numeric_value
250240
3,,GENDER//FEMALE,
251241
3,,SNP//rs234567,
252242
3,,SNP//rs345678,
@@ -261,10 +251,10 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
261251
3,3/12/1996 0:00,DEATH,
262252
263253
"empty_shard": |-2
264-
{subject_id_field},time,code,numeric_value
254+
subject_id,time,code,numeric_value
265255
266256
"held_out": |-2
267-
{subject_id_field},time,code,numeric_value
257+
subject_id,time,code,numeric_value
268258
1,,GENDER//MALE,
269259
1,,SNP//rs123456,
270260
1,12/1/1989 12:03,ADMISSION//CARDIAC,
@@ -349,22 +339,22 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
349339
"""
350340

351341
WANT_SHARDS = parse_labels_yaml(
352-
f"""
342+
"""
353343
"train/0": |-2
354-
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
344+
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
355345
356346
"train/1": |-2
357-
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
347+
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
358348
4,1/28/1991 23:32,False,,,,
359349
360350
"held_out/0/0": |-2
361-
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
351+
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
362352
363353
"empty_shard": |-2
364-
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
354+
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
365355
366356
"held_out": |-2
367-
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
357+
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
368358
1,1/28/1991 23:32,False,,,,
369359
"""
370360
)

0 commit comments

Comments
 (0)