This repository demonstrates how to create and train a reinforcement learning (RL) agent to trade cryptocurrency using historical data. The project leverages advanced RL techniques, including a Long Short-Term Memory (LSTM) policy with Proximal Policy Optimization (PPO). It also features a custom-built trading environment created entirely from scratch and compatible with the OpenAI Gym API.
- Custom Trading Environment: Implements a fully customizable trading simulation environment.
- Advanced RL Algorithm: Uses the LSTM-based PPO algorithm from the Stable-Baselines3 Contrib library.
- Yahoo Finance Integration: Fetches and preprocesses cryptocurrency data dynamically.
- Custom Indicators: Includes SMA, RSI, and OBV as features for the agent.
- Detailed Evaluation: Tracks and visualizes rewards, profit percentages, and final net worth over time.
Install the required libraries:
pip install stable-baselines3 gym gymnasium finta yfinance
- Source: Yahoo Finance
- Data Features:
- Open, High, Low, Close, Volume
- Simple Moving Average (SMA)
- Relative Strength Index (RSI)
- On-Balance Volume (OBV)
btc_data = yf.download('BTC-USD', start='2024-01-01', end='2025-01-15')
btc_data['SMA'] = TA.SMA(btc_data, 12)
btc_data['RSI'] = TA.RSI(btc_data)
btc_data['OBV'] = TA.OBV(btc_data)
btc_data.fillna(0, inplace=True)
The environment is built entirely from scratch using the Gym API. Key components include:
- Actions: Buy, Sell
- Positions: Short, Long
- Reward Calculation: Based on price differences and position changes.
- Profit Tracking: Tracks cumulative profits and calculates gain percentage.
- Render functions to visualize trades.
- Customizable starting balance and trade fees.
class TradingEnv(gym.Env):
def __init__(self, df, window_size, render_mode=None):
self.df = df
self.window_size = window_size
self.prices, self.signal_features = self._process_data()
self.action_space = gym.spaces.Discrete(2) # Buy, Sell
self.observation_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(window_size, self.signal_features.shape[1]), dtype=np.float32
)
self.render_mode = render_mode
self._reset()
def step(self, action):
# Implement trading logic, reward calculation, and profit tracking
pass
def reset(self):
# Reset environment state
pass
def _process_data(self):
# Preprocess input data (prices and indicators)
pass
def render(self, mode='human'):
# Visualize trades and profits
pass
- Algorithm: Recurrent PPO
- Policy: MlpLstmPolicy
- Training Parameters:
- Learning Rate:
2e-4
- Number of Steps:
4096
- Batch Size:
32
- Entropy Coefficient:
0.02
- Learning Rate:
ppo_params = {
"policy": "MlpLstmPolicy",
"env": vec_env,
"learning_rate": 2e-4,
"n_steps": 4096,
"batch_size": 32,
"ent_coef": 0.02,
"verbose": 1
}
ppo_model = RecurrentPPO(**ppo_params)
ppo_model.learn(total_timesteps=100000, callback=eval_callback)
ppo_model.save("ppo_lstm_crypto_model")
The trained model is tested in a new environment using unseen data. Key metrics such as cumulative reward, profit percentage, and final net worth are visualized.
def testing_env_and_model(model):
test_btc_data = preprocess_data()
test_env = TradingEnv(df=test_btc_data, window_size=5)
obs, info = test_env.reset()
while True:
action, _ = model.predict(obs)
obs, reward, done, truncated, info = test_env.step(action)
if done or truncated:
break
Final results are visualized to display:
- Rewards over time
- Profit percentage
- Final net worth
Below is the final plot showing rewards and profit trends over time:
plt.figure(figsize=(15,6))
plt.cla()
test_env.render_all()
plt.show()
|-- data/ # Directory for storing raw and processed data
|-- models/ # Trained models
|-- scripts/ # Scripts for training and evaluation
|-- notebooks/ # Jupyter notebooks for experimentation
|-- README.md # Project documentation (this file)
|-- requirements.txt # Python dependencies
- Clone the repository:
git clone https://github.com/your-repo-name.git
- Install dependencies:
pip install -r requirements.txt
- Run the Jupyter notebook or script to train the model.
This project is licensed under the MIT License - see the LICENSE file for details.