Skip to content

Commit 26e4fa5

Browse files
committed
add merlion ad
1 parent 0a9c5d8 commit 26e4fa5

File tree

8 files changed

+251
-3
lines changed

8 files changed

+251
-3
lines changed

THIRD_PARTY_LICENSES.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,11 @@ rrcf
471471
* Source code: https://github.com/kLabUM/rrcf
472472
* Project home: https://github.com/kLabUM/rrcf
473473

474+
Merlion
475+
* Copyright 2021 Salesforce.com Inc
476+
* License: BSD-3 Clause License
477+
* Source code: https://github.com/salesforce/Merlion
478+
* Project Home: https://github.com/salesforce/Merlion
474479

475480
=============================== Licenses ===============================
476481
------------------------------------------------------------------------

ads/opctl/operator/lowcode/anomaly/const.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import random
78
from ads.common.extended_enum import ExtendedEnumMeta
89
from ads.opctl.operator.lowcode.common.const import DataColumns
10+
from merlion.models.anomaly import autoencoder, deep_point_anomaly_detector, isolation_forest, spectral_residual, windstats, windstats_monthly
11+
from merlion.models.anomaly.change_point import bocpd
12+
from merlion.models.forecast import prophet
913

1014

1115
class SupportedModels(str, metaclass=ExtendedEnumMeta):
@@ -14,6 +18,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
1418
AutoMLX = "automlx"
1519
AutoTS = "autots"
1620
Auto = "auto"
21+
MerilonAD = "merlion_ad"
1722
# TODS = "tods"
1823

1924
class NonTimeADSupportedModels(str, metaclass=ExtendedEnumMeta):
@@ -56,6 +61,84 @@ class TODSSubModels(str, metaclass=ExtendedEnumMeta):
5661
}
5762

5863

64+
class MerlionADSubmodels(str, metaclass=ExtendedEnumMeta):
65+
"""Supported Merlion AD sub models."""
66+
67+
# point anomaly
68+
AUTOENCODER = "autoencoder"
69+
DAGMM = "dagmm"
70+
DBL = "dbl"
71+
DEEP_POINT_ANOMALY_DETECTOR = "deep_point_anomaly_detector"
72+
ISOLATION_FOREST = "isolation_forest"
73+
LOF = "lof"
74+
LSTM_ED = "lstm_ed"
75+
# RANDOM_CUT_FOREST = "random_cut_forest"
76+
SPECTRAL_RESIDUAL = "spectral_residual"
77+
STAT_RESIDUAL = "stat_residual"
78+
VAE = "vae"
79+
WINDSTATS = "windstats"
80+
WINDSTATS_MONTHLY = "windstats_monthly"
81+
ZMS = "zms"
82+
83+
# forecast_based
84+
ARIMA = "arima"
85+
ETS = "ets"
86+
MSES = "mses"
87+
PROPHET = "prophet"
88+
SARIMA = "sarima"
89+
90+
#changepoint
91+
BOCPD = "bocpd"
92+
93+
94+
MERLIONAD_IMPORT_MODEL_MAP = {
95+
MerlionADSubmodels.AUTOENCODER: ".autoendcoder",
96+
MerlionADSubmodels.DAGMM: ".dagmm",
97+
MerlionADSubmodels.DBL: ".dbl",
98+
MerlionADSubmodels.DEEP_POINT_ANOMALY_DETECTOR: ".deep_point_anomaly_detector",
99+
MerlionADSubmodels.ISOLATION_FOREST: ".isolation_forest",
100+
MerlionADSubmodels.LOF: ".lof",
101+
MerlionADSubmodels.LSTM_ED: ".lstm_ed",
102+
# MerlionADSubmodels.RANDOM_CUT_FOREST: ".random_cut_forest",
103+
MerlionADSubmodels.SPECTRAL_RESIDUAL: ".spectral_residual",
104+
MerlionADSubmodels.STAT_RESIDUAL: ".stat_residual",
105+
MerlionADSubmodels.VAE: ".vae",
106+
MerlionADSubmodels.WINDSTATS: ".windstats",
107+
MerlionADSubmodels.WINDSTATS_MONTHLY: ".windstats_monthly",
108+
MerlionADSubmodels.ZMS: ".zms",
109+
MerlionADSubmodels.ARIMA: ".forecast_based.arima",
110+
MerlionADSubmodels.ETS: ".forecast_based.ets",
111+
MerlionADSubmodels.MSES: ".forecast_based.mses",
112+
MerlionADSubmodels.PROPHET: ".forecast_based.prophet",
113+
MerlionADSubmodels.SARIMA: ".forecast_based.sarima",
114+
MerlionADSubmodels.BOCPD: ".change_point.bocpd",
115+
}
116+
117+
118+
MERLIONAD_MODEL_MAP = {
119+
MerlionADSubmodels.AUTOENCODER: "AutoEncoder",
120+
MerlionADSubmodels.DAGMM: "DAGMM",
121+
MerlionADSubmodels.DBL: "DynamicBaseline",
122+
MerlionADSubmodels.DEEP_POINT_ANOMALY_DETECTOR: "DeepPointAnomalyDetector",
123+
MerlionADSubmodels.ISOLATION_FOREST: "IsolationForest",
124+
MerlionADSubmodels.LOF: "LOF",
125+
MerlionADSubmodels.LSTM_ED: "LSTMED",
126+
# MerlionADSubmodels.RANDOM_CUT_FOREST: "RandomCutForest",
127+
MerlionADSubmodels.SPECTRAL_RESIDUAL: "SpectralResidual",
128+
MerlionADSubmodels.STAT_RESIDUAL: "StatThreshold",
129+
MerlionADSubmodels.VAE: "VAE",
130+
MerlionADSubmodels.WINDSTATS: "WindStats",
131+
MerlionADSubmodels.WINDSTATS_MONTHLY: "MonthlyWindStats",
132+
MerlionADSubmodels.ZMS: "ZMS",
133+
MerlionADSubmodels.ARIMA: "ArimaDetector",
134+
MerlionADSubmodels.ETS: "ETSDetector",
135+
MerlionADSubmodels.MSES: "MSESDetector",
136+
MerlionADSubmodels.PROPHET: "ProphetDetector",
137+
MerlionADSubmodels.SARIMA: "SarimaDetector",
138+
MerlionADSubmodels.BOCPD: "BOCPD",
139+
}
140+
141+
59142
class SupportedMetrics(str, metaclass=ExtendedEnumMeta):
60143
UNSUPERVISED_UNIFY95 = "unsupervised_unify95"
61144
UNSUPERVISED_UNIFY95_LOG_LOSS = "unsupervised_unify95_log_loss"
@@ -94,5 +177,6 @@ class OutputColumns(str, metaclass=ExtendedEnumMeta):
94177
Series = DataColumns.Series
95178

96179

180+
MERLION_DEFAULT_MODEL = "prophet"
97181
TODS_DEFAULT_MODEL = "ocsvm"
98182
SUBSAMPLE_THRESHOLD = 1000
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import importlib
8+
from collections import defaultdict
9+
10+
import numpy as np
11+
import pandas as pd
12+
from merlion.utils import TimeSeries
13+
14+
from ads.common.decorator.runtime_dependency import runtime_dependency
15+
from ads.opctl.operator.lowcode.anomaly.const import (
16+
MERLION_DEFAULT_MODEL,
17+
MERLIONAD_IMPORT_MODEL_MAP,
18+
MERLIONAD_MODEL_MAP,
19+
OutputColumns,
20+
)
21+
from tests.integration import other
22+
23+
from .anomaly_dataset import AnomalyOutput
24+
from .base_model import AnomalyOperatorBaseModel
25+
26+
27+
class AnomalyMerlionOperatorModel(AnomalyOperatorBaseModel):
28+
"""Class representing Merlion Anomaly Detection operator model."""
29+
30+
@runtime_dependency(
31+
module="merlion",
32+
err_msg=(
33+
"Please run `pip3 install salesforce-merlion[all]` to "
34+
"install the required packages."
35+
),
36+
)
37+
def _get_config_model(self, model_list):
38+
"""
39+
Returns a dictionary with model names as keys and a list of model config and model object as values.
40+
41+
Parameters
42+
----------
43+
model_list : list
44+
A list of model names.
45+
46+
Returns
47+
-------
48+
dict
49+
A dictionary with model names as keys and a list of model config and model object as values.
50+
"""
51+
model_config_map = {}
52+
for model_name in model_list:
53+
model_module = importlib.import_module(
54+
name=MERLIONAD_IMPORT_MODEL_MAP.get(model_name),
55+
package="merlion.models.anomaly",
56+
)
57+
model_config = getattr(
58+
model_module, MERLIONAD_MODEL_MAP.get(model_name) + "Config"
59+
)
60+
model = getattr(model_module, MERLIONAD_MODEL_MAP.get(model_name))
61+
model_config_map[model_name] = [model_config, model]
62+
return model_config_map
63+
64+
def _build_model(self) -> AnomalyOutput:
65+
"""
66+
Builds a Merlion anomaly detection model and trains it using the given data.
67+
68+
Parameters
69+
----------
70+
None
71+
72+
Returns
73+
-------
74+
AnomalyOutput
75+
An AnomalyOutput object containing the anomaly detection results.
76+
"""
77+
model_kwargs = self.spec.model_kwargs
78+
anomaly_output = AnomalyOutput(date_column="index")
79+
anomaly_threshold = model_kwargs.get("anomaly_threshold", 95)
80+
model_config_map = {}
81+
if model_kwargs.get("sub_model", None):
82+
model_config_map = self._get_config_model(model_kwargs.get("sub_model"))
83+
else:
84+
from merlion.models.anomaly.forecast_based.prophet import ( # noqa: I001
85+
ProphetDetector,
86+
ProphetDetectorConfig,
87+
)
88+
89+
model_config_map[MERLION_DEFAULT_MODEL] = [
90+
ProphetDetectorConfig,
91+
ProphetDetector,
92+
]
93+
94+
date_column = self.spec.datetime_column.name
95+
96+
anomaly_output = AnomalyOutput(date_column=date_column)
97+
# model_objects = defaultdict(list)
98+
for target, df in self.datasets.full_data_dict.items():
99+
data = df.set_index(date_column)
100+
data = TimeSeries.from_pd(data)
101+
for model_name, (model_config, model) in model_config_map.items():
102+
model_config = model_config(**self.spec.model_kwargs)
103+
model = model(model_config)
104+
105+
106+
scores = model.train(train_data=data, anomaly_labels=None)
107+
108+
try:
109+
y_pred = model.get_anomaly_label(data)
110+
y_pred =(y_pred.to_pd().reset_index()["anom_score"] > 0).astype(int)
111+
except Exception as e:
112+
y_pred = (
113+
scores.to_pd().reset_index()["anom_score"]
114+
> np.percentile(
115+
scores.to_pd().reset_index()["anom_score"], anomaly_threshold
116+
)
117+
).astype(int)
118+
119+
index_col = df.columns[0]
120+
121+
anomaly = pd.DataFrame(
122+
{index_col: df[index_col], OutputColumns.ANOMALY_COL: y_pred}
123+
).reset_index(drop=True)
124+
score = pd.DataFrame(
125+
{
126+
index_col: df[index_col],
127+
OutputColumns.SCORE_COL: scores.to_pd().reset_index()[
128+
"anom_score"
129+
],
130+
}
131+
).reset_index(drop=True)
132+
# model_objects[model_name].append(model)
133+
134+
anomaly_output.add_output(target, anomaly, score)
135+
return anomaly_output
136+
137+
def _generate_report(self):
138+
"""Genreates a report for the model."""
139+
import report_creator as rc
140+
141+
other_sections = [
142+
rc.Heading("Selected Models Overview", level=2),
143+
rc.Text(
144+
"The following tables provide information regarding the chosen model."
145+
),
146+
]
147+
148+
model_description = rc.Text(
149+
"The Merlion anomaly detection model is a full-stack automated machine learning system for anomaly detection."
150+
)
151+
152+
return (
153+
model_description,
154+
other_sections,
155+
)

