Skip to content

Commit 2cf357a

Browse files
james-martensDKSdev
authored and
DKSdev
committed
- Adding implementation of Per-Location Normalization
- Updating README.md - Advancing version number to 0.1.2 - Adding whitespace to improve readability - Fixing minor typos in docstrings - Increasing version requirement on PyTorch to 1.8.0. Tests now performed at 1.13.1 PiperOrigin-RevId: 503432432
1 parent dd7bfb0 commit 2cf357a

16 files changed

+666
-33
lines changed

README.md

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
# Official Python package for Deep Kernel Shaping (DKS) and Tailored Activation Transformations (TAT)
55

6-
This Python package implements the activation function transformations and
7-
weight initializations used in Deep Kernel Shaping (DKS) and Tailored Activation
8-
Transformations (TAT). DKS and TAT, which were introduced in the [DKS paper] and
9-
[TAT paper], are methods for constructing/transforming neural networks to make
10-
them much easier to train. For example, these methods can be used in conjunction
11-
with K-FAC to train deep vanilla deep convnets (without skip connections or
12-
normalization layers) as fast as standard ResNets of the same depth.
6+
This Python package implements the activation function transformations, weight
7+
initializations, and dataset preprocessing used in Deep Kernel Shaping (DKS) and
8+
Tailored Activation Transformations (TAT). DKS and TAT, which were introduced in
9+
the [DKS paper] and [TAT paper], are methods for constructing/transforming
10+
neural networks to make them much easier to train. For example, these methods
11+
can be used in conjunction with K-FAC to train deep vanilla deep convnets
12+
(without skip connections or normalization layers) as fast as standard ResNets
13+
of the same depth.
1314

1415
The package supports the JAX, PyTorch, and TensorFlow tensor programming
1516
frameworks.
@@ -23,16 +24,18 @@ from Github will be rejected. Instead, please email us if you find a bug.
2324
## Usage
2425

2526
For each of the supported tensor programming frameworks, there is a
26-
corresponding subpackage which handles the activation function transformations
27-
and weight initializations. (These are `dks.jax`, `dks.pytorch`, and
28-
`dks.tensorflow`.) It's up to the user to import these and use them
29-
appropriately within their model code. Activation functions are transformed by
30-
the function `get_transformed_activations()` in the module
27+
corresponding subpackage which handles the activation function transformations,
28+
weight initializations, and (optional) data preprocessing. (These are `dks.jax`,
29+
`dks.pytorch`, and `dks.tensorflow`.) It's up to the user to import these and
30+
use them appropriately within their model code. Activation functions are
31+
transformed by the function `get_transformed_activations()` in the module
3132
`activation_transform` of the appropriate subpackage. Sampling initial
3233
parameters is done using functions inside of the module
33-
`parameter_sampling_functions` of said subpackage. Note that in order to avoid
34-
having to import all of the tensor programming frameworks, the user is required
35-
to individually import whatever framework subpackage they want. e.g. `import
34+
`parameter_sampling_functions` of said subpackage. And data preprocessing is
35+
done using the function `per_location_normalization` inside of the module
36+
`data_preprocessing` of said subpackage. Note that in order to avoid having to
37+
import all of the tensor programming frameworks, the user is required to
38+
individually import whatever framework subpackage they want. e.g. `import
3639
dks.jax`. Meanwhile, `import dks` won't actually do anything.
3740

3841
`get_transformed_activations()` requires the user to pass either the "maximal
@@ -52,9 +55,10 @@ weighted sums into "normalized sums" (which are weighted sums whose
5255
non-trainable weights have a sum of squares equal to 1). See the section titled
5356
"Summary of our method" of the [DKS paper] for more details.
5457

55-
Note that this package doesn't currently include an implementation of
56-
Per-Location Normalization (PLN) data pre-processing. While not required for
57-
CIFAR or ImageNet, PLN could potentially be important for other datasets. Also
58+
Note that the data preprocessing method implemented, called Per-Location
59+
Normalization (PLN), may not always be needed in practice, but we have observed
60+
certain situations where not using can lead to problems. (For example, training
61+
on datasets that contain all-zero pixels, such as CIFAR-10.) Also
5862
note that ReLUs are only partially supported by DKS, and unsupported by TAT, and
5963
so their use is *highly* discouraged. Instead, one should use Leaky ReLUs, which
6064
are fully supported by DKS, and work especially well with TAT.

dks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
# Do not directly import this package; it won't do anything. Instead, import one
1717
# of the framework-specific subpackages.
1818

19-
__version__ = "0.1.1"
19+
__version__ = "0.1.2"

dks/examples/haiku/modified_resnet.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,15 @@ def __init__(
4747
w_init: Optional[Any],
4848
name: Optional[str] = None,
4949
):
50+
5051
super().__init__(name=name)
52+
5153
self.use_projection = use_projection
5254
self.use_batch_norm = use_batch_norm
5355
self.shortcut_weight = shortcut_weight
5456

5557
if self.use_projection and self.shortcut_weight != 0.0:
58+
5659
self.proj_conv = hk.Conv2D(
5760
output_channels=channels,
5861
kernel_shape=1,
@@ -61,11 +64,13 @@ def __init__(
6164
with_bias=not use_batch_norm,
6265
padding="SAME",
6366
name="shortcut_conv")
67+
6468
if use_batch_norm:
6569
self.proj_batchnorm = hk.BatchNorm(
6670
name="shortcut_batchnorm", **BN_CONFIG)
6771

6872
channel_div = 4 if bottleneck else 1
73+
6974
conv_0 = hk.Conv2D(
7075
output_channels=channels // channel_div,
7176
kernel_shape=1 if bottleneck else 3,
@@ -87,8 +92,10 @@ def __init__(
8792
layers = (conv_0, conv_1)
8893

8994
if use_batch_norm:
95+
9096
bn_0 = hk.BatchNorm(name="batchnorm_0", **BN_CONFIG)
9197
bn_1 = hk.BatchNorm(name="batchnorm_1", **BN_CONFIG)
98+
9299
bn_layers = (bn_0, bn_1)
93100

94101
if bottleneck:
@@ -112,23 +119,31 @@ def __init__(
112119
self.activation = activation
113120

114121
def __call__(self, inputs, is_training, test_local_stats):
122+
115123
out = shortcut = inputs
116124

117125
if self.use_projection and self.shortcut_weight != 0.0:
126+
118127
shortcut = self.proj_conv(shortcut)
128+
119129
if self.use_batch_norm:
120130
shortcut = self.proj_batchnorm(shortcut, is_training, test_local_stats)
121131

122132
for i, conv_i in enumerate(self.layers):
133+
123134
out = conv_i(out)
135+
124136
if self.use_batch_norm:
125137
out = self.bn_layers[i](out, is_training, test_local_stats)
138+
126139
if i < len(self.layers) - 1: # Don't apply activation on last layer
127140
out = self.activation(out)
128141

129142
if self.shortcut_weight is None:
130143
return self.activation(out + shortcut)
144+
131145
elif self.shortcut_weight != 0.0:
146+
132147
return self.activation(
133148
math.sqrt(1 - self.shortcut_weight**2) * out +
134149
self.shortcut_weight * shortcut)
@@ -151,12 +166,15 @@ def __init__(
151166
w_init: Optional[Any],
152167
name: Optional[str] = None,
153168
):
169+
154170
super().__init__(name=name)
171+
155172
self.use_projection = use_projection
156173
self.use_batch_norm = use_batch_norm
157174
self.shortcut_weight = shortcut_weight
158175

159176
if self.use_projection and self.shortcut_weight != 0.0:
177+
160178
self.proj_conv = hk.Conv2D(
161179
output_channels=channels,
162180
kernel_shape=1,
@@ -167,6 +185,7 @@ def __init__(
167185
name="shortcut_conv")
168186

169187
channel_div = 4 if bottleneck else 1
188+
170189
conv_0 = hk.Conv2D(
171190
output_channels=channels // channel_div,
172191
kernel_shape=1 if bottleneck else 3,
@@ -188,11 +207,14 @@ def __init__(
188207
layers = (conv_0, conv_1)
189208

190209
if use_batch_norm:
210+
191211
bn_0 = hk.BatchNorm(name="batchnorm_0", **BN_CONFIG)
192212
bn_1 = hk.BatchNorm(name="batchnorm_1", **BN_CONFIG)
213+
193214
bn_layers = (bn_0, bn_1)
194215

195216
if bottleneck:
217+
196218
conv_2 = hk.Conv2D(
197219
output_channels=channels,
198220
kernel_shape=1,
@@ -205,8 +227,10 @@ def __init__(
205227
layers = layers + (conv_2,)
206228

207229
if use_batch_norm:
230+
208231
bn_2 = hk.BatchNorm(name="batchnorm_2", **BN_CONFIG)
209232
bn_layers += (bn_2,)
233+
210234
self.bn_layers = bn_layers
211235

212236
self.layers = layers
@@ -229,9 +253,11 @@ def __call__(self, inputs, is_training, test_local_stats):
229253

230254
if self.shortcut_weight is None:
231255
return x + shortcut
256+
232257
elif self.shortcut_weight != 0.0:
233258
return math.sqrt(
234259
1 - self.shortcut_weight**2) * x + self.shortcut_weight * shortcut
260+
235261
else:
236262
return x
237263

@@ -272,13 +298,17 @@ def __init__(
272298
name="block_%d" % (i)))
273299

274300
def __call__(self, inputs, is_training, test_local_stats):
301+
275302
out = inputs
303+
276304
for block in self.blocks:
277305
out = block(out, is_training, test_local_stats)
306+
278307
return out
279308

280309

281310
def check_length(length, value, name):
311+
282312
if len(value) != length:
283313
raise ValueError(f"`{name}` must be of length 4 not {len(value)}")
284314

@@ -481,12 +511,15 @@ def __init__(
481511
self.logits = hk.Linear(num_classes, **logits_config)
482512

483513
def __call__(self, inputs, is_training, test_local_stats=False):
514+
484515
out = inputs
485516
out = self.initial_conv(out)
486517

487518
if not self.resnet_v2:
519+
488520
if self.use_batch_norm:
489521
out = self.initial_batchnorm(out, is_training, test_local_stats)
522+
490523
out = self.activation(out)
491524

492525
out = hk.max_pool(
@@ -525,15 +558,18 @@ def subnet_max_func(x, r_fn, depth, shortcut_weight, resnet_v2=True):
525558

526559
if bottleneck and resnet_v2:
527560
res_fn = lambda z: r_fn(r_fn(r_fn(z)))
561+
528562
elif (not bottleneck and resnet_v2) or (bottleneck and not resnet_v2):
529563
res_fn = lambda z: r_fn(r_fn(z))
564+
530565
else:
531566
res_fn = r_fn
532567

533568
res_branch_subnetwork = res_fn(x)
534569

535570
for i in range(4):
536571
for j in range(blocks_per_group[i]):
572+
537573
res_x = res_fn(x)
538574

539575
if j == 0 and use_projection[i] and resnet_v2:

dks/jax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
"""Subpackage for JAX."""
1616

1717
from dks.jax import activation_transform
18+
from dks.jax import data_preprocessing
1819
from dks.jax import haiku_initializers
1920
from dks.jax import parameter_sampling_functions

0 commit comments

Comments
 (0)