Skip to content

Keras 3 may not work with PyTorch DirectML #21228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
TaiXeflar opened this issue Apr 30, 2025 · 6 comments
Open

Keras 3 may not work with PyTorch DirectML #21228

TaiXeflar opened this issue Apr 30, 2025 · 6 comments
Assignees
Labels
stat:contributions welcome A pull request to fix this issue would be welcome. type:Bug

Comments

@TaiXeflar
Copy link

I think this might be a tech issue so I opened issue from disscutions to here.

Issue description:
My test on Keras 3 may cannot set selected device correctly with DirectML based PyTorch backend. I looked #21190 and still can't find some useful anwers.

Backgrounds:
I have a Windows 11 Machine, contains 2 cards with NVIDIA GeForce RTX 4070ti Super and AMD Radeon RX 7800XT respectively.
Then I deployed a Python 3.11.9 env installed torch-directml, by using DirectML backend to use PyTorch. In PyTorch we can manually use .to(device) to set GPU device, and can do check like this:

import torch_directml

# List all available DML devices
device_count = torch_directml.device_count()
print(f"Available DirectML devices: {device_count}")

# Loop over devices and print info
for i in range(device_count):
    dml_device = torch_directml.device(i)
    print(f"Device {i}: {dml_device}")

# Select device 1 (example)
dml = torch_directml.device(1)
Available DirectML devices: 2
Device 0: privateuseone:0
Device 1: privateuseone:1

From the above codes and the result, we can know that device 1 is the AMD card.
How can I manage/add preprocess "on DirectML device select" in Keras? If we just do nothing, the model will be run on CPU only.

Hardware and Software/Environ:

Intel Xeon W-3175X
NVIDIA GeForce RTX 4070ti Super
AMD Radeon RX 7800XT

Windows 11 24H2
Python 3.13.2 ---> PyTorch 2.6.0+cu126, Keras 3            # The NVIDIA CUDA one
Python 3.11.9 ---> PyTorch-DirectML, Keras 3                 # The DirectML for AMD card environment
@dhantule dhantule added type:Bug keras-team-review-pending Pending review by a Keras team member. labels May 2, 2025
@SamanehSaadat
Copy link
Member

Hi @TaiXeflar
There might be some internal work we need to do to support this.
Would you like to take on this and contribute to Keras? If not, we need to wait for another contributor to take on this or the Keras team to plan for supporting this.

@SamanehSaadat SamanehSaadat added stat:contributions welcome A pull request to fix this issue would be welcome. and removed keras-team-review-pending Pending review by a Keras team member. labels May 9, 2025
@TaiXeflar
Copy link
Author

Ok. How can I support this?

@SamanehSaadat
Copy link
Member

Fixing this requires investigation on root cases of DirectML not working in the Keras codebase. One example is that there are instances of hard-coded cuda in the code like here. Fixing this requires identifying the root causes of the issue and proposing solutions to fix them.

@TaiXeflar
Copy link
Author

TaiXeflar commented May 10, 2025

I've got something works on the early state just for few hours ago, but I need more tests to work with like MNIST example.
There's something weird that running PyTorch-DirectML extremely slow like Attn based layers.

I will keep this thread up.

@TaiXeflar
Copy link
Author

@SamanehSaadat
I created a venv for a initial local test.

At first, I use a ENV variable do a control to enable torch-directml. Then I adding this to keras/src/backend/torch/core.py:

import os
if os.environ["KERAS_TORCH_BACKEND"] == "DirectML":
    import torch_directml
    DEFAULT_DEVICE = "privateuseone:1"

note that the device name is already known where the AMD device card is privateuseone:1.

I do a non-linear regression test (the dataset is from myself) and find it works:

Image

Then I do a MNIST example, I found there's something wrong at some compute has fallback to CPU.

Image

That's all what I've tested. If there's new test gets useful, I'll just report here.

@TaiXeflar
Copy link
Author

I think there's no rules to specifying device when Keras using torch-directml based PyTorch backend. So I'm doing some detection test here...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contributions welcome A pull request to fix this issue would be welcome. type:Bug
Projects
None yet
Development

No branches or pull requests

4 participants