Skip to content

Commit bb7d5f6

Browse files
committed
Refactor data source registration
1 parent 07f69f2 commit bb7d5f6

File tree

15 files changed

+961
-111
lines changed

15 files changed

+961
-111
lines changed

investing_algorithm_framework/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
TickerMarketDataSource, MarketService, BacktestReportsEvaluation, \
1111
pretty_print_backtest_reports_evaluation, load_backtest_reports, \
1212
RESERVED_BALANCES, APP_MODE, AppMode, DATETIME_FORMAT, \
13-
load_backtest_report, BacktestDateRange
13+
load_backtest_report, BacktestDateRange, create_ema_graph, \
14+
create_prices_graph, create_rsi_graph, get_price_efficiency_ratio
1415
from investing_algorithm_framework.infrastructure import \
1516
CCXTOrderBookMarketDataSource, CCXTOHLCVMarketDataSource, \
1617
CCXTTickerMarketDataSource, CSVOHLCVMarketDataSource, \
@@ -67,5 +68,9 @@
6768
"load_backtest_report",
6869
"BacktestDateRange",
6970
"create_trade_exit_markers_chart",
70-
"create_trade_entry_markers_chart"
71+
"create_trade_entry_markers_chart",
72+
"create_ema_graph",
73+
"create_prices_graph",
74+
"create_rsi_graph",
75+
"get_price_efficiency_ratio"
7176
]

investing_algorithm_framework/app/app.py

