Skip to content

Commit 4bc3f3d

Browse files
committed
Add default params
1 parent 9042ed4 commit 4bc3f3d

File tree

1 file changed

+22
-34
lines changed

1 file changed

+22
-34
lines changed

pyindicators/indicators/williams_percent_range.py

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,63 +7,51 @@
77
def willr(
88
data: Union[pd.DataFrame, pl.DataFrame],
99
period: int = 14,
10-
result_column: str = None,
10+
result_column: str = "WILLR",
1111
high_column: str = "High",
1212
low_column: str = "Low",
1313
close_column: str = "Close"
1414
) -> Union[pd.DataFrame, pl.DataFrame]:
15-
"""
16-
Function to calculate the Williams %R indicator of a series.
1715

18-
Args:
19-
data (Union[pd.DataFrame, pl.DataFrame]): The input data.
20-
source_column (str): The name of the series.
21-
period (int): The period for the Williams %R calculation.
22-
result_column (str, optional): The name of the column to store
23-
the Williams %R values. Defaults to None, which means it will
24-
be named "WilliamsR_{period}".
25-
26-
Returns:
27-
Union[pd.DataFrame, pl.DataFrame]: The DataFrame with
28-
the Williams %R column added.
29-
"""
30-
31-
# Check if the high and low columns are present
3216
if high_column not in data.columns:
33-
raise PyIndicatorException(
34-
f"Column '{high_column}' not found in DataFrame"
35-
)
17+
raise PyIndicatorException(f"Column '{high_column}' not found in DataFrame")
3618

3719
if low_column not in data.columns:
38-
raise PyIndicatorException(
39-
f"Column '{low_column}' not found in DataFrame"
40-
)
20+
raise PyIndicatorException(f"Column '{low_column}' not found in DataFrame")
4121

4222
if isinstance(data, pd.DataFrame):
43-
data["high_n"] = data[high_column]\
44-
.rolling(window=period, min_periods=1).max()
45-
data["low_n"] = data[low_column]\
46-
.rolling(window=period, min_periods=1).min()
23+
data["high_n"] = data[high_column].rolling(window=period, min_periods=1).max()
24+
data["low_n"] = data[low_column].rolling(window=period, min_periods=1).min()
25+
4726
data[result_column] = ((data["high_n"] - data[close_column]) /
4827
(data["high_n"] - data["low_n"])) * -100
28+
29+
# Set the first `period` rows to 0 using .iloc
30+
if not data.empty:
31+
data.iloc[:period - 1, data.columns.get_loc(result_column)] = 0
32+
4933
return data.drop(columns=["high_n", "low_n"])
5034

5135
elif isinstance(data, pl.DataFrame):
52-
high_n = data.select(pl.col(high_column).rolling_max(period)
53-
.alias("high_n"))
54-
low_n = data.select(pl.col(low_column).rolling_min(period)
55-
.alias("low_n"))
36+
high_n = data.select(pl.col(high_column).rolling_max(period).alias("high_n"))
37+
low_n = data.select(pl.col(low_column).rolling_min(period).alias("low_n"))
5638

5739
data = data.with_columns([
5840
high_n["high_n"],
5941
low_n["low_n"]
6042
])
43+
6144
data = data.with_columns(
62-
((pl.col("high_n") - pl.col(close_column)) /
63-
(pl.col("high_n") - pl.col("low_n")) * -100).alias(result_column)
45+
((pl.col("high_n") - pl.col(close_column)) / (pl.col("high_n") - pl.col("low_n")) * -100)
46+
.alias(result_column)
6447
)
65-
return data.drop(["high_n", "low_n"])
6648

49+
# Set the first `period` rows of result_column to 0 directly in Polars
50+
if data.height > 0:
51+
zero_values = [0] * (period - 1)+ data[result_column].to_list()[period - 1:]
52+
data = data.with_columns(pl.Series(result_column, zero_values))
53+
54+
return data.drop(["high_n", "low_n"])
6755
else:
6856
raise PyIndicatorException(
6957
"Unsupported data type. Must be pandas or polars DataFrame."

0 commit comments

Comments
 (0)