-
Notifications
You must be signed in to change notification settings - Fork 2k
Description
🚀 Feature
Optional BatchNorm integration in NatureCNN
Motivation
Motivation
Batch Normalization helps stabilize and accelerate training by reducing internal covariate shift, which is especially important in high-variance pixel‐based environments like Atari games. By normalizing the activations after each convolutional layer, we expect smoother gradient flow, improved convergence speed, and reduced sensitivity to hyperparameters.
Alternatives Considered
-
LayerNorm: Normalizes across channels for each sample, but doesn’t leverage batch statistics—proved slower to converge in our early trials.
-
GroupNorm: Trades off between BatchNorm and LayerNorm by normalizing over groups of channels; improved stability but added implementation complexity and similar runtime overhead.
BatchNorm offered the best trade-off of simplicity, runtime efficiency, and empirical performance.
Early Results
We ran PPO with NatureCNN
+ BatchNorm on Breakout (A.L.E.) for ~200 K timesteps:
By 200 K timesteps, the agent achieves an average reward of 18.4 ± 6.5, demonstrating both faster early learning and higher final performance compared to the baseline without BatchNorm.
Proposed Implementation
-
Introduce a new
use_batch_norm: bool = False
argument inNatureCNN.__init__
. -
When
use_batch_norm=True
, insertnn.BatchNorm2d
immediately after each convolutional layer:pythonCopyEditlayers = [] layers.append(nn.Conv2d(...)) if use_batch_norm: layers.append(nn.BatchNorm2d(...)) layers.append(nn.ReLU()) # repeat for each conv block
-
Default behavior remains unchanged (
use_batch_norm=False
), ensuring full backward compatibility.
Pitch
Enable an optional BatchNorm toggle in the NatureCNN feature extractor so users can easily turn on/off batch normalization after each convolutional layer, improving training stability and convergence in high-variance, image-based environments.
Alternatives
Alternatives
By default, use_batch_norm is set to False, so there is zero performance or behavioral impact unless the flag is explicitly turned on. When enabled, BatchNorm leverages batch-level statistics to stabilize and accelerate learning in high-variance, image-based inputs.
Other normalization strategies I evaluated:
LayerNorm: Normalizes per sample across channels—does not use batch statistics, led to slower convergence in our Atari benchmarks.
GroupNorm: Splits channels into groups for normalization—more stable than LayerNorm but incurs extra complexity and similar runtime overhead.
Neither alternative matched the simplicity, efficiency, and empirical gains of toggled-on BatchNorm, so we opted for a boolean flag that keeps it completely off by default.
Additional context
No response
Checklist
- I have checked that there is no similar issue in the repo
- If I'm requesting a new feature, I have proposed alternatives