Skip to content

Commit 1546e10

Browse files
committed
Add stochastic oscillator tests and data
1 parent 8bb99d0 commit 1546e10

File tree

2 files changed

+2388
-0
lines changed

2 files changed

+2388
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import pandas as pd
2+
import polars as pl
3+
import pandas.testing as pdt
4+
from polars.testing import assert_frame_equal
5+
6+
from tests.resources import TestBaseline
7+
from pyindicators import stochastic_oscillator
8+
9+
10+
class Test(TestBaseline):
11+
correct_output_csv_filename = \
12+
"STOCHASTIC_OSCILLATOR_BTC-EUR_BINANCE_15m_2023-12-01:00:00_2023-12-25:00:00.csv"
13+
14+
def generate_pandas_df(self, polars_source_df):
15+
polars_source_df = stochastic_oscillator(
16+
data=polars_source_df,
17+
high_column="High",
18+
low_column="Low",
19+
close_column="Close",
20+
k_period=4,
21+
d_period=3,
22+
k_slowing=10
23+
)
24+
return polars_source_df
25+
26+
def generate_polars_df(self, pandas_source_df):
27+
pandas_source_df = stochastic_oscillator(
28+
data=pandas_source_df,
29+
high_column="High",
30+
low_column="Low",
31+
close_column="Close",
32+
k_period=4,
33+
d_period=3,
34+
k_slowing=10
35+
)
36+
return pandas_source_df
37+
38+
def test_comparison_pandas(self):
39+
40+
# Load the correct output in a pandas dataframe
41+
correct_output_pd = pd.read_csv(self.get_correct_output_csv_path())
42+
43+
# Load the source in a pandas dataframe
44+
source = pd.read_csv(self.get_source_csv_path())
45+
46+
# Generate the pandas dataframe
47+
output = self.generate_pandas_df(source)
48+
output = output[correct_output_pd.columns]
49+
output["Datetime"] = \
50+
pd.to_datetime(output["Datetime"]).dt.tz_localize(None)
51+
correct_output_pd["Datetime"] = \
52+
pd.to_datetime(correct_output_pd["Datetime"]).dt.tz_localize(None)
53+
54+
pdt.assert_frame_equal(correct_output_pd, output)
55+
56+
def test_comparison_polars(self):
57+
58+
# Load the correct output in a polars dataframe
59+
correct_output_pl = pl.read_csv(self.get_correct_output_csv_path())
60+
61+
# Load the source in a polars dataframe
62+
source = pl.read_csv(self.get_source_csv_path())
63+
64+
# Generate the polars dataframe
65+
output = self.generate_polars_df(source)
66+
67+
# Convert the datetime columns to datetime
68+
# Convert the 'Datetime' column in both DataFrames to datetime
69+
output = output.with_columns(
70+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
71+
)
72+
73+
correct_output_pl = correct_output_pl.with_columns(
74+
pl.col("Datetime").str.strptime(pl.Datetime).alias("Datetime")
75+
)
76+
output = output[correct_output_pl.columns]
77+
output = self.make_polars_column_datetime_naive(output, "Datetime")
78+
correct_output_pl = self.make_polars_column_datetime_naive(
79+
correct_output_pl, "Datetime"
80+
)
81+
82+
assert_frame_equal(correct_output_pl, output)

0 commit comments

Comments
 (0)