Skip to content

Commit 191f969

Browse files
authored
ENH Enable prediction of isolation forest in parallel (scikit-learn#28622)
1 parent 610d4f7 commit 191f969

File tree

4 files changed

+362
-12
lines changed

4 files changed

+362
-12
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""
2+
==========================================
3+
IsolationForest prediction benchmark
4+
==========================================
5+
A test of IsolationForest on classical anomaly detection datasets.
6+
7+
The benchmark is run as follows:
8+
1. The dataset is randomly split into a training set and a test set, both
9+
assumed to contain outliers.
10+
2. Isolation Forest is trained on the training set fixed at 1000 samples.
11+
3. The test samples are scored using the trained model at:
12+
- 1000, 10000, 50000 samples
13+
- 10, 100, 1000 features
14+
- 0.01, 0.1, 0.5 contamination
15+
- 1, 2, 3, 4 n_jobs
16+
17+
We compare the prediction time at the very end.
18+
19+
Here are instructions for running this benchmark to compare runtime against main branch:
20+
21+
1. Build and run on a branch or main, e.g. for a branch named `pr`:
22+
23+
```bash
24+
python bench_isolation_forest_predict.py bench ~/bench_results pr
25+
```
26+
27+
2. Plotting to compare two branches `pr` and `main`:
28+
29+
```bash
30+
python bench_isolation_forest_predict.py plot ~/bench_results pr main results_image.png
31+
```
32+
"""
33+
34+
import argparse
35+
from collections import defaultdict
36+
from pathlib import Path
37+
from time import time
38+
39+
import numpy as np
40+
import pandas as pd
41+
from joblib import parallel_config
42+
43+
from sklearn.ensemble import IsolationForest
44+
45+
print(__doc__)
46+
47+
48+
def get_data(
49+
n_samples_train, n_samples_test, n_features, contamination=0.1, random_state=0
50+
):
51+
"""Function based on code from: https://scikit-learn.org/stable/
52+
auto_examples/ensemble/plot_isolation_forest.html#sphx-glr-auto-
53+
examples-ensemble-plot-isolation-forest-py
54+
"""
55+
rng = np.random.RandomState(random_state)
56+
57+
X = 0.3 * rng.randn(n_samples_train, n_features)
58+
X_train = np.r_[X + 2, X - 2]
59+
60+
X = 0.3 * rng.randn(n_samples_test, n_features)
61+
X_test = np.r_[X + 2, X - 2]
62+
63+
n_outliers = int(np.floor(contamination * n_samples_test))
64+
X_outliers = rng.uniform(low=-4, high=4, size=(n_outliers, n_features))
65+
66+
outlier_idx = rng.choice(np.arange(0, n_samples_test), n_outliers, replace=False)
67+
X_test[outlier_idx, :] = X_outliers
68+
69+
return X_train, X_test
70+
71+
72+
def plot(args):
73+
import matplotlib.pyplot as plt
74+
import seaborn as sns
75+
76+
bench_results = Path(args.bench_results)
77+
pr_name = args.pr_name
78+
main_name = args.main_name
79+
image_path = args.image_path
80+
81+
results_path = Path(bench_results)
82+
pr_path = results_path / f"{pr_name}.csv"
83+
main_path = results_path / f"{main_name}.csv"
84+
image_path = results_path / image_path
85+
86+
df_pr = pd.read_csv(pr_path).assign(branch=pr_name)
87+
df_main = pd.read_csv(main_path).assign(branch=main_name)
88+
89+
# Merge the two datasets on the common columns
90+
merged_data = pd.merge(
91+
df_pr,
92+
df_main,
93+
on=["n_samples_test", "n_jobs"],
94+
suffixes=("_pr", "_main"),
95+
)
96+
97+
# Set up the plotting grid
98+
sns.set(style="whitegrid", context="notebook", font_scale=1.5)
99+
100+
# Create a figure with subplots
101+
fig, axes = plt.subplots(1, 2, figsize=(18, 6), sharex=True, sharey=True)
102+
103+
# Plot predict time as a function of n_samples_test with different n_jobs
104+
print(merged_data["n_jobs"].unique())
105+
ax = axes[0]
106+
sns.lineplot(
107+
data=merged_data,
108+
x="n_samples_test",
109+
y="predict_time_pr",
110+
hue="n_jobs",
111+
style="n_jobs",
112+
markers="o",
113+
ax=ax,
114+
legend="full",
115+
)
116+
ax.set_title(f"Predict Time vs. n_samples_test - {pr_name} branch")
117+
ax.set_ylabel("Predict Time (Seconds)")
118+
ax.set_xlabel("n_samples_test")
119+
120+
ax = axes[1]
121+
sns.lineplot(
122+
data=merged_data,
123+
x="n_samples_test",
124+
y="predict_time_main",
125+
hue="n_jobs",
126+
style="n_jobs",
127+
markers="X",
128+
dashes=True,
129+
ax=ax,
130+
legend=None,
131+
)
132+
ax.set_title(f"Predict Time vs. n_samples_test - {main_name} branch")
133+
ax.set_ylabel("Predict Time")
134+
ax.set_xlabel("n_samples_test")
135+
136+
# Adjust layout and display the plots
137+
plt.tight_layout()
138+
fig.savefig(image_path, bbox_inches="tight")
139+
print(f"Saved image to {image_path}")
140+
141+
142+
def bench(args):
143+
results_dir = Path(args.bench_results)
144+
branch = args.branch
145+
random_state = 1
146+
147+
results = defaultdict(list)
148+
149+
# Loop over all datasets for fitting and scoring the estimator:
150+
n_samples_train = 1000
151+
for n_samples_test in [
152+
1000,
153+
10000,
154+
50000,
155+
]:
156+
for n_features in [10, 100, 1000]:
157+
for contamination in [0.01, 0.1, 0.5]:
158+
for n_jobs in [1, 2, 3, 4]:
159+
X_train, X_test = get_data(
160+
n_samples_train,
161+
n_samples_test,
162+
n_features,
163+
contamination,
164+
random_state,
165+
)
166+
167+
print("--- Fitting the IsolationForest estimator...")
168+
model = IsolationForest(n_jobs=-1, random_state=random_state)
169+
tstart = time()
170+
model.fit(X_train)
171+
fit_time = time() - tstart
172+
173+
# clearcache
174+
for _ in range(1000):
175+
1 + 1
176+
with parallel_config("threading", n_jobs=n_jobs):
177+
tstart = time()
178+
model.decision_function(X_test) # the lower, the more abnormal
179+
predict_time = time() - tstart
180+
181+
results["predict_time"].append(predict_time)
182+
results["fit_time"].append(fit_time)
183+
results["n_samples_train"].append(n_samples_train)
184+
results["n_samples_test"].append(n_samples_test)
185+
results["n_features"].append(n_features)
186+
results["contamination"].append(contamination)
187+
results["n_jobs"].append(n_jobs)
188+
189+
df = pd.DataFrame(results)
190+
df.to_csv(results_dir / f"{branch}.csv", index=False)
191+
192+
193+
if __name__ == "__main__":
194+
parser = argparse.ArgumentParser()
195+
196+
# parse arguments for benchmarking
197+
subparsers = parser.add_subparsers()
198+
bench_parser = subparsers.add_parser("bench")
199+
bench_parser.add_argument("bench_results")
200+
bench_parser.add_argument("branch")
201+
bench_parser.set_defaults(func=bench)
202+
203+
# parse arguments for plotting
204+
plot_parser = subparsers.add_parser("plot")
205+
plot_parser.add_argument("bench_results")
206+
plot_parser.add_argument("pr_name")
207+
plot_parser.add_argument("main_name")
208+
plot_parser.add_argument("image_path")
209+
plot_parser.set_defaults(func=plot)
210+
211+
# enable the parser and run the relevant function
212+
args = parser.parse_args()
213+
args.func(args)

doc/whats_new/v1.6.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ Changelog
130130
by parallelizing the initial search for bin thresholds
131131
:pr:`28064` by :user:`Christian Lorentzen <lorentzenchr>`.
132132

133+
- |Efficiency| :class:`ensemble.IsolationForest` now runs parallel jobs
134+
during :term:`predict` offering a speedup of up to 2-4x on sample sizes
135+
larger than 2000 using `joblib`.
136+
:pr:`28622` by :user:`Adam Li <adam2392>` and
137+
:user:`Sérgio Pereira <sergiormpereira>`.
138+
133139
:mod:`sklearn.impute`
134140
.....................
135141

0 commit comments

Comments
 (0)