ads/opctl/operator/lowcode/anomaly/model/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ..const import NonTimeADSupportedModels, SupportedModels
99
from ..operator_config import AnomalyOperatorConfig
1010
from .anomaly_dataset import AnomalyDatasets
11+
from .anomaly_merlion import AnomalyMerlionOperatorModel
1112
from .automlx import AutoMLXOperatorModel
1213
from .autots import AutoTSOperatorModel
1314

@@ -48,6 +49,7 @@ class AnomalyOperatorModelFactory:
4849
SupportedModels.AutoMLX: AutoMLXOperatorModel,
4950
# SupportedModels.TODS: TODSOperatorModel,
5051
SupportedModels.AutoTS: AutoTSOperatorModel,
52+
SupportedModels.MerilonAD: AnomalyMerlionOperatorModel
5153
}
5254

5355
_NonTime_MAP = {

ads/opctl/operator/lowcode/anomaly/model/randomcutforest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _build_model(self) -> AnomalyOutput:
3636
# Set tree parameters
3737
num_trees = model_kwargs.get("num_trees", 200)
3838
shingle_size = model_kwargs.get("shingle_size", None)
39-
anomaly_threshold = model_kwargs.get("anamoly_threshold", 95)
39+
anomaly_threshold = model_kwargs.get("anomaly_threshold", 95)
4040

4141
for target, df in self.datasets.full_data_dict.items():
4242
try:

ads/opctl/operator/lowcode/anomaly/schema.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ spec:
364364
- oneclasssvm
365365
- isolationforest
366366
- randomcutforest
367+
- merlion_ad
367368
meta:
368369
description: "The model to be used for anomaly detection"
369370

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ anomaly = [
179179
"oracledb",
180180
"report-creator==1.0.9",
181181
"rrcf==0.4.4",
182-
"scikit-learn"
182+
"scikit-learn",
183+
"salesforce-merlion[all]==2.0.4"
183184
]
184185
recommender = [
185186
"oracle_ads[opctl]",

tests/operators/anomaly/test_anomaly_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
for d in DATASETS:
5353
parameters_short.append((m, d))
5454

55-
MODELS = ["autots", "oneclasssvm", "isolationforest", "randomcutforest"]
55+
MODELS = ["autots", "oneclasssvm", "isolationforest", "randomcutforest", "merlion_ad"]
5656

5757
@pytest.mark.parametrize("model", ["autots"])
5858
def test_artificial_big(model):

0 commit comments

Comments
 (0)