Skip to content

Commit fc3f586

Browse files
committed
Fix flake8 warnings
1 parent 6702bc1 commit fc3f586

File tree

3 files changed

+2416
-2
lines changed

3 files changed

+2416
-2
lines changed

pyindicators/indicators/macd.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def macd(
3939
4040
Returns:
4141
Union[pd.DataFrame, pl.DataFrame]: DataFrame with MACD, Signal
42-
Line, and Histogram.
42+
Line, and Histogram.
4343
"""
4444
if source_column not in data.columns:
4545
raise PyIndicatorException(
@@ -60,9 +60,47 @@ def macd(
6060

6161
# Calculate the MACD Histogram
6262
data[histogram_column] = data[macd_column] - data[signal_column]
63+
64+
# Delete the temporary EMA columns
65+
data = data.drop(columns=[f"EMA_{short_period}", f"EMA_{long_period}"])
6366
return data
6467
elif isinstance(data, pl.DataFrame):
65-
return None
68+
# Polars implementation
69+
data = data.with_columns([
70+
ema(
71+
data,
72+
source_column,
73+
short_period,
74+
f"EMA_{short_period}"
75+
)[f"EMA_{short_period}"],
76+
ema(
77+
data,
78+
source_column,
79+
long_period,
80+
f"EMA_{long_period}"
81+
)[f"EMA_{long_period}"]
82+
])
83+
84+
data = data.with_columns(
85+
(
86+
pl.col(f"EMA_{short_period}") - pl.col(f"EMA_{long_period}")
87+
).alias(macd_column)
88+
)
89+
90+
data = data.with_columns(
91+
ema(data, macd_column, signal_period, signal_column)[signal_column]
92+
)
93+
94+
data = data.with_columns(
95+
(
96+
pl.col(macd_column) - pl.col(signal_column)
97+
).alias(histogram_column)
98+
)
99+
100+
# Delete the temporary EMA columns
101+
data = data.drop([f"EMA_{short_period}", f"EMA_{long_period}"])
102+
103+
return data
66104
else:
67105
raise PyIndicatorException(
68106
"Unsupported DataFrame type. Use Pandas or Polars."

tests/indicators/test_macd.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pandas as pd
2+
import polars as pl
3+
import pandas.testing as pdt
4+
from polars.testing import assert_frame_equal
5+
6+
from tests.resources import TestBaseline
7+
from pyindicators import macd
8+
9+
10+
class Test(TestBaseline):
11+
correct_output_csv_filename = \
12+
"MACD_BTC-EUR_BINANCE_15m_2023-12-01-00-00_2023-12-25-00-00.csv"
13+
14+
def generate_pandas_df(self, polars_source_df):
15+
polars_source_df = macd(
16+
data=polars_source_df, source_column="Close"
17+
)
18+
return polars_source_df
19+
20+
def generate_polars_df(self, pandas_source_df):
21+
pandas_source_df = macd(
22+
data=pandas_source_df, source_column="Close"
23+
)
24+
return pandas_source_df
25+
26+
def test_comparison_pandas(self):
27+
28+
# Load the correct output in a pandas dataframe
29+
correct_output_pd = pd.read_csv(self.get_correct_output_csv_path())
30+
31+
# Load the source in a pandas dataframe
32+
source = pd.read_csv(self.get_source_csv_path())
33+
34+
# Generate the pandas dataframe
35+
output = self.generate_pandas_df(source)
36+
output = output[correct_output_pd.columns]
37+
output["Datetime"] = \
38+
pd.to_datetime(output["Datetime"]).dt.tz_localize(None)
39+
correct_output_pd["Datetime"] = \
40+
pd.to_datetime(correct_output_pd["Datetime"]).dt.tz_localize(None)
41+
42+
pdt.assert_frame_equal(correct_output_pd, output)
43+
44+
def test_comparison_polars(self):
45+
46+
# Load the correct output in a polars dataframe
47+
correct_output_pl = pl.read_csv(self.get_correct_output_csv_path())
48+
49+
# Load the source in a polars dataframe
50+
source = pl.read_csv(self.get_source_csv_path())
51+
52+
# Generate the polars dataframe
53+
output = self.generate_polars_df(source)
54+
55+
# Convert the datetime columns to datetime
56+
# Convert the 'Datetime' column in both DataFrames to datetime
57+
output = output.with_columns(
58+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
59+
)
60+
61+
correct_output_pl = correct_output_pl.with_columns(
62+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
63+
)
64+
output = output[correct_output_pl.columns]
65+
output = self.make_polars_column_datetime_naive(output, "Datetime")
66+
correct_output_pl = self.make_polars_column_datetime_naive(
67+
correct_output_pl, "Datetime"
68+
)
69+
70+
assert_frame_equal(correct_output_pl, output)

0 commit comments

Comments
 (0)