Skip to content

Significant performance regression #21316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
rivershah opened this issue May 22, 2025 · 5 comments
Closed

Significant performance regression #21316

rivershah opened this issue May 22, 2025 · 5 comments

Comments

@rivershah
Copy link

rivershah commented May 22, 2025

I don't have minimum viable example to provide, but on tpu there has been a significant performance regression.

git diff v3.10.0 pr-21187 -- keras/src/backend/jax/

Training step time across many models are 2-3x. Please review and fix.

GPUs are fine. TPUs + jax are exhibiting the 2-3x slowdown

@dhantule
Copy link
Contributor

dhantule commented May 23, 2025

Hi @rivershah, thanks for the report.
I will need some more details to reproduce the issue. Could you tell me which models you’ve encountered this problem with?

@rivershah
Copy link
Author

I will write a minimal reproducible example and share. Apologies for not filing it with issue report. I will expend some time doing this in the next few days. Thank you.

@rivershah
Copy link
Author

Have tried with a simple model in colab; can't replicate issue. I need to dig more carefully for custom models. Investigating if issue upstream in keras or something else breaking downstream in between the diffs of pr-21187 vs v3.10.0:

!pip install -U keras==3.10.0
!pip install -U jax[tpu]==0.6.0
!pip install -U tensorflow==2.19.0

import os
os.environ["KERAS_BACKEND"] = "jax"

import jax
import keras
from keras import layers

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train[..., None]
x_test = x_test[..., None]

data_parallel = keras.distribution.DataParallel()
keras.distribution.set_distribution(data_parallel)

process_index = jax.process_index()
process_count = jax.process_count()
device_count = jax.device_count()
local_device_count = jax.local_device_count()
print(
    f"process_index: {process_index} "
    f"process_count: {process_count} "
    f"device_count: {device_count} "
    f"local_device_count: {local_device_count}"
)

model = keras.Sequential([
    layers.Input((28, 28, 1)),
    layers.Conv2D(32, 3, activation="relu"),
    layers.Conv2D(32, 3, activation="relu"),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, activation="relu"),
    layers.Conv2D(64, 3, activation="relu"),
    layers.MaxPooling2D(),
    layers.Conv2D(128, 3, activation="relu"),
    layers.Conv2D(128, 3, activation="relu"),
    layers.Flatten(),
    layers.Dense(128, activation="relu"),
    layers.Dense(10, activation="softmax")
])

model.compile(optimizer="adam", 
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(x_train, y_train, epochs=10, 
          batch_size=128, validation_split=0.1, verbose=1,)

@rivershah
Copy link
Author

Issue was downstream with the dataset distribution diffs. Fixed

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants