Skip to content

Commit f8e0eed

Browse files
committed
Fix tests and flake8 warnings
1 parent 0252e51 commit f8e0eed

File tree

6 files changed

+2416
-2330
lines changed

6 files changed

+2416
-2330
lines changed

pyindicators/indicators/rsi.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ def rsi(
1616
data (Union[pd.DataFrame, pl.DataFrame]): The input data.
1717
source_column (str): The name of the series.
1818
period (int): The period for the RSI calculation.
19-
result_column (str, optional): The name of the column to store the RSI values.
20-
Defaults to None, which means it will be named "RSI_{period}".
19+
result_column (str, optional): The name of the column to store
20+
the RSI values. Defaults to None, which means it will
21+
be named "RSI_{period}".
2122
2223
Returns:
23-
Union[pd.DataFrame, pl.DataFrame]: The DataFrame with the RSI column added.
24+
Union[pd.DataFrame, pl.DataFrame]: The DataFrame with the RSI
25+
column added.
2426
"""
2527

2628
if result_column is None:
@@ -53,8 +55,8 @@ def rsi(
5355
delta = data[source_column].diff().fill_null(0)
5456

5557
# Compute gains and losses
56-
gain = delta.clip_min(0)
57-
loss = (-delta).clip_min(0)
58+
gain = delta.clip(0)
59+
loss = (-delta).clip(0)
5860

5961
# Compute rolling averages of gains and losses
6062
avg_gain = gain.rolling_mean(window_size=period, min_periods=period)
@@ -65,7 +67,7 @@ def rsi(
6567
rsi_values = 100 - (100 / (1 + rs))
6668

6769
# Replace first `period` values with nulls (polars uses `None`)
68-
rsi_values = rsi_values.set_at_idx(list(range(period)), None)
70+
rsi_values = rsi_values.scatter(list(range(period)), None)
6971

7072
# Add column to DataFrame
7173
data = data.with_columns(rsi_values.alias(result_column))
@@ -110,8 +112,12 @@ def wilders_rsi(
110112

111113
# Apply Wilder's Smoothing for the remaining values
112114
for i in range(period, len(data)):
113-
avg_gain.iloc[i] = (avg_gain.iloc[i - 1] * (period - 1) + gain.iloc[i]) / period
114-
avg_loss.iloc[i] = (avg_loss.iloc[i - 1] * (period - 1) + loss.iloc[i]) / period
115+
avg_gain.iloc[i] = (
116+
avg_gain.iloc[i - 1] * (period - 1) + gain.iloc[i]
117+
) / period
118+
avg_loss.iloc[i] = (
119+
avg_loss.iloc[i - 1] * (period - 1) + loss.iloc[i]
120+
) / period
115121

116122
rs = avg_gain / avg_loss
117123
data[result_column] = 100 - (100 / (1 + rs))
@@ -121,26 +127,32 @@ def wilders_rsi(
121127

122128
elif isinstance(data, pl.DataFrame):
123129
delta = data[source_column].diff().fill_null(0)
124-
gain = delta.clip_min(0)
125-
loss = (-delta).clip_min(0)
130+
gain = delta.clip(0)
131+
loss = (-delta).clip(0)
126132

127133
# Compute initial SMA (first `period` rows)
128134
avg_gain = gain.rolling_mean(window_size=period, min_periods=period)
129135
avg_loss = loss.rolling_mean(window_size=period, min_periods=period)
130136

137+
# Initialize smoothed gains/losses with the first SMA values
138+
smoothed_gain = avg_gain[:period].to_list()
139+
smoothed_loss = avg_loss[:period].to_list()
140+
131141
# Apply Wilder's Smoothing
132-
smoothed_gain = [None] * period
133-
smoothed_loss = [None] * period
134142
for i in range(period, len(data)):
135-
smoothed_gain.append((smoothed_gain[-1] * (period - 1) + gain[i]) / period)
136-
smoothed_loss.append((smoothed_loss[-1] * (period - 1) + loss[i]) / period)
143+
smoothed_gain.append(
144+
(smoothed_gain[-1] * (period - 1) + gain[i]) / period
145+
)
146+
smoothed_loss.append(
147+
(smoothed_loss[-1] * (period - 1) + loss[i]) / period
148+
)
137149

138150
# Compute RSI
139151
rs = pl.Series(smoothed_gain) / pl.Series(smoothed_loss)
140152
rsi_values = 100 - (100 / (1 + rs))
141153

142154
# Replace first `period` values with None
143-
rsi_values = rsi_values.set_at_idx(list(range(period)), None)
155+
rsi_values = rsi_values.scatter(list(range(period)), None)
144156

145157
# Add column to DataFrame
146158
data = data.with_columns(rsi_values.alias(result_column))

static/images/indicators/rsi.png

26.6 KB
Loading
7.99 KB
Loading

tests/indicators/test_rsi.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,34 +45,32 @@ def test_comparison_pandas(self):
4545
correct_output_pd["Datetime"] = \
4646
pd.to_datetime(correct_output_pd["Datetime"]).dt.tz_localize(None)
4747

48-
print(correct_output_pd.head(40))
49-
print(output.head(40))
50-
# pdt.assert_frame_equal(correct_output_pd, output)
48+
pdt.assert_frame_equal(correct_output_pd, output)
5149

52-
# def test_comparison_polars(self):
50+
def test_comparison_polars(self):
5351

54-
# # Load the correct output in a polars dataframe
55-
# correct_output_pl = pl.read_csv(self.get_correct_output_csv_path())
52+
# Load the correct output in a polars dataframe
53+
correct_output_pl = pl.read_csv(self.get_correct_output_csv_path())
5654

57-
# # Load the source in a polars dataframe
58-
# source = pl.read_csv(self.get_source_csv_path())
55+
# Load the source in a polars dataframe
56+
source = pl.read_csv(self.get_source_csv_path())
5957

60-
# # Generate the polars dataframe
61-
# output = self.generate_polars_df(source)
58+
# Generate the polars dataframe
59+
output = self.generate_polars_df(source)
6260

63-
# # Convert the datetime columns to datetime
64-
# # Convert the 'Datetime' column in both DataFrames to datetime
65-
# output = output.with_columns(
66-
# pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
67-
# )
61+
# Convert the datetime columns to datetime
62+
# Convert the 'Datetime' column in both DataFrames to datetime
63+
output = output.with_columns(
64+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
65+
)
6866

69-
# correct_output_pl = correct_output_pl.with_columns(
70-
# pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
71-
# )
72-
# output = output[correct_output_pl.columns]
73-
# output = self.make_polars_column_datetime_naive(output, "Datetime")
74-
# correct_output_pl = self.make_polars_column_datetime_naive(
75-
# correct_output_pl, "Datetime"
76-
# )
67+
correct_output_pl = correct_output_pl.with_columns(
68+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
69+
)
70+
output = output[correct_output_pl.columns]
71+
output = self.make_polars_column_datetime_naive(output, "Datetime")
72+
correct_output_pl = self.make_polars_column_datetime_naive(
73+
correct_output_pl, "Datetime"
74+
)
7775

78-
# assert_frame_equal(correct_output_pl, output)
76+
assert_frame_equal(correct_output_pl, output)

tests/indicators/test_wilders_rsi.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pandas as pd
2+
import polars as pl
3+
import pandas.testing as pdt
4+
from polars.testing import assert_frame_equal
5+
6+
from tests.resources import TestBaseline
7+
from pyindicators import wilders_rsi
8+
9+
10+
class Test(TestBaseline):
11+
correct_output_csv_filename = \
12+
"WILDERS_RSI_14_BTC-EUR_BINANCE_15m_2023-12-01:00:00_2023-12-25:00:00.csv"
13+
14+
def generate_pandas_df(self, polars_source_df):
15+
polars_source_df = wilders_rsi(
16+
data=polars_source_df,
17+
period=14,
18+
result_column="RSI_14",
19+
source_column="Close"
20+
)
21+
return polars_source_df
22+
23+
def generate_polars_df(self, pandas_source_df):
24+
pandas_source_df = wilders_rsi(
25+
data=pandas_source_df,
26+
period=14,
27+
result_column="RSI_14",
28+
source_column="Close"
29+
)
30+
return pandas_source_df
31+
32+
def test_comparison_pandas(self):
33+
34+
# Load the correct output in a pandas dataframe
35+
correct_output_pd = pd.read_csv(self.get_correct_output_csv_path())
36+
37+
# Load the source in a pandas dataframe
38+
source = pd.read_csv(self.get_source_csv_path())
39+
40+
# Generate the pandas dataframe
41+
output = self.generate_pandas_df(source)
42+
output = output[correct_output_pd.columns]
43+
output["Datetime"] = \
44+
pd.to_datetime(output["Datetime"]).dt.tz_localize(None)
45+
correct_output_pd["Datetime"] = \
46+
pd.to_datetime(correct_output_pd["Datetime"]).dt.tz_localize(None)
47+
48+
pdt.assert_frame_equal(correct_output_pd, output)
49+
50+
def test_comparison_polars(self):
51+
52+
# Load the correct output in a polars dataframe
53+
correct_output_pl = pl.read_csv(self.get_correct_output_csv_path())
54+
55+
# Load the source in a polars dataframe
56+
source = pl.read_csv(self.get_source_csv_path())
57+
58+
# Generate the polars dataframe
59+
output = self.generate_polars_df(source)
60+
61+
# Convert the datetime columns to datetime
62+
# Convert the 'Datetime' column in both DataFrames to datetime
63+
output = output.with_columns(
64+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
65+
)
66+
67+
correct_output_pl = correct_output_pl.with_columns(
68+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
69+
)
70+
output = output[correct_output_pl.columns]
71+
output = self.make_polars_column_datetime_naive(output, "Datetime")
72+
correct_output_pl = self.make_polars_column_datetime_naive(
73+
correct_output_pl, "Datetime"
74+
)
75+
76+
assert_frame_equal(correct_output_pl, output)

0 commit comments

Comments
 (0)