You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It seems that layers.Rescaling will broadcast its input to match the size of the provided gain/offset parameters.
This may or may not be expected behaviour (it's not clear from the documentation).
What definitely seems wrong, however, is that this broadcasting is only occurring if I make a prediction with the model, but not when I print model.summary(), or when I query the model.output_shape parameter.
If I train the model against a target with a different shape, no errors are raised. This also does not seem like it should be expected behaviour.
Code to reproduce:
import numpy as np
import keras
from keras import layers
x = np.zeros((8, 100, 1))
y_target = np.zeros((8, 100, 1))
model = keras.Sequential(
[
layers.Input(shape=(None, 1), name="Feature"),
layers.Rescaling([1.0, 1.0], [0.0, 0.0]),
]
)
model.summary()
y_pred = model(x)
print(f"Expected model output shape: {model.output_shape}")
print(f"True model output shape: {y_pred.shape}")
print(f"Target shape: {y_target.shape}")
model.compile(
optimizer="adam",
loss="mse",
)
model.fit(x, y_target, epochs=1, batch_size=8)
print("Fit executed with no errors, despite inconsistent shapes.")
Versions:
keras==3.8.0
jax==0.5.0
The text was updated successfully, but these errors were encountered:
It seems that layers.Rescaling will broadcast its input to match the size of the provided gain/offset parameters.
This may or may not be expected behaviour (it's not clear from the documentation).
What definitely seems wrong, however, is that this broadcasting is only occurring if I make a prediction with the model, but not when I print model.summary(), or when I query the model.output_shape parameter.
If I train the model against a target with a different shape, no errors are raised. This also does not seem like it should be expected behaviour.
Code to reproduce:
Versions:
keras==3.8.0
jax==0.5.0
The text was updated successfully, but these errors were encountered: