|
| 1 | +from typing import Optional, Union, Tuple, Dict |
| 2 | + |
| 3 | +import equinox as eqx |
| 4 | +import chex |
| 5 | +import jax |
| 6 | +import jax.numpy as jnp |
| 7 | + |
| 8 | +import e3nn_jax as e3nn |
| 9 | +from e3nn_jax._src.utils.dtype import get_pytree_dtype |
| 10 | + |
| 11 | +from .linear import ( |
| 12 | + FunctionalLinear, |
| 13 | + linear_indexed, |
| 14 | + linear_mixed, |
| 15 | + linear_mixed_per_channel, |
| 16 | + linear_vanilla, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +def _get_gradient_normalization( |
| 21 | + gradient_normalization: Optional[Union[float, str]] |
| 22 | +) -> float: |
| 23 | + """Get the gradient normalization from the config or from the argument.""" |
| 24 | + if gradient_normalization is None: |
| 25 | + gradient_normalization = e3nn.config("gradient_normalization") |
| 26 | + if isinstance(gradient_normalization, str): |
| 27 | + return {"element": 0.0, "path": 1.0}[gradient_normalization] |
| 28 | + return gradient_normalization |
| 29 | + |
| 30 | + |
| 31 | +class Linear(eqx.Module): |
| 32 | + r"""Equivariant Linear Flax module |
| 33 | +
|
| 34 | + Args: |
| 35 | + irreps_out (`Irreps`): output representations, if allowed bu Schur's lemma. |
| 36 | + channel_out (optional int): if specified, the last axis before the irreps |
| 37 | + is assumed to be the channel axis and is mixed with the irreps. |
| 38 | + irreps_in (`Irreps`): input representations. If not specified, |
| 39 | + the input representations is obtained when calling the module. |
| 40 | + channel_in (optional int): required when using 'mixed_per_channel' linear_type, |
| 41 | + indicating the size of the last axis before the irreps in the input. |
| 42 | + biases (bool): whether to add a bias to the output. |
| 43 | + path_normalization (str or float): Normalization of the paths, ``element`` or ``path``. |
| 44 | + 0/1 corresponds to a normalization where each element/path has an equal contribution to the forward. |
| 45 | + gradient_normalization (str or float): Normalization of the gradients, ``element`` or ``path``. |
| 46 | + 0/1 corresponds to a normalization where each element/path has an equal contribution to the learning. |
| 47 | + num_indexed_weights (optional int): number of indexed weights. See example below. |
| 48 | + weights_per_channel (bool): whether to have one set of weights per channel. |
| 49 | + force_irreps_out (bool): whether to force the output irreps to be the one specified in ``irreps_out``. |
| 50 | +
|
| 51 | + Due to how Equinox is implemented, the random key, irreps_in and irreps_out must be supplied at initialization. |
| 52 | + The type of the linear layer must also be supplied at initialization: |
| 53 | + 'vanilla', 'indexed', 'mixed', 'mixed_per_channel' |
| 54 | + Also, depending on what type of linear layer is used, additional options |
| 55 | + (eg. 'num_indexed_weights', 'weights_per_channel', 'weights_dim', 'channel_in') |
| 56 | + must be supplied. |
| 57 | +
|
| 58 | + Examples: |
| 59 | + Vanilla:: |
| 60 | +
|
| 61 | + >>> import e3nn_jax as e3nn |
| 62 | + >>> import jax |
| 63 | +
|
| 64 | + >>> x = e3nn.normal("0e + 1o") |
| 65 | + >>> linear = e3nn.equinox.Linear( |
| 66 | + irreps_out="2x0e + 1o + 2e", |
| 67 | + irreps_in=x.irreps, |
| 68 | + key=jax.random.PRNGKey(0), |
| 69 | + ) |
| 70 | + >>> linear(x).irreps # Note that the 2e is discarded. Avoid this by setting force_irreps_out=True. |
| 71 | + 2x0e+1x1o |
| 72 | + >>> linear(x).shape |
| 73 | + (5,) |
| 74 | +
|
| 75 | + External weights:: |
| 76 | +
|
| 77 | + >>> linear = e3nn.equinox.Linear( |
| 78 | + irreps_out="2x0e + 1o", |
| 79 | + irreps_in=x.irreps, |
| 80 | + linear_type="mixed", |
| 81 | + weights_dim=4, |
| 82 | + key=jax.random.PRNGKey(0), |
| 83 | + ) |
| 84 | + >>> e = jnp.array([1., 2., 3., 4.]) |
| 85 | + >>> linear(e, x).irreps |
| 86 | + 2x0e+1x1o |
| 87 | + >>> linear(e, x).shape |
| 88 | + (5,) |
| 89 | +
|
| 90 | + Indexed weights:: |
| 91 | +
|
| 92 | + >>> linear = e3nn.equinox.Linear( |
| 93 | + irreps_out="2x0e + 1o + 2e", |
| 94 | + irreps_in=x.irreps, |
| 95 | + linear_type="indexed", |
| 96 | + num_indexed_weights=3, |
| 97 | + key=jax.random.PRNGKey(0), |
| 98 | + ) |
| 99 | + >>> i = jnp.array(2) |
| 100 | + >>> linear(i, x).irreps |
| 101 | + 2x0e+1x1o |
| 102 | + >>> linear(i, x).shape |
| 103 | + (5,) |
| 104 | + """ |
| 105 | + irreps_out: e3nn.Irreps |
| 106 | + irreps_in: e3nn.Irreps |
| 107 | + channel_out: int |
| 108 | + channel_in: int |
| 109 | + gradient_normalization: Optional[Union[float, str]] |
| 110 | + path_normalization: Optional[Union[float, str]] |
| 111 | + biases: bool |
| 112 | + num_indexed_weights: Optional[int] |
| 113 | + weights_per_channel: bool |
| 114 | + force_irreps_out: bool |
| 115 | + weights_dim: Optional[int] |
| 116 | + linear_type: str |
| 117 | + |
| 118 | + # These are used internally. |
| 119 | + _linear: FunctionalLinear |
| 120 | + _weights: Dict[str, jnp.ndarray] |
| 121 | + _input_dtype: jnp.dtype |
| 122 | + |
| 123 | + def __init__( |
| 124 | + self, |
| 125 | + *, |
| 126 | + irreps_out: e3nn.Irreps, |
| 127 | + irreps_in: e3nn.Irreps, |
| 128 | + channel_out: Optional[int] = None, |
| 129 | + channel_in: Optional[int] = None, |
| 130 | + biases: bool = False, |
| 131 | + path_normalization: Optional[Union[str, float]] = None, |
| 132 | + gradient_normalization: Optional[Union[str, float]] = None, |
| 133 | + num_indexed_weights: Optional[int] = None, |
| 134 | + weights_per_channel: bool = False, |
| 135 | + force_irreps_out: bool = False, |
| 136 | + weights_dim: Optional[int] = None, |
| 137 | + input_dtype: jnp.dtype = jnp.float32, |
| 138 | + linear_type: str = "vanilla", |
| 139 | + key: chex.PRNGKey, |
| 140 | + ): |
| 141 | + irreps_in_regrouped = e3nn.Irreps(irreps_in).regroup() |
| 142 | + irreps_out = e3nn.Irreps(irreps_out) |
| 143 | + |
| 144 | + self.irreps_in = irreps_in_regrouped |
| 145 | + self.channel_in = channel_in |
| 146 | + self.channel_out = channel_out |
| 147 | + self.biases = biases |
| 148 | + self.path_normalization = path_normalization |
| 149 | + self.num_indexed_weights = num_indexed_weights |
| 150 | + self.weights_per_channel = weights_per_channel |
| 151 | + self.force_irreps_out = force_irreps_out |
| 152 | + self.linear_type = linear_type |
| 153 | + self.weights_dim = weights_dim |
| 154 | + self._input_dtype = input_dtype |
| 155 | + |
| 156 | + self.gradient_normalization = _get_gradient_normalization( |
| 157 | + gradient_normalization |
| 158 | + ) |
| 159 | + |
| 160 | + channel_irrep_multiplier = 1 |
| 161 | + if self.channel_out is not None: |
| 162 | + assert not self.weights_per_channel |
| 163 | + channel_irrep_multiplier = self.channel_out |
| 164 | + |
| 165 | + if not self.force_irreps_out: |
| 166 | + irreps_out = irreps_out.filter(keep=irreps_in_regrouped) |
| 167 | + irreps_out = irreps_out.simplify() |
| 168 | + self.irreps_out = irreps_out |
| 169 | + |
| 170 | + self._linear = FunctionalLinear( |
| 171 | + irreps_in_regrouped, |
| 172 | + channel_irrep_multiplier * irreps_out, |
| 173 | + biases=self.biases, |
| 174 | + path_normalization=self.path_normalization, |
| 175 | + gradient_normalization=self.gradient_normalization, |
| 176 | + ) |
| 177 | + self._weights = self._get_weights(key) |
| 178 | + |
| 179 | + def _get_weights(self, key: chex.PRNGKey): |
| 180 | + """Constructs the weights for the linear module.""" |
| 181 | + irreps_in = self._linear.irreps_in |
| 182 | + irreps_out = self._linear.irreps_out |
| 183 | + |
| 184 | + weights = {} |
| 185 | + for ins in self._linear.instructions: |
| 186 | + weight_key, key = jax.random.split(key) |
| 187 | + if ins.i_in == -1: |
| 188 | + name = f"b[{ins.i_out}] {irreps_out[ins.i_out]}" |
| 189 | + else: |
| 190 | + name = f"w[{ins.i_in},{ins.i_out}] {irreps_in[ins.i_in]},{irreps_out[ins.i_out]}" |
| 191 | + |
| 192 | + if self.linear_type == "vanilla": |
| 193 | + weight_shape = ins.path_shape |
| 194 | + weight_std = ins.weight_std |
| 195 | + |
| 196 | + if self.linear_type == "indexed": |
| 197 | + if self.num_indexed_weights is None: |
| 198 | + raise ValueError( |
| 199 | + "num_indexed_weights must be provided when 'linear_type' is 'indexed'" |
| 200 | + ) |
| 201 | + |
| 202 | + weight_shape = (self.num_indexed_weights,) + ins.path_shape |
| 203 | + weight_std = ins.weight_std |
| 204 | + |
| 205 | + if self.linear_type in ["mixed", "mixed_per_channel"]: |
| 206 | + if self.weights_dim is None: |
| 207 | + raise ValueError( |
| 208 | + "weights_dim must be provided when 'linear_type' is 'mixed'" |
| 209 | + ) |
| 210 | + |
| 211 | + d = self.weights_dim |
| 212 | + if self.linear_type == "mixed": |
| 213 | + weight_shape = (d,) + ins.path_shape |
| 214 | + |
| 215 | + if self.linear_type == "mixed_per_channel": |
| 216 | + if self.channel_in is None: |
| 217 | + raise ValueError( |
| 218 | + "channel_in must be provided when 'linear_type' is 'mixed_per_channel'" |
| 219 | + ) |
| 220 | + weight_shape = (d, self.channel_in) + ins.path_shape |
| 221 | + |
| 222 | + alpha = 1 / d |
| 223 | + stddev = jnp.sqrt(alpha) ** (1.0 - self.gradient_normalization) |
| 224 | + weight_std = stddev * ins.weight_std |
| 225 | + |
| 226 | + weights[name] = weight_std * jax.random.normal( |
| 227 | + weight_key, |
| 228 | + weight_shape, |
| 229 | + self._input_dtype, |
| 230 | + ) |
| 231 | + return weights |
| 232 | + |
| 233 | + def __call__(self, weights_or_input, input_or_none=None) -> e3nn.IrrepsArray: |
| 234 | + """Apply the linear operator. |
| 235 | +
|
| 236 | + Args: |
| 237 | + weights (optional IrrepsArray or jnp.ndarray): scalar weights that are contracted with free parameters. |
| 238 | + An array of shape ``(..., contracted_axis)``. Broadcasting with `input` is supported. |
| 239 | + input (IrrepsArray): input irreps-array of shape ``(..., [channel_in,] irreps_in.dim)``. |
| 240 | + Broadcasting with `weights` is supported. |
| 241 | +
|
| 242 | + Returns: |
| 243 | + IrrepsArray: output irreps-array of shape ``(..., [channel_out,] irreps_out.dim)``. |
| 244 | + Properly normalized assuming that the weights and input are properly normalized. |
| 245 | + """ |
| 246 | + if input_or_none is None: |
| 247 | + weights = None |
| 248 | + input: e3nn.IrrepsArray = weights_or_input |
| 249 | + else: |
| 250 | + weights: jnp.ndarray = weights_or_input |
| 251 | + input: e3nn.IrrepsArray = input_or_none |
| 252 | + del weights_or_input, input_or_none |
| 253 | + |
| 254 | + input = e3nn.as_irreps_array(input) |
| 255 | + |
| 256 | + dtype = get_pytree_dtype(weights, input) |
| 257 | + if dtype.kind == "i": |
| 258 | + dtype = jnp.float32 |
| 259 | + input = input.astype(dtype) |
| 260 | + |
| 261 | + if self.irreps_in != input.irreps.regroup(): |
| 262 | + raise ValueError( |
| 263 | + f"e3nn.equinox.Linear: The input irreps ({input.irreps}) " |
| 264 | + f"do not match the expected irreps ({self.irreps_in})." |
| 265 | + ) |
| 266 | + |
| 267 | + if self.channel_in is not None: |
| 268 | + if self.channel_in != input.shape[-2]: |
| 269 | + raise ValueError( |
| 270 | + f"e3nn.equinox.Linear: The input channel ({input.shape[-2]}) " |
| 271 | + f"does not match the expected channel ({self.channel_in})." |
| 272 | + ) |
| 273 | + |
| 274 | + input = input.remove_zero_chunks().regroup() |
| 275 | + |
| 276 | + def get_parameter( |
| 277 | + name: str, |
| 278 | + path_shape: Tuple[int, ...], |
| 279 | + weight_std: float, |
| 280 | + dtype: jnp.dtype = jnp.float32, |
| 281 | + ): |
| 282 | + del path_shape, weight_std, dtype |
| 283 | + return self._weights[name] |
| 284 | + |
| 285 | + assertion_message = ( |
| 286 | + "Weights cannot be provided when 'linear_type' is 'vanilla'." |
| 287 | + "Otherwise, weights must be provided." |
| 288 | + "If weights are provided, they must be either: \n" |
| 289 | + "* integers and num_indexed_weights must be provided, or \n" |
| 290 | + "* floats and num_indexed_weights must not be provided.\n" |
| 291 | + f"weights.dtype={weights.dtype if weights is not None else None}, " |
| 292 | + f"num_indexed_weights={self.num_indexed_weights}" |
| 293 | + ) |
| 294 | + |
| 295 | + if self.linear_type == "vanilla": |
| 296 | + assert weights is None, assertion_message |
| 297 | + output = linear_vanilla(input, self._linear, get_parameter) |
| 298 | + |
| 299 | + if self.linear_type in ["indexed", "mixed", "mixed_per_channel"]: |
| 300 | + assert weights is not None, assertion_message |
| 301 | + if isinstance(weights, e3nn.IrrepsArray): |
| 302 | + if not weights.irreps.is_scalar(): |
| 303 | + raise ValueError("weights must be scalar") |
| 304 | + weights = weights.array |
| 305 | + |
| 306 | + if self.linear_type == "indexed": |
| 307 | + assert weights.dtype.kind == "i", assertion_message |
| 308 | + if self.weights_per_channel: |
| 309 | + raise NotImplementedError( |
| 310 | + "weights_per_channel not implemented for indexed weights" |
| 311 | + ) |
| 312 | + |
| 313 | + output = linear_indexed( |
| 314 | + input, self._linear, get_parameter, weights, self.num_indexed_weights |
| 315 | + ) |
| 316 | + |
| 317 | + if self.linear_type in ["mixed", "mixed_per_channel"]: |
| 318 | + assert weights.dtype.kind in "fc", assertion_message |
| 319 | + assert self.num_indexed_weights is None, assertion_message |
| 320 | + |
| 321 | + if self.linear_type == "mixed": |
| 322 | + output = linear_mixed( |
| 323 | + input, |
| 324 | + self._linear, |
| 325 | + get_parameter, |
| 326 | + weights, |
| 327 | + self.gradient_normalization, |
| 328 | + ) |
| 329 | + |
| 330 | + if self.linear_type == "mixed_per_channel": |
| 331 | + output = linear_mixed_per_channel( |
| 332 | + input, |
| 333 | + self._linear, |
| 334 | + get_parameter, |
| 335 | + weights, |
| 336 | + self.gradient_normalization, |
| 337 | + ) |
| 338 | + |
| 339 | + if self.channel_out is not None: |
| 340 | + output = output.mul_to_axis(self.channel_out) |
| 341 | + |
| 342 | + return output.rechunk(self.irreps_out) |
0 commit comments