7
7
8
8
import polars as pl
9
9
import pyarrow as pa
10
- from meds import label_schema , subject_id_field
10
+ from meds import DataSchema , LabelSchema
11
11
from yaml import load as load_yaml
12
12
13
13
from .utils import (
36
36
37
37
# TODO: Make use meds library
38
38
MEDS_PL_SCHEMA = {
39
- subject_id_field : pl .Int64 ,
40
- "time" : pl .Datetime ("us" ),
41
- "code" : pl .Utf8 ,
42
- "numeric_value" : pl .Float32 ,
43
- "numeric_value/is_inlier" : pl .Boolean ,
39
+ DataSchema .subject_id_name : pl .Int64 ,
40
+ DataSchema .time_name : pl .Datetime ("us" ),
41
+ DataSchema .code_name : pl .Utf8 ,
42
+ DataSchema .numeric_value_name : pl .Float32 ,
44
43
}
45
44
46
45
47
46
MEDS_LABEL_MANDATORY_TYPES = {
48
- subject_id_field : pl .Int64 ,
47
+ LabelSchema . subject_id_name : pl .Int64 ,
49
48
}
50
49
51
50
MEDS_LABEL_OPTIONAL_TYPES = {
52
- "boolean_value" : pl .Boolean ,
53
- "integer_value" : pl .Int64 ,
54
- "float_value" : pl .Float64 ,
55
- "categorical_value" : pl .String ,
56
- "prediction_time" : pl .Datetime ("us" ),
51
+ LabelSchema . boolean_value_name : pl .Boolean ,
52
+ LabelSchema . integer_value_name : pl .Int64 ,
53
+ LabelSchema . float_value_name : pl .Float64 ,
54
+ LabelSchema . categorical_value_name : pl .String ,
55
+ LabelSchema . prediction_time_name : pl .Datetime ("us" ),
57
56
}
58
57
59
58
@@ -113,16 +112,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
113
112
)
114
113
df = df .drop (extra_cols )
115
114
116
- df = df .select (
117
- subject_id_field ,
118
- "prediction_time" ,
119
- "boolean_value" ,
120
- "integer_value" ,
121
- "float_value" ,
122
- "categorical_value" ,
123
- )
124
-
125
- return df .to_arrow ().cast (label_schema )
115
+ return LabelSchema .align (df .to_arrow ())
126
116
127
117
128
118
def parse_meds_csvs (
@@ -140,7 +130,7 @@ def reader(csv_str: str) -> pl.DataFrame:
140
130
cols = csv_str .strip ().split ("\n " )[0 ].split ("," )
141
131
read_schema = {k : v for k , v in default_read_schema .items () if k in cols }
142
132
return pl .read_csv (StringIO (csv_str ), schema = read_schema ).with_columns (
143
- pl .col ("time" ).str .strptime (MEDS_PL_SCHEMA ["time" ], DEFAULT_CSV_TS_FORMAT )
133
+ pl .col ("time" ).str .strptime (MEDS_PL_SCHEMA [DataSchema . time_name ], DEFAULT_CSV_TS_FORMAT )
144
134
)
145
135
146
136
if isinstance (csvs , str ):
@@ -169,9 +159,9 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
169
159
170
160
# Data (input)
171
161
MEDS_SHARDS = parse_shards_yaml (
172
- f """
162
+ """
173
163
"train/0": |-
174
- { subject_id_field } ,time,code,numeric_value
164
+ subject_id ,time,code,numeric_value
175
165
2,,SNP//rs234567,
176
166
2,,SNP//rs345678,
177
167
2,,GENDER//FEMALE,
@@ -196,7 +186,7 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
196
186
2,6/8/1996 3:00,DEATH,
197
187
198
188
"train/1": |-2
199
- { subject_id_field } ,time,code,numeric_value
189
+ subject_id ,time,code,numeric_value
200
190
4,,GENDER//MALE,
201
191
4,,SNP//rs123456,
202
192
4,12/1/1989 12:03,ADMISSION//CARDIAC,
@@ -246,7 +236,7 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
246
236
6,3/12/1996 0:00,DEATH,
247
237
248
238
"held_out/0/0": |-2
249
- { subject_id_field } ,time,code,numeric_value
239
+ subject_id ,time,code,numeric_value
250
240
3,,GENDER//FEMALE,
251
241
3,,SNP//rs234567,
252
242
3,,SNP//rs345678,
@@ -261,10 +251,10 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
261
251
3,3/12/1996 0:00,DEATH,
262
252
263
253
"empty_shard": |-2
264
- { subject_id_field } ,time,code,numeric_value
254
+ subject_id ,time,code,numeric_value
265
255
266
256
"held_out": |-2
267
- { subject_id_field } ,time,code,numeric_value
257
+ subject_id ,time,code,numeric_value
268
258
1,,GENDER//MALE,
269
259
1,,SNP//rs123456,
270
260
1,12/1/1989 12:03,ADMISSION//CARDIAC,
@@ -349,22 +339,22 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
349
339
"""
350
340
351
341
WANT_SHARDS = parse_labels_yaml (
352
- f """
342
+ """
353
343
"train/0": |-2
354
- { subject_id_field } ,prediction_time,boolean_value,integer_value,float_value,categorical_value
344
+ subject_id ,prediction_time,boolean_value,integer_value,float_value,categorical_value
355
345
356
346
"train/1": |-2
357
- { subject_id_field } ,prediction_time,boolean_value,integer_value,float_value,categorical_value
347
+ subject_id ,prediction_time,boolean_value,integer_value,float_value,categorical_value
358
348
4,1/28/1991 23:32,False,,,,
359
349
360
350
"held_out/0/0": |-2
361
- { subject_id_field } ,prediction_time,boolean_value,integer_value,float_value,categorical_value
351
+ subject_id ,prediction_time,boolean_value,integer_value,float_value,categorical_value
362
352
363
353
"empty_shard": |-2
364
- { subject_id_field } ,prediction_time,boolean_value,integer_value,float_value,categorical_value
354
+ subject_id ,prediction_time,boolean_value,integer_value,float_value,categorical_value
365
355
366
356
"held_out": |-2
367
- { subject_id_field } ,prediction_time,boolean_value,integer_value,float_value,categorical_value
357
+ subject_id ,prediction_time,boolean_value,integer_value,float_value,categorical_value
368
358
1,1/28/1991 23:32,False,,,,
369
359
"""
370
360
)
0 commit comments