Closed
Description
🐞 Bug Description
There are many cases where Dense
layers are applied across multiple batch dimensions. However, when wrapped in a SpectralNormalization
layer, the layer only works for the batch shape seen during building. The error comes from the rigid definition of the layer's InputSpec
.
✅ Expected Behavior
The dense layer wrapper in a SpectralNormalization
layer should behave in the same way as a Dense
layer (i.e., be applicable to varying batch shapes).
🔁 Steps to Reproduce
spectral_dense = keras.layers.SpectralNormalization(keras.layers.Dense(32))
dense = keras.layers.Dense(32)
x1 = keras.random.normal((2, 4))
# Second input will have an additional batch dim
x2 = keras.random.normal((2, 5, 4))
# Fine, builds layer
dense(x1)
# Properly performs tensordot with multiple batch dims
dense(x2)
# Fine, builds layer
spectral_dense(x1)
# Fails: ValueError: Input 0 of layer "spectral_normalization_1" is incompatible with the layer: expected ndim=2, found ndim=3.
spectral_dense(x2)
🧪 Environment
The error occurs on all backends.