Skip to content

Commit b01e022

Browse files
committed
add alm_threshold
1 parent 2220351 commit b01e022

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

ads/opctl/operator/lowcode/anomaly/model/anomaly_merlion.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MERLIONAD_IMPORT_MODEL_MAP,
1616
MERLIONAD_MODEL_MAP,
1717
OutputColumns,
18+
SupportedModels,
1819
)
1920

2021
from .anomaly_dataset import AnomalyOutput
@@ -84,16 +85,19 @@ def _build_model(self) -> AnomalyOutput:
8485
data = df.set_index(date_column)
8586
data = TimeSeries.from_pd(data)
8687
for model_name, (model_config, model) in model_config_map.items():
87-
model_config = model_config(
88-
**{
89-
**self.spec.model_kwargs,
90-
"threshold": AggregateAlarms(
91-
alm_threshold=model_kwargs.get("alm_threshold")
92-
if model_kwargs.get("alm_threshold")
93-
else None
94-
),
95-
}
96-
)
88+
if self.spec.model == SupportedModels.BOCPD:
89+
model_config = model_config(**self.spec.model_kwargs)
90+
else:
91+
model_config = model_config(
92+
**{
93+
**self.spec.model_kwargs,
94+
"threshold": AggregateAlarms(
95+
alm_threshold=model_kwargs.get("alm_threshold")
96+
if model_kwargs.get("alm_threshold")
97+
else None
98+
),
99+
}
100+
)
97101
if hasattr(model_config, "target_seq_index"):
98102
model_config.target_seq_index = df.columns.get_loc(
99103
self.spec.target_column

0 commit comments

Comments
 (0)