Lines changed: 85 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,12 @@ def _initialize_app_for_backtest(
254254
before running a backtest or a set of backtests and should be called
255255
once.
256256
257-
:param backtest_date_range: instance of BacktestDateRange
258-
:param pending_order_check_interval: The interval at which to check
259-
pending orders (e.g. 1h, 1d, 1w)
260-
:return: None
257+
Args:
258+
backtest_date_range: instance of BacktestDateRange
259+
pending_order_check_interval: The interval at which to check
260+
pending orders (e.g. 1h, 1d, 1w)
261+
262+
Return None
261263
"""
262264
# Set all config vars for backtesting
263265
configuration_service = self.container.configuration_service()
@@ -275,7 +277,18 @@ def _initialize_app_for_backtest(
275277
# Create resource dir if not exits
276278
self._create_resource_directory_if_not_exists()
277279

278-
def _initialize_algorithm_for_backtest(self, algorithm):
280+
def _create_backtest_database_if_not_exists(self):
281+
"""
282+
Create the backtest database if it does not exist. This method
283+
should be called before running a backtest for an algorithm.
284+
It creates the database if it does not exist.
285+
286+
Args:
287+
None
288+
289+
Returns
290+
None
291+
"""
279292
configuration_service = self.container.configuration_service()
280293
resource_dir = configuration_service.config[RESOURCE_DIRECTORY]
281294

@@ -301,15 +314,27 @@ def _initialize_algorithm_for_backtest(self, algorithm):
301314
setup_sqlalchemy(self)
302315
create_all_tables()
303316

304-
# Override the MarketDataSourceService service with the backtest
305-
# market data source service equivalent. Additionally, convert the
306-
# market data sources to backtest market data sources
307-
# Get all market data source services
308-
market_data_sources = self._market_data_source_service\
317+
def _initialize_backtest_data_sources(self, algorithm):
318+
"""
319+
Initialize the backtest data sources for the algorithm. This method
320+
should be called before running a backtest. It initializes the
321+
backtest data sources for the algorithm. It takes all registered
322+
data sources and converts them to backtest equivalents
323+
324+
Args:
325+
algorithm: The algorithm to initialize for backtesting
326+
327+
Returns
328+
None
329+
"""
330+
331+
market_data_sources = self._market_data_source_service \
309332
.get_market_data_sources()
333+
backtest_market_data_sources = []
310334

311335
if algorithm.data_sources is not None \
312336
and len(algorithm.data_sources) > 0:
337+
313338
for data_source in algorithm.data_sources:
314339
self.add_market_data_source(data_source)
315340

@@ -324,16 +349,36 @@ def _initialize_algorithm_for_backtest(self, algorithm):
324349
if market_data_source is not None:
325350
market_data_source.config = self.config
326351

327-
self.container.market_data_source_service.override(
328-
BacktestMarketDataSourceService(
329-
market_data_sources=backtest_market_data_sources,
330-
market_service=self.container.market_service(),
331-
market_credential_service=self.container
332-
.market_credential_service(),
333-
configuration_service=self.container
334-
.configuration_service(),
335-
)
352+
# Override the market data source service with the backtest market
353+
# data source service
354+
self.container.market_data_source_service.override(
355+
BacktestMarketDataSourceService(
356+
market_data_sources=backtest_market_data_sources,
357+
market_service=self.container.market_service(),
358+
market_credential_service=self.container
359+
.market_credential_service(),
360+
configuration_service=self.container
361+
.configuration_service(),
336362
)
363+
)
364+
365+
# Set all data sources to the algorithm
366+
algorithm.add_data_sources(backtest_market_data_sources)
367+
368+
def _initialize_algorithm_for_backtest(self, algorithm):
369+
"""
370+
Function to initialize the algorithm for backtesting. This method
371+
should be called before running a backtest. It initializes the
372+
all data sources to backtest data sources and overrides the services
373+
with the backtest services equivalents.
374+
375+
Args:
376+
algorithm: The algorithm to initialize for backtesting
377+
378+
Return None
379+
"""
380+
self._create_backtest_database_if_not_exists()
381+
self._initialize_backtest_data_sources(algorithm)
337382

338383
# Override the portfolio service with the backtest portfolio service
339384
self.container.portfolio_service.override(
@@ -385,7 +430,6 @@ def _initialize_algorithm_for_backtest(self, algorithm):
385430
market_credential_service = self.container.market_credential_service()
386431
market_data_source_service = \
387432
self.container.market_data_source_service()
388-
389433
# Initialize all services in the algorithm
390434
algorithm.initialize_services(
391435
configuration_service=self.container.configuration_service(),
@@ -444,17 +488,19 @@ def run(
444488
raises an OperationalException. Then it initializes the algorithm
445489
with the services and the configuration.
446490
447-
After the algorithm is initialized, it initializes the app and starts
448-
the algorithm. If the app is running in stateless mode, it handles the
491+
If the app is running in stateless mode, it handles the
449492
payload. If the app is running in web mode, it starts the web app in a
450493
separate thread.
451494
452-
:param payload: The payload to handle if the app is running in
453-
stateless mode
454-
:param number_of_iterations: The number of iterations to run the
455-
algorithm for
456-
:param sync: Whether to sync the portfolio with the exchange
457-
:return: None
495+
Args:
496+
payload: The payload to handle if the app is running in
497+
stateless mode
498+
number_of_iterations: The number of iterations to run the
499+
algorithm for
500+
sync: Whether to sync the portfolio with the exchange
501+
502+
Returns:
503+
None
458504
"""
459505

460506
# Run all on_initialize hooks
@@ -676,21 +722,21 @@ def run_backtest(
676722
Run a backtest for an algorithm. This method should be called when
677723
running a backtest.
678724
679-
:param algorithm: The algorithm to run a backtest for (instance of
680-
Algorithm)
681-
:param backtest_date_range: The date range to run the backtest for
682-
(instance of BacktestDateRange)
683-
:param pending_order_check_interval: The interval at which to check
684-
pending orders
685-
:param output_directory: The directory to write the backtest report to
686-
:return: Instance of BacktestReport
725+
Args:
726+
algorithm: The algorithm to run a backtest for (instance of
727+
Algorithm)
728+
backtest_date_range: The date range to run the backtest for
729+
(instance of BacktestDateRange)
730+
pending_order_check_interval: The interval at which to check
731+
pending orders
732+
output_directory: The directory to write the backtest report to
733+
734+
Returns:
735+
Instance of BacktestReport
687736
"""
688737
logger.info("Initializing backtest")
689738
self.algorithm = algorithm
690739

691-
market_data_sources = self._market_data_source_service\
692-
.get_market_data_sources()
693-
694740
self._initialize_app_for_backtest(
695741
backtest_date_range=backtest_date_range,
696742
pending_order_check_interval=pending_order_check_interval,

investing_algorithm_framework/app/strategy.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from investing_algorithm_framework.domain import \
33
TimeUnit, StrategyProfile, Trade
44
from .algorithm import Algorithm
5+
import pandas as pd
56

67

78
class TradingStrategy:
@@ -11,6 +12,7 @@ class TradingStrategy:
1112
strategy_id: str = None
1213
decorated = None
1314
market_data_sources = None
15+
traces = None
1416

1517
def __init__(
1618
self,
@@ -46,6 +48,8 @@ def __init__(
4648

4749
if strategy_id is not None:
4850
self.strategy_id = strategy_id
51+
else:
52+
self.strategy_id = self.worker_id
4953

5054
# Check if time_unit is None
5155
if self.time_unit is None:
@@ -59,6 +63,9 @@ def __init__(
5963
f"Interval not set for strategy instance {self.strategy_id}"
6064
)
6165

66+
# context initialization
67+
self._context = None
68+
6269
def run_strategy(self, algorithm, market_data):
6370
# Check pending orders before running the strategy
6471
algorithm.check_pending_orders()
@@ -135,3 +142,72 @@ def strategy_identifier(self):
135142
return self.strategy_id
136143

137144
return self.worker_id
145+
146+
@property
147+
def context(self):
148+
return self._context
149+
150+
@context.setter
151+
def context(self, context):
152+
self._context = context
153+
154+
def add_trace(
155+
self,
156+
symbol: str,
157+
data,
158+
drop_duplicates=True
159+
) -> None:
160+
"""
161+
Add data to the straces object for a given symbol
162+
163+
Args:
164+
symbol (str): The symbol
165+
data (pd.DataFrame): The data to add to the tracing
166+
drop_duplicates (bool): Drop duplicates
167+
168+
Returns:
169+
None
170+
"""
171+
172+
# Check if data is a DataFrame
173+
if not isinstance(data, pd.DataFrame):
174+
raise ValueError(
175+
"Currently only pandas DataFrames are "
176+
"supported as tracing data objects."
177+
)
178+
179+
data: pd.DataFrame = data
180+
181+
# Check if index is a datetime object
182+
if not isinstance(data.index, pd.DatetimeIndex):
183+
raise ValueError("Dataframe Index must be a datetime object.")
184+
185+
if self.traces is None:
186+
self.traces = {}
187+
188+
# Check if the key is already in the context dictionary
189+
if symbol in self.traces:
190+
# If the key is already in the context dictionary,
191+
# append the new data to the existing data
192+
combined = pd.concat([self.traces[symbol], data])
193+
else:
194+
# If the key is not in the context dictionary,
195+
# add the new data to the context dictionary
196+
combined = data
197+
198+
if drop_duplicates:
199+
# Drop duplicates and sort the data by the index
200+
combined = combined[~combined.index.duplicated(keep='first')]
201+
202+
# Set the datetime column as the index
203+
combined.set_index(pd.DatetimeIndex(combined.index), inplace=True)
204+
self.traces[symbol] = combined
205+
206+
def get_traces(self) -> dict:
207+
"""
208+
Get the traces object
209+
210+
Returns:
211+
dict: The traces object
212+
"""
213+
return self.traces

investing_algorithm_framework/domain/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
load_backtest_report, \
3131
csv_to_list, StoppableThread, pretty_print_backtest_reports_evaluation, \
3232
pretty_print_backtest, load_csv_into_dict, load_backtest_reports
33+
from .graphs import create_prices_graph, create_ema_graph, create_rsi_graph
34+
from .metrics import get_price_efficiency_ratio
3335

3436
__all__ = [
3537
'Config',
@@ -114,4 +116,8 @@
114116
"RoundingService",
115117
"BacktestDateRange",
116118
"load_backtest_report",
119+
"create_prices_graph",
120+
"create_ema_graph",
121+
"create_rsi_graph",
122+
"get_price_efficiency_ratio"
117123
]

0 commit comments

Comments
 (0)