-
Notifications
You must be signed in to change notification settings - Fork 510
Description
Feature request
- I want to make this library compatible with Apple's GPU but it needs two lines of code to be modified.
What is the expected behavior?
-
Currently, running a training on Apple's GPU almost works by setting the
device_name
to"mps"
. Yet, at the end of the training whenTabModel.explain
method is called, it raises an error. -
Specifically, if I started the training with the following initialization of the model, the line 354 of
TabModel.explain
raisesTypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
regressor = TabNetRegressor( optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=2e-2), device_name="mps", mask_type="entmax", )
Then the error comes from the following line. This is because the
data
is in float64 while Apple's GPU only supports float32.tabnet/pytorch_tabnet/abstract_model.py
Lines 353 to 356 in 2c0c4eb
for batch_nb, data in enumerate(dataloader): data = data.to(self.device).float() M_explain, masks = self.network.forward_masks(data)
What is motivation or use case for adding/changing the behavior?
- I believe utilizing GPU on training benefits users of Apple computers.
How should this be implemented in your opinion?
- I confirmed that adding the two lines below to the method solves the issue.
for batch_nb, data in enumerate(dataloader):
+ if self.device == torch.device("mps"):
+ data = data.to(torch.float32)
data = data.to(self.device).float()
M_explain, masks = self.network.forward_masks(data)
Are you willing to work on this yourself?
yes