You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
1. add a test for casting a DTensor to mxfp8
2. make the test pass:
a. remove addition of epsilon, it's not supported in DTensor world but
we also don't need it anymore since we are no longer using `log2`
anywhere.
b. replace `<<` with `torch.bitwise_left_shift` and `>>` with
`torch.bitwise_right_shift`. The short versions are silently broken
for DTensor inputs, but the verbose versions work.
3. set up the wiring for testing mxfp8 with TP on a toy model. Note that
making this work is split for the next PR, as this PR got too large.
Test Plan:
```bash
./test/prototype/mx_formats/test_dtensor.sh
./test/float8/test_dtensor.sh
pytest test/prototype/mx_formats/
```
Reviewers:
Subscribers:
Tasks:
Tags:
ghstack-source-id: e054ed4
ghstack-comment-id: 2993264092
Pull Request resolved: #2420
0 commit comments