Skip to content

Commit 8788679

Browse files
committed
Losen dependencies
1 parent 6e6f0dd commit 8788679

File tree

2 files changed

+823
-0
lines changed

2 files changed

+823
-0
lines changed

pyindicators/indicators/adx.py

Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
import pandas as pd
2+
import polars as pl
3+
from typing import Union
4+
from pyindicators.exceptions import PyIndicatorException
5+
6+
from .utils import pad_zero_values_pandas, pad_zero_values_polars
7+
8+
9+
def adx(
10+
data: Union[pd.DataFrame, pl.DataFrame],
11+
period=14,
12+
high_column="High",
13+
low_column="Low",
14+
close_column="Close",
15+
result_adx_column="ADX",
16+
result_pdi_column="+DI",
17+
result_ndi_column="-DI",
18+
) -> Union[pd.DataFrame, pl.DataFrame]:
19+
"""
20+
Calculate the Average Directional Index (ADX) for a given DataFrame.
21+
22+
Args:
23+
data (Union[pd.DataFrame, pl.DataFrame]): Input data containing
24+
the price series.
25+
period (int, optional): Period for the ADX calculation (default: 14).
26+
high_column (str, optional): Column name for the high price series.
27+
low_column (str, optional): Column name for the low price series.
28+
close_column (str, optional): Column name for the close price series.
29+
result_adx_column (str, optional): Column name to store the ADX.
30+
result_pdi_column (str, optional): Column name to store the +DI.
31+
result_ndi_column (str, optional): Column name to store the -DI.
32+
33+
Returns:
34+
Union[pd.DataFrame, pl.DataFrame]: DataFrame with ADX, +DI, and -DI.
35+
"""
36+
37+
# Check if the high, low, and close columns are in the DataFrame
38+
if high_column not in data.columns:
39+
raise PyIndicatorException(
40+
f"Column '{high_column}' not found in DataFrame"
41+
)
42+
43+
if low_column not in data.columns:
44+
raise PyIndicatorException(
45+
f"Column '{low_column}' not found in DataFrame"
46+
)
47+
48+
if close_column not in data.columns:
49+
raise PyIndicatorException(
50+
f"Column '{close_column}' not found in DataFrame"
51+
)
52+
53+
if isinstance(data, pd.DataFrame):
54+
# Pandas version of the ADX calculation
55+
high = data[high_column]
56+
low = data[low_column]
57+
close = data[close_column]
58+
59+
# Calculate True Range (TR)
60+
tr = pd.DataFrame({
61+
'TR': pd.concat([
62+
high - low,
63+
(high - close.shift(1)).abs(),
64+
(low - close.shift(1)).abs()
65+
], axis=1).max(axis=1)
66+
})
67+
68+
# Calculate Directional Movement (+DM and -DM)
69+
plus_dm = pd.DataFrame(
70+
{'+DM': (high.diff() > low.diff()).astype(int)
71+
* (high.diff().clip(lower=0))}
72+
)
73+
minus_dm = pd.DataFrame(
74+
{'-DM': (low.diff() > high.diff()).astype(int)
75+
* (-low.diff().clip(upper=0))}
76+
)
77+
78+
# Smooth the TR, +DM, and -DM over the period
79+
tr_smooth = tr['TR'].rolling(window=period).mean()
80+
plus_dm_smooth = plus_dm['+DM'].rolling(window=period).mean()
81+
minus_dm_smooth = minus_dm['-DM'].rolling(window=period).mean()
82+
83+
# Calculate +DI and -DI
84+
pdi = 100 * (plus_dm_smooth / tr_smooth)
85+
ndi = 100 * (minus_dm_smooth / tr_smooth)
86+
87+
# Smooth the difference to get ADX
88+
adx = pd.DataFrame({
89+
result_adx_column: (pdi - ndi).abs().rolling(window=period).mean()
90+
})
91+
92+
# Add columns to the original dataframe
93+
data[result_adx_column] = adx
94+
data[result_pdi_column] = pdi
95+
data[result_ndi_column] = ndi
96+
97+
pad_zero_values_pandas(data, result_adx_column, period)
98+
pad_zero_values_pandas(data, result_pdi_column, period - 1)
99+
pad_zero_values_pandas(data, result_ndi_column, period - 1)
100+
return data
101+
102+
elif isinstance(data, pl.DataFrame):
103+
# Polars version of the ADX calculation
104+
high = data[high_column]
105+
low = data[low_column]
106+
close = data[close_column]
107+
108+
# Calculate True Range (TR)
109+
tr = pl.max_horizontal([
110+
high - low,
111+
(high - close.shift(1)).abs(),
112+
(low - close.shift(1)).abs()
113+
])
114+
115+
# Calculate Directional Movement (+DM and -DM)
116+
plus_dm = high.diff().clip_min(0)
117+
minus_dm = (-low.diff()).clip_min(0).abs()
118+
119+
# Smooth the TR, +DM, and -DM over the period
120+
# (use rolling sum, not mean)
121+
tr_smooth = tr.rolling_sum(window_size=period, min_periods=1)
122+
plus_dm_smooth = plus_dm.rolling_sum(window_size=period, min_periods=1)
123+
minus_dm_smooth = minus_dm.rolling_sum(
124+
window_size=period, min_periods=1
125+
)
126+
127+
# Calculate +DI and -DI
128+
pdi = 100 * (plus_dm_smooth / tr_smooth)
129+
ndi = 100 * (minus_dm_smooth / tr_smooth)
130+
131+
# Calculate ADX (average of the absolute difference
132+
# between +DI and -DI)
133+
134+
di_diff = (pdi - ndi).abs()
135+
# Smooth the difference to get ADX
136+
adx = di_diff.rolling_mean(window_size=period)
137+
138+
# Add columns to the original dataframe
139+
data = data.with_columns([
140+
adx.alias(result_adx_column),
141+
pdi.alias(result_pdi_column),
142+
ndi.alias(result_ndi_column)
143+
])
144+
145+
# Pad the first `period` rows with zero values
146+
data = pad_zero_values_polars(data, result_adx_column, period)
147+
data = pad_zero_values_polars(data, result_pdi_column, period - 1)
148+
data = pad_zero_values_polars(data, result_ndi_column, period - 1)
149+
150+
return data
151+
else:
152+
raise PyIndicatorException(
153+
"Input data must be either a pandas or polars DataFrame."
154+
)
155+
156+
157+
def adx_v2(
158+
data: Union[pd.DataFrame, pl.DataFrame],
159+
period=14,
160+
high_column="High",
161+
low_column="Low",
162+
close_column="Close",
163+
result_adx_column="ADX",
164+
result_pdi_column="+DI",
165+
result_ndi_column="-DI",
166+
) -> Union[pd.DataFrame, pl.DataFrame]:
167+
"""
168+
Calculate the Average Directional Index (ADX) using Wilder's smoothing.
169+
Matches Tulipy's ADX calculation.
170+
171+
Args:
172+
data: Input DataFrame (Pandas or Polars).
173+
period: Period for the ADX calculation (default: 14).
174+
high_column, low_column, close_column: Column names for price data.
175+
result_adx_column, result_pdi_column,
176+
result_ndi_column: Output column names.
177+
178+
Returns:
179+
DataFrame with ADX, +DI, and -DI.
180+
"""
181+
if high_column not in data.columns \
182+
or low_column not in data.columns \
183+
or close_column not in data.columns:
184+
raise PyIndicatorException(
185+
"High, Low, or Close column not found in DataFrame."
186+
)
187+
188+
if isinstance(data, pd.DataFrame):
189+
# Pandas version
190+
high, low, close = data[high_column], data[low_column], \
191+
data[close_column]
192+
193+
tr = pd.concat([
194+
high - low,
195+
(high - close.shift(1)).abs(),
196+
(low - close.shift(1)).abs()
197+
], axis=1).max(axis=1)
198+
199+
plus_dm = high.diff().clip(lower=0)
200+
minus_dm = -low.diff().clip(upper=0)
201+
202+
# Wilder’s smoothing with EMA
203+
tr_smooth = tr.ewm(span=period, adjust=False).mean()
204+
plus_dm_smooth = plus_dm.ewm(span=period, adjust=False).mean()
205+
minus_dm_smooth = minus_dm.ewm(span=period, adjust=False).mean()
206+
207+
pdi = 100 * (plus_dm_smooth / tr_smooth)
208+
ndi = 100 * (minus_dm_smooth / tr_smooth)
209+
adx = (100 * (pdi - ndi).abs().ewm(span=period, adjust=False).mean())
210+
211+
# Add results to DataFrame
212+
data[result_adx_column] = adx
213+
data[result_pdi_column] = pdi
214+
data[result_ndi_column] = ndi
215+
216+
# Pad with zeros
217+
pad_zero_values_pandas(data, result_adx_column, period)
218+
pad_zero_values_pandas(data, result_pdi_column, period - 1)
219+
pad_zero_values_pandas(data, result_ndi_column, period - 1)
220+
221+
return data
222+
223+
elif isinstance(data, pl.DataFrame):
224+
# Polars version
225+
high, low, close = data[high_column], data[low_column], \
226+
data[close_column]
227+
228+
tr = pl.max_horizontal([
229+
high - low,
230+
(high - close.shift(1)).abs(),
231+
(low - close.shift(1)).abs()
232+
])
233+
234+
plus_dm = high.diff().clip_min(0)
235+
minus_dm = (-low.diff()).clip_min(0).abs()
236+
237+
# Wilder’s smoothing (manual EMA for Polars)
238+
def wilder_ema(series, period):
239+
alpha = 1 / period
240+
return series.cumsum() * alpha
241+
242+
tr_smooth = wilder_ema(tr, period)
243+
plus_dm_smooth = wilder_ema(plus_dm, period)
244+
minus_dm_smooth = wilder_ema(minus_dm, period)
245+
246+
pdi = 100 * (plus_dm_smooth / tr_smooth)
247+
ndi = 100 * (minus_dm_smooth / tr_smooth)
248+
adx = (100 * (pdi - ndi).abs()).cumsum() / period
249+
250+
# Add results to DataFrame
251+
data = data.with_columns([
252+
adx.alias(result_adx_column),
253+
pdi.alias(result_pdi_column),
254+
ndi.alias(result_ndi_column)
255+
])
256+
257+
# Pad with zeros
258+
data = pad_zero_values_polars(data, result_adx_column, period)
259+
data = pad_zero_values_polars(data, result_pdi_column, period - 1)
260+
data = pad_zero_values_polars(data, result_ndi_column, period - 1)
261+
262+
return data
263+
264+
else:
265+
raise PyIndicatorException(
266+
"Input data must be either a pandas or polars DataFrame."
267+
)
268+
269+
270+
def di(
271+
data: Union[pd.DataFrame, pl.DataFrame],
272+
period=14,
273+
high_column="High",
274+
low_column="Low",
275+
close_column="Close",
276+
result_pdi_column="+DI",
277+
result_ndi_column="-DI",
278+
) -> Union[pd.DataFrame, pl.DataFrame]:
279+
"""
280+
Calculate the +DI and -DI indicators exactly like Tulipy,
281+
supporting both Pandas and Polars.
282+
283+
Args:
284+
data (Union[pd.DataFrame, pl.DataFrame]): Input data
285+
containing the price series.
286+
period (int, optional): Period for the DI calculation (default: 14).
287+
high_column (str, optional): Column name for the high price series.
288+
low_column (str, optional): Column name for the low price series.
289+
close_column (str, optional): Column name for the close price series.
290+
result_pdi_column (str, optional): Column name to store the +DI.
291+
result_ndi_column (str, optional): Column name to store the -DI.
292+
293+
Returns:
294+
Union[pd.DataFrame, pl.DataFrame]: DataFrame with +DI and -DI.
295+
"""
296+
297+
if isinstance(data, pd.DataFrame):
298+
high = data[high_column]
299+
low = data[low_column]
300+
close = data[close_column]
301+
302+
# True Range
303+
tr = pd.concat([
304+
high - low,
305+
(high - close.shift(1)).abs(),
306+
(low - close.shift(1)).abs()
307+
], axis=1).max(axis=1)
308+
309+
# Directional Movement
310+
plus_dm = (
311+
(high.diff() > low.shift(1) - low) & (high.diff() > 0)
312+
) * high.diff()
313+
minus_dm = (
314+
(low.shift(1) - low > high.diff()) & (low.shift(1) - low > 0)
315+
) * (low.shift(1) - low)
316+
317+
# Smoothed values
318+
tr_smooth = tr.rolling(window=period).sum()
319+
plus_dm_smooth = plus_dm.rolling(window=period).sum()
320+
minus_dm_smooth = minus_dm.rolling(window=period).sum()
321+
322+
# Calculate +DI and -DI
323+
pdi = 100 * (plus_dm_smooth / tr_smooth)
324+
ndi = 100 * (minus_dm_smooth / tr_smooth)
325+
326+
# Add to DataFrame
327+
data[result_pdi_column] = pdi
328+
data[result_ndi_column] = ndi
329+
330+
# Pad initial values with zero
331+
# (replace NaN values for first `period-1` rows)
332+
data[result_pdi_column].iloc[:period-1] = 0
333+
data[result_ndi_column].iloc[:period-1] = 0
334+
335+
return data
336+
337+
elif isinstance(data, pl.DataFrame):
338+
high = data[high_column]
339+
low = data[low_column]
340+
close = data[close_column]
341+
342+
# True Range
343+
tr = pl.max_horizontal([
344+
high - low,
345+
(high - close.shift(1)).abs(),
346+
(low - close.shift(1)).abs()
347+
])
348+
349+
# Directional Movement
350+
plus_dm = (high.diff() > low.shift(1) - low) & (high.diff() > 0)
351+
plus_dm = plus_dm * high.diff()
352+
353+
minus_dm = (
354+
low.shift(1) - low > high.diff()
355+
) & (low.shift(1) - low > 0)
356+
minus_dm = minus_dm * (low.shift(1) - low)
357+
358+
# Smoothed values
359+
tr_smooth = tr.rolling_sum(window_size=period)
360+
plus_dm_smooth = plus_dm.rolling_sum(window_size=period)
361+
minus_dm_smooth = minus_dm.rolling_sum(window_size=period)
362+
363+
# Calculate +DI and -DI
364+
pdi = 100 * (plus_dm_smooth / tr_smooth)
365+
ndi = 100 * (minus_dm_smooth / tr_smooth)
366+
367+
# Add to DataFrame
368+
data = data.with_columns([
369+
pdi.alias(result_pdi_column),
370+
ndi.alias(result_ndi_column)
371+
])
372+
373+
# Pad initial values with zero
374+
# (replace NaN values for first `period-1` rows)
375+
data = data.with_columns([
376+
pl.when(pl.col(result_pdi_column).is_null()).then(0)
377+
.otherwise(pl.col(result_pdi_column)).alias(result_pdi_column),
378+
pl.when(pl.col(result_ndi_column).is_null()).then(0)
379+
.otherwise(pl.col(result_ndi_column)).alias(result_ndi_column)
380+
])
381+
382+
return data
383+
384+
else:
385+
raise ValueError(
386+
"Input data must be either a pandas or polars DataFrame."
387+
)

0 commit comments

Comments
 (0)