Skip to content

Commit da3e3c5

Browse files
committed
Upgraded polars and improved constraint checking for static constraints
1 parent 047da2a commit da3e3c5

File tree

3 files changed

+11
-15
lines changed

3 files changed

+11
-15
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ classifiers = [
1818
"Operating System :: OS Independent",
1919
]
2020
dependencies = [
21-
"polars ~= 1.27.1",
21+
"polars ~= 1.30.0",
2222
"numpy < 1.29.0",
2323
"bigtree == 0.18.*",
2424
"ruamel.yaml == 0.18.*",
2525
"hydra-core ~= 1.3.2",
2626
"pytimeparse == 1.1.*",
2727
"networkx == 3.3.*",
28-
"pyarrow == 17.*",
28+
"pyarrow",
2929
"meds ~= 0.4.0",
3030
]
3131

src/aces/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def MEDS_eval_expr(self) -> pl.Expr:
9393
col("code").str.contains(["^foo.*"]).all_horizontal([[(col("numeric_value")) >
9494
(dyn int: 120)]])
9595
>>> print(PlainPredicateConfig(code={'any': ['foo', 'bar']}).MEDS_eval_expr())
96-
col("code").is_in([Series])
96+
col("code").is_in([["foo", "bar"]])
9797
"""
9898
criteria = []
9999
if isinstance(self.code, dict):

src/aces/constraints.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -169,21 +169,17 @@ def check_static_variables(patient_demographics: list[str], predicates_df: pl.Da
169169
...
170170
ValueError: Static predicate 'female' not found in the predicates dataframe.
171171
"""
172+
173+
predicates_constraints = []
174+
172175
for demographic in patient_demographics:
173176
if demographic not in predicates_df.columns:
174177
raise ValueError(f"Static predicate '{demographic}' not found in the predicates dataframe.")
175178

176-
keep_expr = ((pl.col("timestamp").is_null()) & (pl.col(demographic) == 1)).alias("keep_expr")
177-
178-
exclude_expr = ~keep_expr
179-
exclude_count = predicates_df.filter(exclude_expr).shape[0]
180-
181-
logger.info(f"Excluding {exclude_count:,} rows due to the '{demographic}' criteria.")
182-
183-
predicates_df = predicates_df.filter(
184-
pl.col("subject_id").is_in(predicates_df.filter(keep_expr).select("subject_id").unique())
179+
predicates_constraints.append(
180+
(pl.col("timestamp").is_null() & (pl.col(demographic) > 0)).any().over("subject_id")
185181
)
186182

187-
return predicates_df.drop_nulls(subset=["timestamp"]).drop(
188-
*[x for x in patient_demographics if x in predicates_df.columns]
189-
)
183+
predicate_filter = pl.all_horizontal(predicates_constraints)
184+
185+
return predicates_df.filter(predicate_filter).drop_nulls(subset=["timestamp"]).drop(patient_demographics)

0 commit comments

Comments
 (0)