This libaray allows you to easily train agents built with Keras or PyTorch using reinforcement learning. You just need to have your agent class inherit from the RL or RL_pytorch class, and you can easily train your agent built with Keras or PyTorch. You can learn how to build an agent from the examples here. The README shows how to train, save, and restore agent built with Keras or PyTorch.
To use this library, you need to download it and then unzip it to the site-packages folder of your Python environment.
dependent packages:
tensorflow>=2.16.1
pytorch>=2.3.1
gym<=0.25.2
matplotlib>=3.8.4
python requirement:
python>=3.10
Keras: Agent built with Keras.
import tensorflow as tf
from Note_rl.policy import EpsGreedyQPolicy
from Note_rl.examples.keras.DQN import DQN
model=DQN(4,128,2)
model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=64,update_steps=10)
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
model.train(train_loss, optimizer, 100, pool_network=False)
# If set criterion.
# model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=64,update_steps=10,trial_count=10,criterion=200)
# model.train(train_loss, optimizer, 100, pool_network=False)
# If save the model at intervals of 10 episode, with a maximum of 2 saved file, and the file name is model.dat.
# model.path='model.dat'
# model.save_freq=10
# model. max_save_files=2
# model.train(train_loss, optimizer, 100, pool_network=False)
# If save parameters only
# model.path='param.dat'
# model.save_freq=10
# model. max_save_files=2
# model.save_param_only=True
# model.train(train_loss, optimizer, 100, pool_network=False)
# If save best only
# model.path='model.dat'
# model.save_best_only=True
# model.train(train_loss, optimizer, 100, pool_network=False)
# visualize
# model.visualize_loss()
# model.visualize_reward()
# model.visualize_reward_loss()
# animate agent
# model.animate_agent(200)
# save
# model.save_param('param.dat')
# model.save('model.dat')
# Use PPO.
import tensorflow as tf
from Note_rl.policy import SoftmaxPolicy
from Note_rl.examples.keras.PPO import PPO
model=PPO(4,128,2,0.7,0.7)
model.set(policy=SoftmaxPolicy(),pool_size=10000,batch=64,update_steps=1000,PPO=True)
optimizer = [tf.keras.optimizers.Adam(1e-4),tf.keras.optimizers.Adam(5e-3)]
train_loss = tf.keras.metrics.Mean(name='train_loss')
model.train(train_loss, optimizer, 100, pool_network=False)
# Use HER.
import tensorflow as tf
from Note_rl.noise import GaussianWhiteNoiseProcess
from Note_rl.examples.keras.DDPG_HER import DDPG
model=DDPG(128,0.1,0.98,0.005)
model.set(noise=GaussianWhiteNoiseProcess(),pool_size=10000,batch=256,criterion=-5,trial_count=10,HER=True)
optimizer = [tf.keras.optimizers.Adam(),tf.keras.optimizers.Adam()]
train_loss = tf.keras.metrics.Mean(name='train_loss')
model.train(train_loss, optimizer, 2000, pool_network=False)
# Use Multi-agent reinforcement learning.
import tensorflow as tf
from Note_rl.policy import SoftmaxPolicy
from Note_rl.examples.keras.MADDPG import DDPG
model=DDPG(128,0.1,0.98,0.005)
model.set(policy=SoftmaxPolicy(),pool_size=3000,batch=32,trial_count=10,MARL=True)
optimizer = [tf.keras.optimizers.Adam(),tf.keras.optimizers.Adam()]
train_loss = tf.keras.metrics.Mean(name='train_loss')
model.train(train_loss, optimizer, 100, pool_network=False)
# This technology uses Python’s multiprocessing module to speed up trajectory collection and storage, I call it Pool Network.
import tensorflow as tf
from Note_rl.policy import EpsGreedyQPolicy
from Note_rl.examples.keras.pool_network.DQN import DQN
model=DQN(4,128,2,7)
model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,update_batches=17)
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
model.train(train_loss, optimizer, 100, pool_network=True, processes=7)
PyTorch: Agent built with PyTorch.
import torch
from Note_rl.policy import EpsGreedyQPolicy
from Note_rl.examples.pytorch.DQN import DQN
model=DQN(4,128,2)
model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=64,update_steps=10)
optimizer = torch.optim.Adam(model.param)
model.train(optimizer, 100)
# If set criterion.
# model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=64,update_steps=10,trial_count=10,criterion=200)
# model.train(optimizer, 100)
# If use prioritized replay.
# model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=64,update_steps=10,trial_count=10,criterion=200,PR=True,initial_TD=7,alpha=0.7)
# model.train(optimizer, 100)
# If save the model at intervals of 10 episode, with a maximum of 2 saved file, and the file name is model.dat.
# model.path='model.dat'
# model.save_freq=10
# model. max_save_files=2
# model.train(optimizer, 100)
# If save parameters only
# model.path='param.dat'
# model.save_freq=10
# model. max_save_files=2
# model.save_param_only=True
# model.train(optimizer, 100)
# If save best only
# model.path='model.dat'
# model.save_best_only=True
# model.train(optimizer, 100)
# visualize
# model.visualize_loss()
# model.visualize_reward()
# model.visualize_reward_loss()
# animate agent
# model.animate_agent(200)
# save
# model.save_param('param.dat')
# model.save('model.dat')
# Use HER.
import torch
from Note_rl.noise import GaussianWhiteNoiseProcess
from Note_rl.examples.pytorch.DDPG_HER import DDPG
model=DDPG(128,0.1,0.98,0.005)
model.set(noise=GaussianWhiteNoiseProcess(),pool_size=10000,batch=256,criterion=-5,trial_count=10,HER=True)
optimizer = [torch.optim.Adam(model.param[0]),torch.optim.Adam(model.param[1])]
model.train(optimizer, 2000)
# Use Multi-agent reinforcement learning.
import torch
from Note_rl.policy import SoftmaxPolicy
from Note_rl.examples.pytorch.MADDPG import DDPG
model=DDPG(128,0.1,0.98,0.005)
model.set(policy=SoftmaxPolicy(),pool_size=3000,batch=32,trial_count=10,MARL=True)
optimizer = [torch.optim.Adam(model.param[0]),torch.optim.Adam(model.param[1])]
model.train(optimizer, 100)
# This technology uses Python’s multiprocessing module to speed up trajectory collection and storage, I call it Pool Network.
import torch
from Note_rl.policy import EpsGreedyQPolicy
from Note_rl.examples.pytorch.pool_network.DQN import DQN
model=DQN(4,128,2,7)
model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=64,update_batches=17)
optimizer = torch.optim.Adam(model.param)
model.train(optimizer, 100, pool_network=True, processes=7)
# Use HER.
# This technology uses Python’s multiprocessing module to speed up trajectory collection and storage, I call it Pool Network.
# Furthermore use Python’s multiprocessing module to speed up getting a batch of data.
import torch
from Note_rl.noise import GaussianWhiteNoiseProcess
from Note_rl.examples.pytorch.pool_network.DDPG_HER import DDPG
model=DDPG(128,0.1,0.98,0.005,7)
model.set(noise=GaussianWhiteNoiseProcess(),pool_size=10000,batch=256,trial_count=10,HER=True)
optimizer = [torch.optim.Adam(model.param[0]),torch.optim.Adam(model.param[1])]
model.train(train_loss, optimizer, 2000, pool_network=True, processes=7, processes_her=4)
MirroredStrategy: Agent built with Keras.
import tensorflow as tf
from Note_rl.policy import EpsGreedyQPolicy
from Note_rl.examples.keras.DQN import DQN
strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
with strategy.scope():
model=DQN(4,128,2)
optimizer = tf.keras.optimizers.Adam()
model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=64,update_steps=10)
model.distributed_training(GLOBAL_BATCH_SIZE, optimizer, strategy, 100, pool_network=False)
# If set criterion.
# model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=GLOBAL_BATCH_SIZE,update_steps=10,trial_count=10,criterion=200)
# model.distributed_training(optimizer, strategy, 100, pool_network=False)
# If save the model at intervals of 10 episode, with a maximum of 2 saved file, and the file name is model.dat.
# model.path='model.dat'
# model.save_freq=10
# model. max_save_files=2
# model.distributed_training(optimizer, strategy, 100, pool_network=False)
# If save parameters only
# model.path='param.dat'
# model.save_freq=10
# model. max_save_files=2
# model.save_param_only=True
# model.distributed_training(optimizer, strategy, 100, pool_network=False)
# If save best only
# model.path='model.dat'
# model.save_best_only=True
# model.distributed_training(optimizer, strategy, 100, pool_network=False)
# visualize
# model.visualize_loss()
# model.visualize_reward()
# model.visualize_reward_loss()
# animate agent
# model.animate_agent(200)
# save
# model.save_param('param.dat')
# model.save('model.dat')
# Use PPO
import tensorflow as tf
from Note_rl.policy import SoftmaxPolicy
from Note_rl.examples.keras.PPO import PPO
strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
with strategy.scope():
model=PPO(4,128,2,0.7,0.7)
optimizer = [tf.keras.optimizers.Adam(1e-4),tf.keras.optimizers.Adam(5e-3)]
model.set(policy=SoftmaxPolicy(),pool_size=10000,batch=GLOBAL_BATCH_SIZE,update_steps=1000,PPO=True)
model.distributed_training(optimizer, strategy, 100, pool_network=False)
# Use HER.
import tensorflow as tf
from Note_rl.noise import GaussianWhiteNoiseProcess
from Note_rl.examples.keras.DDPG_HER import DDPG
strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = 256
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
with strategy.scope():
model=DDPG(128,0.1,0.98,0.005)
optimizer = [tf.keras.optimizers.Adam(),tf.keras.optimizers.Adam()]
model.set(noise=GaussianWhiteNoiseProcess(),pool_size=10000,batch=GLOBAL_BATCH_SIZE,criterion=-5,trial_count=10,HER=True)
model.distributed_training(optimizer, strategy, 2000, pool_network=False)
# Use Multi-agent reinforcement learning.
import tensorflow as tf
from Note_rl.policy import SoftmaxPolicy
from Note_rl.examples.keras.MADDPG import DDPG
strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = 32
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
with strategy.scope():
model=DDPG(128,0.1,0.98,0.005)
optimizer = [tf.keras.optimizers.Adam(),tf.keras.optimizers.Adam()]
model.set(policy=SoftmaxPolicy(),pool_size=3000,batch=GLOBAL_BATCH_SIZE,trial_count=10,MARL=True)
model.distributed_training(optimizer, strategy, 100, pool_network=False)
# This technology uses Python’s multiprocessing module to speed up trajectory collection and storage, I call it Pool Network.
import tensorflow as tf
from Note_rl.policy import EpsGreedyQPolicy
from Note_rl.examples.keras.pool_network.DQN import DQN
strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
with strategy.scope():
model=DQN(4,128,2,7)
optimizer = tf.keras.optimizers.Adam()
model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=GLOBAL_BATCH_SIZE,update_batches=17)
model.distributed_training(optimizer, strategy, 100, pool_network=True, processes=7)
MultiWorkerMirroredStrategy:
import tensorflow as tf
from Note_rl.policy import EpsGreedyQPolicy
from Note_rl.examples.keras.pool_network.DQN import DQN
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ.pop('TF_CONFIG', None)
if '.' not in sys.path:
sys.path.insert(0, '.')
tf_config = {
'cluster': {
'worker': ['localhost:12345', 'localhost:23456']
},
'task': {'type': 'worker', 'index': 0}
}
strategy = tf.distribute.MultiWorkerMirroredStrategy()
per_worker_batch_size = 64
num_workers = len(tf_config['cluster']['worker'])
global_batch_size = per_worker_batch_size * num_workers
with strategy.scope():
multi_worker_model = DQN(4,128,2)
optimizer = tf.keras.optimizers.Adam()
multi_worker_model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=global_batch_size,update_batches=17)
multi_worker_model.distributed_training(optimizer, strategy, num_episodes=100,
pool_network=True, processes=7)
# If set criterion.
# model.set(policy=EpsGreedyQPolicy(0.01),pool_size=10000,batch=global_batch_size,update_steps=10,trial_count=10,criterion=200)
# multi_worker_model.distributed_training(optimizer, strategy, num_episodes=100,
# pool_network=True, processes=7)
# If save the model at intervals of 10 episode, with a maximum of 2 saved file, and the file name is model.dat.
# model.path='model.dat'
# model.save_freq=10
# model. max_save_files=2
# multi_worker_model.distributed_training(optimizer, strategy, num_episodes=100,
# pool_network=True, processes=7)
# If save parameters only
# model.path='param.dat'
# model.save_freq=10
# model. max_save_files=2
# model.save_param_only=True
# multi_worker_model.distributed_training(optimizer, strategy, num_episodes=100,
# pool_network=True, processes=7)
# If save best only
# model.path='model.dat'
# model.save_best_only=True
# multi_worker_model.distributed_training(optimizer, strategy, num_episodes=100,
# pool_network=True, processes=7)
# visualize
# model.visualize_loss()
# model.visualize_reward()
# model.visualize_reward_loss()
# animate agent
# model.animate_agent(200)
# save
# model.save_param('param.dat')
# model.save('model.dat')
import pickle
output_file=open('param.dat','wb')
pickle.dump(model.param,output_file)
output_file.close()
or
model = MyModel(...)
model.save_param('param.dat')
import pickle
input_file=open('param.dat','rb')
param=pickle.load(input_file)
input_file.close()
or
model = MyModel(...)
model.restore_param('param.dat')
or
from Note import nn
param=nn.restore_param('param.dat')
model = MyModel(...)
model.save('model.dat')
# distributed training
with strategy.scope():
model = MyModel(...)
model.restore('model.dat')
or
model = MyModel(...)
model.restore('model.dat')
Description:
Runs the main training loop for the RL
agent. Supports single-process and multi-process experience collection via a pool network, distributed training strategies (Mirrored/MultiWorker/ParameterServer), just-in-time compilation for training steps, callbacks, and special replay mechanisms: Hindsight Experience Replay (HER), Prioritized Replay (PR) and PPO-compatible behavior. The method coordinates environment rollout(s), buffer aggregation, batch sampling, training updates, optional periodic trimming of replay buffers (via window_size_fn
/ window_size_ppo
), logging and model saving.
Arguments:
train_loss
(tf.keras.metrics.Metric
): Metric used to accumulate/report training loss (e.g.tf.keras.metrics.Mean()
).optimizer
(tf.keras.optimizers.Optimizer
or list): Optimizer (or list of optimizers) used to apply gradients. Ifself.optimizer
is already set, the passedoptimizer
is only used to initializeself.optimizer
(see code behaviour).episodes
(int
, optional): Number of episodes to run. IfNone
, training runs indefinitely (or untilself.stop_training
or reward criterion is met).jit_compile
(bool
, optional, default=True
): Whether to use@tf.function(jit_compile=True)
compiled train steps. When True the compiled train-steps are used where available.pool_network
(bool
, optional, default=True
): Enable pool-network multi-process rollouts. When True, experiences are collected in parallel byprocesses
worker processes and aggregated into shared (manager) buffers.processes
(int
, optional): Number of parallel worker processes used whenpool_network=True
to collect experience.processes_her
(int
, optional): When HER is enabled, number of processes used for HER batch generation. Affects internal multiprocessing logic and intermediate buffers.processes_pr
(int
, optional): When PR is enabled, number of processes used for prioritized replay sampling. Affects internal multiprocessing logic and intermediate buffers.window_size
(int
, optional): Fixed window size used when trimming per-process buffers insidepool
/store_in_parallel
. (IfNone
uses default popping behavior.)clearing_freq
(int
, optional): When set, triggers periodic trimming of per-process buffers everyclearing_freq
stored items.window_size_
(int
, optional): A global fallback window size used in several trimming spots when buffers exceedself.pool_size
.window_size_ppo
(int
, optional): Default PPO-specific window trimming size used ifwindow_size_fn
is not supplied (used whenPPO == True
andPR == True
).random
(bool
, optional, default=False
): Whenpool_network=True
, toggles random worker selection vs. inverse-length selection logic used instore_in_parallel
.save_data
(bool
, optional, default=True
): If True, keeps collected pool lists in shared manager lists to allow saving/resuming; otherwise per-process buffers are reinitialized each run.p
(int
, optional): Controls the logging/printing frequency. Ifp
isNone
a default of 9 is used (internally the implementation derives a logging interval). Ifp == 0
the periodic logging block is disabled (the code containsif p!=0
guards around prints). Implementation note: The code transforms the user-suppliedp
into an internalself.p
and a derived integerp
that is used for printing interval computation (p
becomes roughly the number of episodes between logs).
Returns:
- If running with
distributed_flag==True
: returns(total_loss / num_batches).numpy()
(the average distributed loss for the epoch/batch group). - Otherwise: returns
train_loss.result().numpy()
(the metric's current value). - If early exit happens (e.g.
self.stop_training==True
), the function returns early (commonly the currenttrain_loss
value ornp.array(0.)
depending on branch).
Details:
-
Initialization & manager setup:
- If
pool_network=True
, amultiprocessing.Manager()
is created and many local lists/buffers (state_pool_list
,action_pool_list
,reward_pool_list
, etc.) are converted into manager lists/dicts so worker processes can append data safely. - Per-process data structures (e.g.
self.ratio_list
,self.TD_list
) are initialized ifPR==True
. WhenPPO==True
andPR==True
the code uses per-processratio_list
/TD_list
and later concatenates them intoself.prioritized_replay
before training.
- If
-
Callbacks & training lifecycle:
- Calls
on_train_begin
on registered callbacks at the start. - Per-episode: calls
on_episode_begin
andon_episode_end
callbacks with logs including'loss'
and'reward'
. - Per-batch: calls
on_batch_begin
/on_batch_end
with batch logs (loss). This applies to both the PR/HER per-batch generation branches and the dataset-driven branches. - Respects
self.stop_training
— if set True during training the method exits early and returns.
- Calls
-
Experience collection:
- When
pool_network=True
the function spawnsprocesses
worker processes (each runsstore_in_parallel
) to produce per-process pool lists, thenconcatenate
s them (or packs them intoself.state_pool[7]
etc. whenprocesses_pr
/processes_her
are used). - If
processes_pr
/processes_her
are set, special per-process lists (self.state_list
,self.action_list
, ...) are used for parallel sampling and later aggregated indata_func()
.
- When
-
Training procedure & batching:
-
Two main modes:
-
PR/HER path: When
self.PR
orself.HER
isTrue
, batches are generated viaself.data_func()
(which may itself spawn worker processes to form batches). The loop iterates overbatches
computed from the pool length /self.batch
. Each generated batch is turned into a smalltf.data.Dataset
(batched toself.global_batch_size
) and then:- If using a MirroredStrategy, the dataset is distributed and
distributed_train_step
or_
is used. - Else the code uses
train_step
/train_step_
or directly the non-distributed loops.
- If using a MirroredStrategy, the dataset is distributed and
-
Plain dataset path: When not PR/HER, the code creates a
tf.data.Dataset
from the entire pool (self.state_pool,...
) and iterates it as usual (shuffle when notpool_network
), applyingtrain_step
/train_step_
for each mini-batch.
-
-
self.batch_counter
andself.step_counter
are used to decide when to callself.update_param()
and (if PPO + PR) when to applywindow_size_fn
/window_size_ppo
trimming to per-process buffers.
-
-
Distributed strategies:
-
Code supports
tf.distribute.MirroredStrategy
,MultiWorkerMirroredStrategy
andParameterServerStrategy
integration:- When MirroredStrategy is detected, datasets are distributed via
strategy.experimental_distribute_dataset
anddistributed_train_step
is used. - For
MultiWorkerMirroredStrategy
a custom path callsself.CTL
(user-defined) to compute loss over multiple workers. - If a ParameterServerStrategy is used and
stop_training
triggers, the code may callself.coordinator.join()
to sync workers and exit.
- When MirroredStrategy is detected, datasets are distributed via
-
-
Priority replay (PR) & PPO interactions:
-
If
PR==True
andPPO==True
, the training loop:- Maintains per-process
ratio_list
/TD_list
during collection. - Concatenates them into
self.prioritized_replay.ratio
andself.prioritized_replay.TD
before sampling/training. - When
self.batch_counter % self.update_batches == 0
orself.update_steps
triggers an update, the code attempts to callself.window_size_fn(p)
(if provided) for each process and trims per-process buffers to the returnedwindow_size
(or useswindow_size_ppo
fallback). This enables adaptive trimming (e.g. driven by ESS).
- Maintains per-process
-
If
PR==True
butPPO==False
, onlyTD_list
is used/concatenated.
-
-
Saving & early stopping:
- Periodic saving: if
self.path
is set andi % self.save_freq == 0
, callssave_param_
orsave_
depending onself.save_param_only
.max_save_files
andsave_best_only
can be used in your saving implementations (not implemented here). - Reward-based termination: if
self.trial_count
andself.criterion
are set, the method computesavg_reward
over the most recenttrial_count
episodes and will terminate early whenavg_reward >= criterion
. It prints summary info (episode count, average reward, elapsed time) and returns.
- Periodic saving: if
-
Logging behavior:
- The printed logs (loss/reward) are gated by the derived
p
logic. Passingp==0
suppresses periodic printouts (there are manyif p!=0
guards around prints). - The method always updates
self.loss_list
,self.total_episode
, andself.time
counters.
- The printed logs (loss/reward) are gated by the derived
-
Return values & possible early-exit values:
- On normal epoch/episode completion the method returns the computed train loss (distributed average or
train_loss.result().numpy()
). - On early exit (stop_training true or ParameterServer coordinator join) the method may return
np.array(0.)
or the current metric depending on branch.
- On normal epoch/episode completion the method returns the computed train loss (distributed average or
Notes / Implementation caveats:
- The
p
parameter behavior is non-standard: if you want the default printing cadence, passp=None
(internally becomes 9). Passp=0
to disable periodic printing. - When
PR==True
andPPO==True
the code expects per-processratio_list
/TD_list
and relies on concatenation. Make sure those variables are initialized and thatself.window_size_fn
(if used) handles small buffer sizes (the user-providedwindow_size_fn
should guardlen(weights) < 2
). - Be defensive around buffer sizes: many places assume
len(self.state_pool) >= self.batch
. During warm-up training you may see early returns if the pool is not yet filled. - The method mutates internal buffers when trimming; ensure that any external references to those buffers are updated if needed (they are manager lists/dicts in
pool_network
mode). - Callbacks are integrated; use them for logging, checkpointing, early stopping, or custom monitoring.
Description
Runs a distributed / multi-device training loop for the RL
agent using TensorFlow tf.distribute
strategies. It combines multi-process environment rollouts (pool network) with distributed model updates (MirroredStrategy / MultiWorkerMirroredStrategy) and supports special replay modes (Prioritized Replay PR
, Hindsight ER HER
) and PPO interactions. The method orchestrates rollout collection across OS processes, constructs aggregated replay buffers, builds distributed datasets, runs distributed train steps, calls callbacks, does periodic trimming (via window_size_fn
/ window_size_ppo
), saving, and early stopping.
optimizer
(tf.keras.optimizers.Optimizer
or list): Optimizer(s) to apply gradients. Ifself.optimizer
isNone
this will initializeself.optimizer
.strategy
(tf.distribute.Strategy
): A TensorFlow distribution strategy instance (e.g.tf.distribute.MirroredStrategy
,tf.distribute.MultiWorkerMirroredStrategy
) under whose scope distributed training is executed.episodes
(int
, optional): Number of episodes to run (MirroredStrategy path). IfNone
andnum_episodes
supplied,num_episodes
may be used by some branches.num_episodes
(int
, optional): Alternative name forepisodes
used by some strategy branches (e.g. MultiWorker path). If provided, it overrides/assignsepisodes
.jit_compile
(bool
, optional, default=True
): Whether to use JIT compiled train steps where available (@tf.function(jit_compile=True)
).pool_network
(bool
, optional, default=True
): Enable multi-process environment rollouts (pool of worker processes).processes
(int
, optional): Number of parallel worker processes to launch for rollouts whenpool_network=True
.processes_her
(int
, optional): Number of worker processes dedicated for HER sampling (ifHER=True
).processes_pr
(int
, optional): Number of worker processes dedicated for PR sampling (ifPR=True
).window_size
(int
, optional): Fixed per-process trimming window used in collection logic.clearing_freq
(int
, optional): Periodic trimming frequency (applies to per-process buffers).window_size_
(int
, optional): Global fallback window used in some trimming branches.window_size_ppo
(int
, optional): Default PPO window trimming fallback used ifwindow_size_fn
is not present (used withPPO==True and PR==True
).random
(bool
, optional, default=False
): Controls per-process selection strategy instore_in_parallel
(random vs. inverse-length selection).save_data
(bool
, optional, default=True
): Whether to persist per-process buffers to amultiprocessing.Manager()
so they survive across processes and can be saved.p
(int
, optional): Controls printing/logging frequency. IfNone
an internal default is used (≈9). Passingp==0
disables periodic printing. Internally the method transformsp
to an interval used for logging.
- For MirroredStrategy / distributed branches: returns
(total_loss / num_batches).numpy()
whendistributed_flag==True
and that branch computestotal_loss / num_batches
. - Otherwise returns
train_loss.result().numpy()
(current metric value). - The function may return early (e.g.
self.stop_training==True
or when rewardcriterion
is met). In early-exit cases the return value depends on the branch (commonly the current metric ornp.array(0.)
).
-
Distributed setup
-
The function sets
self.distributed_flag = True
and defines acompute_loss
closure insidestrategy.scope()
that callstf.nn.compute_average_loss
withglobal_batch_size=self.batch
. This is used by the distributed train step to scale per-example losses. -
It supports at least two strategy types explicitly:
tf.distribute.MirroredStrategy
— typical synchronous multi-GPU single-machine use; the function builds distributed datasets and usesdistributed_train_step
.tf.distribute.MultiWorkerMirroredStrategy
— multi-worker synchronous training. The code follows a slightly different loop (usesself.CTL
for loss aggregation in some branches).
-
-
Pool-network (multi-process rollouts)
- If
pool_network=True
the method creates amultiprocessing.Manager()
and convertsself.env
and many per-process lists into manager lists/dicts so worker processes can fill them concurrently. - For
PR==True
andPPO==True
it initializes per-processratio_list
andTD_list
(astf.Variable
wrappers) and later concatenates them intoself.prioritized_replay.ratio
/.TD
before training. - Worker processes are launched using
mp.Process(target=self.store_in_parallel, args=(p, lock_list))
to collect rollouts. Note: the code referenceslock_list
when launching workers in some branches butlock_list
is not created in every branch of this function (this is an implementation caveat — see Caveats).
- If
-
Data aggregation & sampling
- When
processes_her
/processes_pr
are provided, the code collects per-process mini-batches (self.state_list
,self.action_list
, etc.) anddata_func()
uses those to form training batches. - When not using PR/HER, per-process pools are concatenated
np.concatenate(self.state_pool_list)
etc. to form the fullself.state_pool
which is turned into atf.data.Dataset
.
- When
-
Training step selection
- For Mirrored strategy: dataset is wrapped with
strategy.experimental_distribute_dataset()
and the loop callsdistributed_train_step
(JIT or non-JIT variant depending onjit_compile
). - For MultiWorker strategy: the code takes a different path and (in places) calls
self.CTL(multi_worker_dataset)
— a custom user-defined procedure expected to exist on the RL instance. - For non-distributed branches fallback to
train1
/train2
logic is reused.
- For Mirrored strategy: dataset is wrapped with
-
PR / PPO interactions
- If
PR
is enabled, per-process TD / ratio lists are concatenated into the prioritized replay object before sampling/training. - If
PPO
+PR
the method useswindow_size_fn
(if present) to compute adaptive trimming for each process and trimsstate_pool_list[p]
etc. accordingly after update steps; otherwise it falls back towindow_size_ppo
.
- If
-
Callbacks, saving, and early stopping
- Calls callbacks:
on_train_begin
,on_episode_begin
,on_batch_begin
,on_batch_end
,on_episode_end
,on_train_end
at appropriate points. - Saves model / params periodically when
self.path
is set according toself.save_freq
. - If
self.trial_count
andself.criterion
are set, computes a rolling average reward over recent episodes and stops training early if criterion is reached.
- Calls callbacks:
- Launching many OS processes for rollouts can be CPU- and memory- intensive. Use a sensible
processes
count per machine. - MirroredStrategy moves gradient application to devices — ensure your batch sizing and
global_batch_size
match your device count to avoid under/over-scaling. PR
requires additional memory for theratio
/TD
arrays; be mindful when concatenating per-process lists.
lock_list
usage: the function passeslock_list
intostore_in_parallel
in several places butlock_list
is not defined insidedistributed_training
before use. If you rely on locks to guard manager lists, make sure to constructlock_list = [mp.Lock() for _ in range(processes)]
(as is done in the non-distributedtrain
function) and pass it into the worker processes.- Small buffer sizes: many trimming and
window_size_fn
usages assumelen(weights) >= 2
. Guardwindow_size_fn
and trimming calls against tiny buffers during warm-up. self.CTL
and other user hooks: The code callsself.CTL(...)
in MultiWorker branches — ensure you implement this helper to compute the loss when using MultiWorker strategy.- Return values vary by branch: different strategy branches return different items (distributed average loss or metric result). Tests should validate the return path you use.
Description: Compute an adaptive experience-replay window size based on the effective sample size (ESS) of the prioritized weights. This function estimates how many recent experiences should be kept (vs. discarded) by converting the ESS into a desired number of kept samples, applying optional exponential moving average (EMA) smoothing to the ESS, and returning the number of oldest entries to drop (the window size). It supports both single-process and pool-network (multi-process) setups.
Arguments:
p
(int
): Process index whenpool_network=True
. Selects which sub-pool's weight vector to evaluate. Ifpool_network=False
,p
is ignored.scale
(float
, optional, default=1.0
): Multiplier applied to the (smoothed) ESS to compute the desired number of samples to keep. Values >1 increase the kept size (smaller window), values <1 decrease it (larger window).smooth_alpha
(float
, optional, default=0.2
): EMA smoothing coefficient in[0,1]
used to smooth ESS over time. Higher values weight the newest ESS more; lower values emphasize past ESS.
Returns:
window_size
(int
): Suggested number of oldest samples to remove from the experience pool. Computed aslen(weights) - desired_keep
. Guaranteed to be a non-negative integer under normal conditions (see Notes).
Details:
-
Choose source of weights:
- If
self.pool_network == True
the function readsweights = np.array(self.ratio_list[p])
— the per-process ratio array used for prioritized sampling in that sub-pool. - If
self.pool_network == False
the function readsweights = np.array(self.prioritized_replay.ratio)
— the global prioritized weights array.
- If
-
Compute ESS:
-
Calls
self.compute_ess_from_weights(weights)
, which:- clips weights to a minimum positive value (to avoid zeros),
- normalizes them to a probability vector
p
, - computes ESS as
1 / sum(p^2)
.
-
ESS is a continuous estimate of how many “effective” independent samples exist given the weight distribution.
-
-
EMA smoothing:
- The function stores smoothed ESS in
self.ema_ess
. - For
pool_network==True
,self.ema_ess
is a list andself.ema_ess[p]
is updated. For single-process mode it is a scalar. - New smoothed ESS is
ema = smooth_alpha * ess + (1.0 - smooth_alpha) * prev_ema
(orema = ess
if no prior EMA exists).
- The function stores smoothed ESS in
-
Desired kept samples and window size:
-
desired_keep = np.clip(int(ema * scale), 1, len(weights) - 1)
- Intuition: convert (smoothed) ESS to an integer number of samples to keep, optionally scaled.
- The clip prevents degeneracy by requiring at least one sample kept and at most
len(weights)-1
.
-
window_size = len(weights) - desired_keep
- This is the number of oldest entries to remove; the caller can then slice arrays like
state_pool = state_pool[window_size:]
.
- This is the number of oldest entries to remove; the caller can then slice arrays like
-
-
Side effects:
- Updates
self.ema_ess
(orself.ema_ess[p]
) with the new smoothed ESS value. - Does not modify replay buffers or ratio/TD arrays — it only returns the window size. The caller is responsible for actually removing entries.
- Updates
-
Assumptions & edge cases:
- The function assumes
weights
has length ≥ 2. Iflen(weights) <= 1
the codenp.clip(..., 1, len(weights)-1)
may produce an invalid clip range (upper < lower) and raise aValueError
or produce unexpected results. It is recommended to guard against this by checkinglen(weights)
before calling (or adding a small wrapper). - If weights contain zeros or extremely small values,
compute_ess_from_weights
already protects against divide-by-zero by clipping to a small positive minimum.
- The function assumes
-
Complexity:
- Time complexity is O(n) where n is the number of weights (dominant cost is computing ESS).
Usage Example:
https://github.com/NoteDance/Note_rl/blob/main/Note_rl/examples/keras/pool_network/PPO_pr.py https://github.com/NoteDance/Note_rl/blob/main/Note_rl/examples/pytorch/pool_network/PPO_pr.py
Description:
Compute and return an adaptive training mini-batch size based on the Effective Sample Size (ESS) computed from replay-buffer weights. The method converts TD-errors (and — for PPO — a ratio-deviation term) into sampling weights, computes ESS, smooths ESS with an exponential moving average (EMA), maps ESS to a batch size (or scales the current batch relative to a target_ess
), constrains the result to [min_batch, max_batch]
, aligns it to a multiple (align
), caps it to the buffer length, and optionally adapts the prioritization exponent alpha
using a simple learning rule.
Arguments:
scale
(float
, optional, default=1.0
): Multiplicative scaling factor applied when mapping ESS → candidate batch size (or when scaling relative totarget_ess
).smooth_alpha
(float
, optional, default=0.2
): Smoothing coefficient used when updating EMA of ESS:ema = smooth_alpha * ess + (1 - smooth_alpha) * prev_ema
.min_batch
(int
, optional): Minimum allowed batch size. IfNone
, usesmax(1, self.batch // 2)
.max_batch
(int
, optional): Maximum allowed batch size. IfNone
, usesmax(1, len(weights))
.target_ess
(float
, optional): If provided, batch is computed by scaling the currentself.batch
byema/target_ess
(useful for keeping ESS near a target) instead of directly usingema
.align
(int
, optional): The returned batch size will be rounded down to a multiple ofalign
. IfNone
, defaults toself.batch
.alpha_lr
(float
, optional): If provided, enables online adjustment ofself.alpha
(the exponent used to convert TD / scores → weights).alpha
is moved toward atarget_alpha
computed from the ESS error with learning ratealpha_lr
.alpha_min
(float
, optional): Minimum allowed value forself.alpha
whenalpha_lr
is used. Required ifalpha_lr
is provided.alpha_max
(float
, optional): Maximum allowed value forself.alpha
whenalpha_lr
is used. Required ifalpha_lr
is provided.smooth_beta
(float
, optional, default=0.2
): Smoothing coefficient used when updatingself.alpha
:self.alpha = smooth_beta * self.alpha + (1 - smooth_beta) * target_alpha
.
Returns:
int
— The adjusted mini-batch size (≥ 1, ≤ buffer length) that respectsmin_batch
,max_batch
, andalign
. Also may updateself.alpha
ifalpha_lr
is provided.
Details / Algorithm:
-
Weight computation:
- If
self.PPO
isTrue
: compute per-samplescores = self.lambda_ * TD + (1.0 - self.lambda_) * abs(ratio - 1.0)
and thenweights = (scores + 1e-7) ** self.alpha
. - Otherwise:
weights = (TD + 1e-7) ** self.alpha
. TD
(andratio
) are taken fromself.prioritized_replay
(single-process) or (in pool setups)self.TD_list
/self.ratio_list
.
- If
-
Compute ESS:
- Call
self.compute_ess_from_weights(weights)
which normalizes weights and computes ESS robustly:ess = 1 / sum(p^2)
wherep = w / sum(w)
.
- Call
-
EMA smoothing:
- Update the stored EMA
self.ema_ess
(scalar for single-process, or shared array per process in pool mode) withsmooth_alpha
to reduce variance in ESS estimates.
- Update the stored EMA
-
Map ESS → batch size:
-
If
target_ess
is provided:- Compute
batch = round(self.batch * ema / target_ess * scale)
— i.e. scale the current batch to move ESS towardtarget_ess
.
- Compute
-
Else:
- Compute
batch = round(ema * scale)
.
- Compute
-
Clip
batch
to[min_batch, max_batch]
(with sensible defaults for min/max). -
Align
batch
down to a multiple ofalign
:new_batch = align * (batch // align)
and ensure at least1
. -
Cap
new_batch
by the buffer lengthlen(weights)
(because batch cannot exceed number of samples).
-
-
Optional
alpha
adaptation:- If
alpha_lr
is provided, compute atarget_alpha
using the normalized ESS error:target_alpha = self.alpha + alpha_lr * (target_ess - ema) / target_ess
(the implementation multiplies the error byalpha_lr
and then clipstarget_alpha
to[alpha_min, alpha_max]
). - Smoothly update
self.alpha
viaself.alpha = smooth_beta * self.alpha + (1 - smooth_beta) * target_alpha
. - Store
self.alpha
as a float.
- If
-
Return:
- Return
int(new_batch)
.
- Return
Edge Cases & Notes:
- If the replay buffer is empty (
len(weights) == 0
) or too small, calls to this function will be invalid; ensure buffer has samples before calling. align
should be chosen to reflect training constraints (device / data-parallel multiples, or the originalself.batch
). Ifalign
is larger than the buffer, the returned batch will be capped to the buffer length.- When
target_ess
is used, the method scales the current batch to move ESS toward the target; choosetarget_ess
relative to feasible ESS values (1..buffer_size). alpha_lr
must be supplied together withalpha_min
andalpha_max
to boundself.alpha
. The adaptation is heuristic — tunealpha_lr
andsmooth_beta
carefully to avoid instability.- This method updates
self.ema_ess
(and optionallyself.alpha
) as a side effect.
Usage Example:
https://github.com/NoteDance/Note_rl/blob/main/Note_rl/examples/keras/pool_network/PPO_pr.py https://github.com/NoteDance/Note_rl/blob/main/Note_rl/examples/pytorch/pool_network/PPO_pr.py
Description:
Adaptively compute a new mini-batch size by estimating the gradient variance from the current replay buffer and scaling the batch to drive that estimated noise toward a target noise level. The routine estimates the gradient variance using estimate_gradient_variance
, maintains an exponential moving average (EMA) of that noise, maps EMA noise → batch size (relative to the current self.batch
and target_noise
), enforces min/max and alignment constraints, and optionally updates the prioritization exponent self.alpha
as a heuristic to influence sampling weights.
Arguments:
-
num_samples
(int
): Number of gradient samples to draw when estimating gradient variance. Each sample computes gradients on a different random mini-batch (of currentself.batch
) and is used to estimate variance of gradients. -
target_noise
(float
, optional, default=1e-3
): Desired (target) gradient-variance level. The function scales the batch size to moveema_noise
toward thistarget_noise
. Smallertarget_noise
generally leads to larger computed batches. -
scale
(float
, optional, default=1.0
): Multiplicative factor applied when converting the noise ratio(ema_noise / target_noise)
into a candidate batch size. -
smooth_alpha
(float
, optional, default=0.2
): Smoothing coefficient for the EMA update of noise:ema_noise = smooth_alpha * estimated_noise + (1 - smooth_alpha) * prev_ema_noise
. -
min_batch
(int
, optional): Minimum allowed batch size. IfNone
, defaults tomax(1, self.batch // 2)
. -
max_batch
(int
, optional): Maximum allowed batch size. IfNone
, defaults tomax(1, buffer_length)
wherebuffer_length
is the length of the underlying replay buffer used (single-processself.state_pool
or pooledself.state_pool[7]
). -
align
(int
, optional): Round the resulting batch down to a multiple ofalign
. IfNone
, defaults to the currentself.batch
. -
jit_compile
(bool
, optional, default=True
): Passed toestimate_gradient_variance
to select whether to run the gradient estimation with the JIT-compiled (tf.function(jit_compile=True)
) path or the plaintf.function
path. Affects performance and possibly numerical behavior. -
alpha_min
(float
, optional): Lower bound forself.alpha
ifalpha_lr
is used to adapt it. Required ifalpha_lr
is supplied. -
alpha_max
(float
, optional): Upper bound forself.alpha
ifalpha_lr
is used. -
alpha_lr
(float
, optional): If provided, enable a heuristic online update ofself.alpha
to try to influence sampling weight sharpness based on the noise gap. The update rule uses(target_noise - ema_noise) / target_noise
scaled byalpha_lr
, clipped to[alpha_min, alpha_max]
, and smoothed intoself.alpha
(the code uses a smoothing factor of0.9
for the old alpha and0.1
for the target update).
Returns:
int
— The new mini-batch size (≥ 1), clipped to[min_batch, max_batch]
and aligned toalign
.
Details:
-
Estimate gradient variance:
-
Calls
self.estimate_gradient_variance(self.batch, num_samples, jit_compile)
which:- Draws
num_samples
gradient estimates using random mini-batches of sizeself.batch
from the replay buffer, - Flattens and stacks gradients, computes the mean gradient, and returns the (scalar) variance estimate.
- Draws
-
-
EMA of noise:
-
Maintains
self.ema_noise
. If not present, it is initialized to the firstestimated_noise
. Otherwise it is updated withsmooth_alpha
:ema_noise = smooth_alpha * estimated_noise + (1 - smooth_alpha) * self.ema_noise self.ema_noise = ema_noise
-
-
Buffer length & min/max defaults:
- Determines
buf_len
fromself.state_pool
(single-process) orself.state_pool[7]
(pooled / special multi-process case). - If
min_batch
isNone
, set tomax(1, self.batch // 2)
. - If
max_batch
isNone
, set tomax(1, buf_len)
.
- Determines
-
Map noise → batch:
-
Compute a candidate batch by scaling the current batch size by the noise ratio:
base_new_batch = round(self.batch * (ema_noise / target_noise) * scale)
(i.e. if noise is larger than target, it increases batch size; if smaller, it decreases).
-
Clip
base_new_batch
into[min_batch, max_batch]
.
-
-
Alignment:
-
Align the clipped batch down to a multiple of
align
:- If
align
isNone
, useself.batch
. new_batch = align * (clipped_batch // align)
.- Ensure
new_batch >= 1
andnew_batch <= max_batch
.
- If
-
-
Optional
self.alpha
adaptation:-
If
alpha_lr
is provided, compute:target_alpha = self.alpha + alpha_lr * (target_noise - ema_noise) / target_noise target_alpha = clip(target_alpha, alpha_min, alpha_max) self.alpha = 0.9 * self.alpha + 0.1 * target_alpha
This is a heuristic to change how sharply TD/ration scores are converted into weights (affects prioritized sampling).
-
-
Return:
- Returns
new_batch
asint
.
- Returns
Side Effects:
- Updates
self.ema_noise
(EMA of gradient variance). - May update
self.alpha
whenalpha_lr
is supplied. - Does not change
self.batch
itself; the caller should assign the returned value back toself.batch
if desired.
Edge Cases & Notes:
- The function expects the replay buffer to contain at least
self.batch
samples. If the buffer is smaller thanself.batch
,estimate_gradient_variance
(and thusadabatch
) may fail or produce noisy estimates. num_samples
controls the estimator variance for gradient variance: largernum_samples
yields a more stable estimate but costs more computation.- The
target_noise
should be chosen based on empirical gradient magnitudes; if set too small, it will push batch sizes up aggressively (possibly to the buffer limit). - The
alpha_lr
adaptation is heuristic — tunealpha_lr
,alpha_min
,alpha_max
and smoothing carefully to avoid instability. align
is useful to keep batch sizes compatible with hardware or parallelization constraints (e.g., data-parallel multiples).
Usage Example:
https://github.com/NoteDance/Note_rl/blob/main/Note_rl/examples/keras/pool_network/PPO_pr.py https://github.com/NoteDance/Note_rl/blob/main/Note_rl/examples/pytorch/pool_network/PPO_pr.py
Usage:
Create a Note_rl agent, then execute this code:
from Note_rl.lr_finder import LRFinder
# agent is a Note_rl agent
agent.optimizer = tf.keras.optimizers.Adam()
lr_finder = LRFinder(agent)
# Train a agent with 77 episodes
# with learning rate growing exponentially from 0.0001 to 1
# N: Total number of iterations (or mini-batch steps) over which the learning rate is increased.
# This parameter determines how many updates occur between the starting learning rate (start_lr)
# and the ending learning rate (end_lr). The learning rate is increased exponentially by a fixed
# multiplicative factor computed as:
# factor = (end_lr / start_lr) ** (1.0 / N)
# This ensures that after N updates, the learning rate will reach exactly end_lr.
#
# window_size: The size of the sliding window (i.e., the number of most recent episodes)
# used to compute the moving average and standard deviation of the rewards.
# This normalization helps smooth out the reward signal and adjust for the fact that
# early episodes may have lower rewards (due to limited experience) compared to later ones.
# By using only the recent window_size rewards, we obtain a more stable and current estimate
# of the reward statistics for normalization.
lr_finder.find(train_loss, pool_network=False, N=77, window_size=7, start_lr=0.0001, end_lr=1, episodes=77)
or
from Note_rl.lr_finder import LRFinder
# agent is a Note_rl agent
agent.optimizer = tf.keras.optimizers.Adam()
strategy = tf.distribute.MirroredStrategy()
lr_finder = LRFinder(agent)
# Train a agent with 77 episodes
# with learning rate growing exponentially from 0.0001 to 1
# N: Total number of iterations (or mini-batch steps) over which the learning rate is increased.
# This parameter determines how many updates occur between the starting learning rate (start_lr)
# and the ending learning rate (end_lr). The learning rate is increased exponentially by a fixed
# multiplicative factor computed as:
# factor = (end_lr / start_lr) ** (1.0 / N)
# This ensures that after N updates, the learning rate will reach exactly end_lr.
#
# window_size: The size of the sliding window (i.e., the number of most recent episodes)
# used to compute the moving average and standard deviation of the rewards.
# This normalization helps smooth out the reward signal and adjust for the fact that
# early episodes may have lower rewards (due to limited experience) compared to later ones.
# By using only the recent window_size rewards, we obtain a more stable and current estimate
# of the reward statistics for normalization.
lr_finder.find(pool_network=False, strategy=strategy, N=77, window_size=7, start_lr=0.0001, end_lr=1, episodes=77)
# Plot the reward, ignore 20 batches in the beginning and 5 in the end
lr_finder.plot_reward(n_skip_beginning=20, n_skip_end=5)
# Plot rate of change of the reward
# Ignore 20 batches in the beginning and 5 in the end
# Smooth the curve using simple moving average of 20 batches
# Limit the range for y axis to (-0.02, 0.01)
lr_finder.plot_reward_change(sma=20, n_skip_beginning=20, n_skip_end=5, y_lim=(-0.01, 0.01))
Usage:
Create a Note agent, then execute this code:
from Note_rl.opt_finder import OptFinder
# agent is a Note agent
optimizers = [tf.keras.optimizers.Adam(), tf.keras.optimizers.AdamW(), tf.keras.optimizers.Adamax()]
opt_finder = OptFinder(agent, optimizers)
# Train a agent with 7 episodes
opt_finder.find(train_loss, pool_network=False, episodes=7)
or
from Note_rl.opt_finder import OptFinder
# agent is a Note agent
optimizers = [tf.keras.optimizers.Adam(), tf.keras.optimizers.AdamW(), tf.keras.optimizers.Adamax()]
strategy = tf.distribute.MirroredStrategy()
opt_finder = OptFinder(agent, optimizers)
# Train a agent with 7 episodes
opt_finder.find(pool_network=False, strategy=strategy, episodes=7)
Overview
The AgentFinder class is designed for reinforcement learning or multi-agent training scenarios. It trains multiple agents in parallel and selects the best performing agent based on a chosen metric (reward or loss). The class employs multiprocessing to run each agent’s training in its own process and uses callbacks at the end of each episode to update performance logs. Depending on the selected metric, at the end of the training episodes, it computes the mean reward or mean loss for each agent and updates the shared logs with the best optimizer and corresponding performance value.
Key Attributes
-
agents
Type:list
Description: A list of agent instances to be trained. Each agent will run its training in a separate process. -
optimizers
Type:list
Description: A list of optimizers corresponding to the agents, used during the training process. -
rewards
Type: Shared dictionary (created viamultiprocessing.Manager().dict()
)
Description: Records the reward values for each episode for every agent. For each agent, a list of rewards is maintained. -
losses
Type: Shared dictionary
Description: Records the loss values for each episode for every agent. For each agent, a list of losses is maintained. -
logs
Type: Shared dictionary
Description: Stores key training information. Initially, it contains:best_reward
: Set to a very low value (-1e9) to store the best mean reward.best_loss
: Set to a high value (1e9) to store the lowest mean loss.- When training is complete, it also stores
best_opt
, which corresponds to the optimizer of the best performing agent.
-
lock
Type:multiprocessing.Lock
Description: A multiprocessing lock used to ensure data consistency and thread safety when multiple processes update the shared dictionaries. -
episode
Type:int
Description: The total number of training episodes, set in thefind
method. This value is used to determine if the current episode is the final one.
Main Methods
1. __init__(self, agents, optimizers)
Purpose:
Initializes an AgentFinder instance by setting the list of agents and corresponding optimizers. It also creates shared dictionaries for rewards, losses, and logs, and initializes a multiprocessing lock to ensure safe data access.
Parameters:
agents
: A list of agent instances.optimizers
: A list of optimizers corresponding to the agents.
Details:
The constructor uses multiprocessing.Manager()
to create shared dictionaries (rewards
, losses
, logs
) and sets initial values for best reward and best loss for subsequent comparisons. A lock object is created to synchronize updates in a multiprocessing environment.
2. on_episode_end(self, episode, logs, agent=None, lock=None)
Purpose:
This callback function is invoked at the end of each episode when the metric is set to 'reward'. It updates the corresponding agent’s reward list and, if the episode is the last one, calculates the mean reward. If the mean reward exceeds the current best reward recorded in the shared logs, it updates the logs with the new best reward and the corresponding optimizer.
Parameters:
episode
: The current episode number (starting from 0).logs
: A dictionary containing training information for the current episode; it must include the key'reward'
.agent
: The current agent instance, used to update the reward list and access its optimizer.lock
: The multiprocessing lock used to synchronize access to shared data.
Key Logic:
- Acquire the lock with
lock.acquire()
to ensure safe data updates. - Retrieve the current episode’s reward from
logs
. - Append the reward to the corresponding agent’s list in the
rewards
dictionary. - If this is the last episode (i.e.,
episode + 1 == self.episode
), calculate the mean reward. - If the mean reward is higher than the current
best_reward
in the shared logs, updatelogs['best_reward']
andlogs['best_opt']
(using the agent’s optimizer). - Release the lock using
lock.release()
.
3. on_episode_end_(self, episode, logs, agent=None, lock=None)
Purpose:
This callback function is used when the metric is set to 'loss'. It updates the corresponding agent’s loss list and, at the end of the final episode, computes the mean loss. If the mean loss is lower than the current best loss recorded in the shared logs, it updates the logs with the new best loss and the corresponding optimizer.
Parameters:
episode
: The current episode number (starting from 0).logs
: A dictionary containing training information for the current episode; it must include the key'loss'
.agent
: The current agent instance.lock
: The multiprocessing lock used to synchronize access to shared data.
Key Logic:
- Acquire the lock to ensure safe updates.
- Retrieve the loss from
logs
and append it to the corresponding agent’s list in thelosses
dictionary. - At the last episode, calculate the mean loss and compare it to the current best loss.
- If the mean loss is lower, update
logs['best_loss']
andlogs['best_opt']
(with the agent’s optimizer). - Release the lock.
4. find(self, train_loss=None, pool_network=True, processes=None, processes_her=None, processes_pr=None, strategy=None, episodes=1, metrics='reward', jit_compile=True)
Purpose:
Starts the training of multiple agents using multiprocessing and utilizes callback functions to update the best agent information based on the selected metric (reward or loss).
Parameters:
train_loss
: A function or parameter for computing the training loss (optional).pool_network
: Boolean flag indicating whether to use a shared network pool.processes
: Number of processes to be used for training (optional).processes_her
: Parameters related to HER (Hindsight Experience Replay) (optional).processes_pr
: Parameters possibly related to Prioritized Experience Replay (optional).strategy
: Distributed training strategy (optional). If provided, the distributed training mode is used; otherwise, standard training is performed.episodes
: Total number of training episodes.metrics
: The metric to be used, either'reward'
or'loss'
. This choice determines which callback function is used.jit_compile
: Boolean flag indicating whether to enable JIT compilation to speed up training.
Key Logic:
- Set the total number of episodes to
self.episodes
. - Iterate over each agent:
- If the selected metric is
'reward'
:- Use
functools.partial
to create apartial_callback
that binds the agent, lock, and theon_episode_end
callback. - Create a callback instance using
nn.LambdaCallback
. - Initialize the agent’s reward list in the
rewards
dictionary.
- Use
- If the selected metric is
'loss'
:- Similarly, bind the
on_episode_end_
callback. - Initialize the agent’s loss list in the
losses
dictionary.
- Similarly, bind the
- If the selected metric is
- Assign the corresponding optimizer to each agent.
- Depending on whether a
strategy
is provided, choose the training mode:- If
strategy
isNone
, call the agent’strain
method with the appropriate parameters (e.g., training loss, episodes, network pool options, process parameters, callbacks, and jit_compile settings). - If a
strategy
is provided, call the agent’sdistributed_training
method with similar parameters and a similar callback setup.
- If
- Start all training processes and wait for them to complete using
join()
.
Example Usage
Below is an example demonstrating how to use AgentFinder to train multiple agents and select the best performing agent based on either reward or loss:
from Note_rl.parallel_finder import ParallelFinder
# Assume agent1 and agent2 are two initialized agent instances,
# and optimizer1 and optimizer2 are their respective optimizers.
agent1 = ... # Initialize agent 1
agent2 = ... # Initialize agent 2
optimizer1 = ... # Optimizer for agent 1
optimizer2 = ... # Optimizer for agent 2
# Create lists of agents and optimizers
agents = [agent1, agent2]
optimizers = [optimizer1, optimizer2]
# Initialize the AgentFinder instance
parallel_finder = ParallelFinder(agents, optimizers)
# Assume train_loss is defined as a function or metric for calculating training loss (if needed)
train_loss = ...
# Choose the evaluation metric: 'reward' or 'loss'
metrics_choice = 'reward' # or 'loss'
# Execute training with 10 episodes and enable JIT compilation
parallel_finder.find(
train_loss=train_loss,
pool_network=True,
processes=4,
processes_her=2,
processes_pr=2,
strategy=None, # Pass None to use standard training (not distributed)
episodes=10,
metrics=metrics_choice,
jit_compile=True
)
# After training, retrieve the best record from agent_finder.logs
if metrics_choice == 'reward':
print("Best Mean Reward:", agent_finder.logs['best_reward'])
else:
print("Best Mean Loss:", agent_finder.logs['best_loss'])
print("Best Optimizer:", agent_finder.logs['best_opt'])