Skip to content

PR: Improve backtesting speed #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/backtest_example/run_backtest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import logging.config
from datetime import datetime, timedelta

Expand Down Expand Up @@ -156,13 +157,17 @@ def apply_strategy(self, algorithm: Algorithm, market_data):

if __name__ == "__main__":
end_date = datetime(2023, 12, 2)
start_date = end_date - timedelta(days=100)
start_date = end_date - timedelta(days=400)
date_range = BacktestDateRange(
start_date=start_date,
end_date=end_date
)
start_time = time.time()

backtest_report = app.run_backtest(
algorithm=algorithm,
backtest_date_range=date_range,
)
pretty_print_backtest(backtest_report)
end_time = time.time()
print(f"Execution Time: {end_time - start_time:.6f} seconds")
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
from datetime import timedelta, datetime, timezone
from dateutil.parser import parse
import polars
from dateutil import parser

Expand Down Expand Up @@ -57,6 +56,9 @@ def __init__(
self.data = None
self._start_date_data_source = None
self._end_date_data_source = None
self.backtest_end_index = self.window_size
self.backtest_start_index = 0
self.window_cache = {}

def prepare_data(
self,
Expand Down Expand Up @@ -100,8 +102,6 @@ def prepare_data(

self.backtest_data_start_date = backtest_data_start_date\
.replace(microsecond=0)
self.backtest_data_index_date = backtest_data_start_date\
.replace(microsecond=0)
self.backtest_data_end_date = backtest_end_date.replace(microsecond=0)

# Creating the backtest data directory and file
Expand Down Expand Up @@ -148,14 +148,30 @@ def prepare_data(
self.write_data_to_file_path(file_path, ohlcv)

self.load_data()
self._precompute_sliding_windows() # Precompute sliding windows!

def _precompute_sliding_windows(self):
"""
Precompute all sliding windows for fast retrieval.
"""
self.window_cache = {}
timestamps = self.data["Datetime"].to_list()

for i in range(len(timestamps) - self.window_size + 1):
# Use last timestamp as key
end_time = timestamps[i + self.window_size - 1]
self.window_cache[end_time] = self.data.slice(i, self.window_size)

def load_data(self):
file_path = self._create_file_path()
self.data = polars.read_csv(file_path)
self.data = polars.read_csv(
file_path, dtypes={"Datetime": polars.Datetime}, low_memory=True
) # Faster parsing
first_row = self.data.head(1)
last_row = self.data.tail(1)
self._start_date_data_source = parse(first_row["Datetime"][0])
self._end_date_data_source = parse(last_row["Datetime"][0])

self._start_date_data_source = first_row["Datetime"][0]
self._end_date_data_source = last_row["Datetime"][0]

def _create_file_path(self):
"""
Expand Down Expand Up @@ -190,38 +206,21 @@ def get_data(
source. This implementation will use polars to load and filter the
data.
"""
if self.data is None:
self.load_data()

end_date = date

if end_date is None:
return self.data
data = self.window_cache.get(date)
if data is not None:
return data

start_date = self.create_start_date(
end_date, self.time_frame, self.window_size
)
# Find closest previous timestamp
sorted_timestamps = sorted(self.window_cache.keys())

if start_date < self._start_date_data_source:
raise OperationalException(
f"Start date {start_date} is before the start date "
f"of the data source {self._start_date_data_source}"
)
closest_date = None
for ts in reversed(sorted_timestamps):
if ts < date:
closest_date = ts
break

if end_date > self._end_date_data_source:
raise OperationalException(
f"End date {end_date} is after the end date "
f"of the data source {self._end_date_data_source}"
)

time_frame = TimeFrame.from_string(self.time_frame)
start_date = start_date - \
timedelta(minutes=time_frame.amount_of_minutes)
selection = self.data.filter(
(self.data['Datetime'] >= start_date.strftime(DATETIME_FORMAT))
& (self.data['Datetime'] <= end_date.strftime(DATETIME_FORMAT))
)
return selection
return self.window_cache.get(closest_date) if closest_date else None

def to_backtest_market_data_source(self) -> BacktestMarketDataSource:
# Ignore this method for now
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(
self._columns = [
"Datetime", "Open", "High", "Low", "Close", "Volume"
]
df = polars.read_csv(csv_file_path)

df = polars.read_csv(self._csv_file_path)

# Check if all column names are in the csv file
if not all(column in df.columns for column in self._columns):
Expand All @@ -53,15 +54,25 @@ def __init__(
f"Missing columns: {missing_columns}"
)

first_row = df.head(1)
last_row = df.tail(1)
self._start_date_data_source = parse(first_row["Datetime"][0])
self._end_date_data_source = parse(last_row["Datetime"][0])
self.data = self._load_data(self.csv_file_path)
first_row = self.data.head(1)
last_row = self.data.tail(1)
self._start_date_data_source = first_row["Datetime"][0]
self._end_date_data_source = last_row["Datetime"][0]

@property
def csv_file_path(self):
return self._csv_file_path

def _load_data(self, file_path):
return polars.read_csv(
file_path, dtypes={"Datetime": polars.Datetime}, low_memory=True
).with_columns(
polars.col("Datetime").cast(
polars.Datetime(time_unit="ms", time_zone="UTC")
)
)

def get_data(
self,
start_date: datetime = None,
Expand All @@ -86,9 +97,7 @@ def get_data(
"""

if start_date is None and end_date is None:
return polars.read_csv(
self.csv_file_path, columns=self._columns, separator=","
)
return self.data

if end_date is not None and start_date is not None:

Expand All @@ -101,13 +110,10 @@ def get_data(
if start_date > self._end_date_data_source:
return polars.DataFrame()

df = polars.read_csv(
self.csv_file_path, columns=self._columns, separator=","
)

df = self.data
df = df.filter(
(df['Datetime'] >= start_date.strftime(DATETIME_FORMAT))
& (df['Datetime'] <= end_date.strftime(DATETIME_FORMAT))
(df['Datetime'] >= start_date)
& (df['Datetime'] <= end_date)
)
return df

Expand All @@ -119,11 +125,9 @@ def get_data(
if start_date > self._end_date_data_source:
return polars.DataFrame()

df = polars.read_csv(
self.csv_file_path, columns=self._columns, separator=","
)
df = self.data
df = df.filter(
(df['Datetime'] >= start_date.strftime(DATETIME_FORMAT))
(df['Datetime'] >= start_date)
)
df = df.head(self.window_size)
return df
Expand All @@ -136,11 +140,9 @@ def get_data(
if end_date > self._end_date_data_source:
return polars.DataFrame()

df = polars.read_csv(
self.csv_file_path, columns=self._columns, separator=","
)
df = self.data
df = df.filter(
(df['Datetime'] <= end_date.strftime(DATETIME_FORMAT))
(df['Datetime'] <= end_date)
)
df = df.tail(self.window_size)
return df
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from queue import PriorityQueue
from dateutil import parser

from investing_algorithm_framework.domain import OrderStatus, \
TradeStatus, Trade, OperationalException, MarketDataType
Expand Down Expand Up @@ -248,9 +247,7 @@ def update_trades_with_market_data(self, market_data):
last_row = data.tail(1)
update_data = {
"last_reported_price": last_row["Close"][0],
"updated_at": parser.parse(
last_row["Datetime"][0]
)
"updated_at": last_row["Datetime"][0]
}
price = last_row["Close"][0]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from unittest import TestCase

from dateutil import parser
Expand Down Expand Up @@ -42,7 +42,7 @@ def test_right_columns(self):
f"{file_name}",
window_size=10,
)
date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc())
date = datetime(2023, 8, 7, 8, 0, tzinfo=timezone.utc)
df = data_source.get_data(start_date=date)
self.assertEqual(
["Datetime", "Open", "High", "Low", "Close", "Volume"], df.columns
Expand All @@ -61,7 +61,7 @@ def test_throw_exception_when_missing_column_names_columns(self):
)

def test_start_date(self):
start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc())
start_date = datetime(2023, 8, 7, 8, 0, tzinfo=timezone.utc)
file_name = "OHLCV_BTC-EUR_BINANCE" \
"_2h_2023-08-07-07-59_2023-12-02-00-00.csv"
csv_ohlcv_market_data_source = CSVOHLCVMarketDataSource(
Expand All @@ -78,7 +78,7 @@ def test_start_date(self):

def test_start_date_with_window_size(self):
start_date = datetime(
year=2023, month=8, day=7, hour=10, minute=0, tzinfo=tzutc()
year=2023, month=8, day=7, hour=10, minute=0, tzinfo=timezone.utc
)
file_name = "OHLCV_BTC-EUR_BINANCE" \
"_2h_2023-08-07-07-59_2023-12-02-00-00.csv"
Expand All @@ -92,7 +92,7 @@ def test_start_date_with_window_size(self):
start_date=start_date
)
self.assertEqual(12, len(data))
first_date = parser.parse(data["Datetime"][0])
first_date = data["Datetime"][0]
self.assertEqual(
start_date.strftime(DATETIME_FORMAT),
first_date.strftime(DATETIME_FORMAT)
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_empty(self):
f"{file_name}",
window_size=10,
)
start_date = datetime(2023, 12, 2, 0, 0, tzinfo=tzutc())
start_date = datetime(2023, 12, 2, 0, 0, tzinfo=timezone.utc)
self.assertFalse(data_source.empty(start_date))

def test_get_data(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
from unittest import TestCase
from datetime import datetime, timedelta
from investing_algorithm_framework.infrastructure.models\
.market_data_sources.ccxt import CCXTOHLCVBacktestMarketDataSource
from investing_algorithm_framework.domain import RESOURCE_DIRECTORY, \
BACKTEST_DATA_DIRECTORY_NAME


class TestCCXTOHLCVBacktestDataSource(TestCase):

def setUp(self):
self.resource_dir = os.path.abspath(
os.path.join(
os.path.join(
os.path.join(
os.path.join(
os.path.join(
os.path.realpath(__file__),
os.pardir
),
os.pardir
),
os.pardir
),
os.pardir
),
"resources"
)
)
self.backtest_data_dir = "market_data_sources_for_testing"

def test_prepare_data(self):
pass

def test_get_data(self):
data_source = CCXTOHLCVBacktestMarketDataSource(
identifier="bitvavo",
market="BITVAVO",
symbol="BTC/EUR",
time_frame="2h",
window_size=200,
)
config = {
RESOURCE_DIRECTORY: self.resource_dir,
BACKTEST_DATA_DIRECTORY_NAME: self.backtest_data_dir
}
data_source.prepare_data(
config=config,
backtest_start_date=datetime(2021, 1, 1), backtest_end_date=datetime(2025, 1, 1)
)
number_of_data_retrievals = 0
backtest_start_date = datetime(2021, 1, 1)
backtest_end_date = datetime(2022, 1, 1)
interval = timedelta(hours=2) # Define the 2-hour interval
current_date = backtest_start_date
delta = backtest_end_date - backtest_start_date
runs = (delta.total_seconds() / 7200) + 1

while current_date <= backtest_end_date:
data = data_source.get_data(date=current_date)

if data is not None:
number_of_data_retrievals += 1
self.assertTrue(abs(200 - len(data)) <= 4)

current_date += interval # Increment by 2 hours

self.assertEqual(runs, number_of_data_retrievals)
Loading
Loading