Skip to content

Commit 22ab5a3

Browse files
Improve error message for class_weight with PyTorch DataLoaders (#21414)
* Improve error message for class_weight with PyTorch DataLoaders * Improve error message for class_weight with PyTorch DataLoaders
1 parent e32175a commit 22ab5a3

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

keras/src/trainers/data_adapters/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,12 @@ def get_data_adapter(
9797
if class_weight is not None:
9898
raise ValueError(
9999
"Argument `class_weight` is not supported for torch "
100-
f"DataLoader inputs. Received: class_weight={class_weight}"
100+
f"DataLoader inputs. You can modify your `__getitem__ ` method"
101+
" to return input tensor, label and class_weight. "
102+
"Alternatively, use a custom training loop. See the User Guide "
103+
"https://keras.io/guides/custom_train_step_in_torch/"
104+
"#supporting-sampleweight-amp-classweight for more details. "
105+
f"Received: class_weight={class_weight}"
101106
)
102107
return TorchDataLoaderAdapter(x)
103108
# TODO: should we warn or not?

0 commit comments

Comments
 (0)