Skip to content

9080 expose mat mul precision #9081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

yaoshiang
Copy link
Collaborator

@yaoshiang yaoshiang commented May 2, 2025

I added a binding to get the mat mul precision to the init_python_bindings.cpp

I exposed it as a brand new module, torch_xla.backends. This is an ackward name but the goal is to eventually migrate this to torch.backends.xla, to parallel torch.backends.{cuda, cpu, mps, etc}.

I got advice from on very exact numerics for default (1 pass). I made estimates on the 3 pass and 6 pass technique.

I was careful to ensure that there was a non-zero delta from the torch64 cpu calculation - it's easy to make a mistake and end up with your "reference" math also be rounded.

I have a detailed guide on this coming after this PR goes in.

@yaoshiang yaoshiang linked an issue May 2, 2025 that may be closed by this pull request
@yaoshiang
Copy link
Collaborator Author

Do I need to reduce the verbosity of my unit test?

(torch310) yho_google_com@t1v-n-2db64fb2-w-0:~/pytorch/xla/test/tpu$ ./run_tests.sh 
+++ dirname ./run_tests.sh
++ cd .
++ pwd -P
+ CDIR=/home/yho_google_com/pytorch/xla/test/tpu
++ dirname /home/yho_google_com/pytorch/xla/test/tpu
+ TEST_CDIR=/home/yho_google_com/pytorch/xla/test
+ source /home/yho_google_com/pytorch/xla/test/utils/run_tests_utils.sh
++ set -exo pipefail
+ python3 /home/yho_google_com/pytorch/xla/test/test_mat_mul_precision_highest.py
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
.
----------------------------------------------------------------------
Ran 1 test in 24.850s

OK
+ python3 /home/yho_google_com/pytorch/xla/test/test_mat_mul_precision_get_and_set.py
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
....
----------------------------------------------------------------------
Ran 4 tests in 25.572s

OK
+ python3 /home/yho_google_com/pytorch/xla/test/test_mat_mul_precision_default.py
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
.
----------------------------------------------------------------------
Ran 1 test in 26.447s

OK
+ python3 /home/yho_google_com/pytorch/xla/test/test_mat_mul_precision_high.py
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
.
----------------------------------------------------------------------
Ran 1 test in 27.707s

OK
+ python3 /home/yho_google_com/pytorch/xla/test/test_operations.py -v
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
test_mp_decorator (__main__.MpDecoratorTest) ... ok
test_multi_init_xla_backend (__main__.RegisterXLAKeyTest) ... ok
test_dropout (__main__.TestActivationCheckpoint) ... ok
test_opt_barrier (__main__.TestActivationCheckpoint) ... ok
test (__main__.TestAtenTensorTo) ... ok
test_add_mixed_device (__main__.TestAtenXlaTensor) ... ok
test_addmm_integer_types (__main__.TestAtenXlaTensor) ... skipped 'Not supported on TPU'
test_ailing_slice (__main__.TestAtenXlaTensor) ... ok
test_amp_foreach_non_finite_check_and_unscale_ (__main__.TestAtenXlaTensor) ... skipped 'Not supported on TPU'
test_amp_norm_append_dtype (__main__.TestAtenXlaTensor) ... ok
test_arange_nan (__main__.TestAtenXlaTensor) ... ok
test_as_strided_r1 (__main__.TestAtenXlaTensor) ... ok
test_as_strided_r1_dim1 (__main__.TestAtenXlaTensor) ... ok
test_as_strided_r1_slice (__main__.TestAtenXlaTensor) ... ok
test_as_strided_r1_t (__main__.TestAtenXlaTensor) ... ok
test_as_strided_r1_t_off (__main__.TestAtenXlaTensor) ... ok
test_as_strided_r1_t_slice (__main__.TestAtenXlaTensor) ... ok
test_as_strided_r2_t_update (__main__.TestAtenXlaTensor) ... ok

@yaoshiang yaoshiang requested a review from qihqi May 2, 2025 19:02


# Some of this description adapted from Jax documentation.
def set_mat_mul_precision(precision: _PrecisionType) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Should we move these functions to a separate precision.py file and import these here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion - I expected it and mentioned in a comment to look at how torch.backends.mps and cpu is implemented. Trying to stay parallel to them.


# 22 bits is an estimate of precision in the
# six pass technique.
worst_atol = torch.tensor(1 - ((2**22 - 1) / 2**22)**2, dtype=torch.float64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know the derivation of this calculation? I was looking at
image

And trying to calculate an upper bound on error, a=1 + eps, b = 1 + eps, Q = a*b, I reached an upper bound of sqrt(2)*eps

Copy link
Collaborator Author

@yaoshiang yaoshiang May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have heard from an authoritative source (I'll share offline) that the correct error for 1 pass mode (default) is 1 - (255/256)**2. I did not get as Clear of guidance on 3 and 6 pass mode so I empirically derived tighter bounds. I believe that there's an argument that 6 rounds of bf16 basically gets you fp32, but that there could still be some delta on the final bit due to implementation details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose mat_mul_precision
2 participants