Skip to content

Commit 2220351

Browse files
committed
add alm_threshold
1 parent a2c8257 commit 2220351

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import pandas as pd
10+
from merlion.post_process.threshold import AggregateAlarms
1011
from merlion.utils import TimeSeries
1112

1213
from ads.common.decorator.runtime_dependency import runtime_dependency
@@ -83,7 +84,16 @@ def _build_model(self) -> AnomalyOutput:
8384
data = df.set_index(date_column)
8485
data = TimeSeries.from_pd(data)
8586
for model_name, (model_config, model) in model_config_map.items():
86-
model_config = model_config(**self.spec.model_kwargs)
87+
model_config = model_config(
88+
**{
89+
**self.spec.model_kwargs,
90+
"threshold": AggregateAlarms(
91+
alm_threshold=model_kwargs.get("alm_threshold")
92+
if model_kwargs.get("alm_threshold")
93+
else None
94+
),
95+
}
96+
)
8797
if hasattr(model_config, "target_seq_index"):
8898
model_config.target_seq_index = df.columns.get_loc(
8999
self.spec.target_column

0 commit comments

Comments
 (0)