Skip to content

Commit 9e83e6c

Browse files
committed
enable to_mxfp8 cast for DTensor
Summary: 1. add a test for casting a DTensor to mxfp8 2. make the test pass: a. remove addition of epsilon, it's not supported in DTensor world but we also don't need it anymore since we are no longer using `log2` anywhere. b. replace `<<` with `torch.bitwise_left_shift` and `>>` with `torch.bitwise_right_shift`. The short versions are silently broken for DTensor inputs, but the verbose versions work. 3. set up the wiring for testing mxfp8 with TP on a toy model. Note that making this work is split for the next PR, as this PR got too large. Test Plan: ```bash ./test/prototype/mx_formats/test_dtensor.sh ./test/float8/test_dtensor.sh pytest test/prototype/mx_formats/ ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e054ed4 ghstack-comment-id: 2993264092 Pull Request resolved: #2420
1 parent 4362c9f commit 9e83e6c

File tree

6 files changed

+866
-18
lines changed

6 files changed

+866
-18
lines changed

0 commit comments

Comments
 (0)