-
Notifications
You must be signed in to change notification settings - Fork 546
9080 expose mat mul precision #9081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
86200a5
initial commit: binding and backends package, no tests.
yaoshiang 2901b77
Tests for default, high, and highest precision.
yaoshiang 564685a
clang-format
yaoshiang 4c0b52e
formatter
yaoshiang 0d3f44f
Updates to error messages
yaoshiang 3c333d8
fixed test class names
yaoshiang 967ccda
typo
yaoshiang 393257e
typo on error message. unit tested and yapfed.
yaoshiang f48d7f6
linter
yaoshiang 114456b
minor edits.
yaoshiang 87cff3d
Updated TODO per review
yaoshiang b96f671
Update todo and precision math per comment.
yaoshiang c26ad8f
Merge branch 'master' into 9080-expose-mat_mul_precision
yaoshiang 30d1c19
yapf
yaoshiang 959b0dc
linter
yaoshiang 8aaa979
parameterized, but in a process isolated way.
yaoshiang 04516c7
removed dead code
yaoshiang c3856c8
added issue for repeatable, unexpected behavior.
yaoshiang 79af416
Updated docstring.
yaoshiang 6ff7f59
changed naming of is_on_tpu
yaoshiang 89c72a8
Merge branch 'master' into 9080-expose-mat_mul_precision
yaoshiang 9db6e3e
CICD friendly version, hopefully.
yaoshiang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Numeric tests for default precision of mat mul. | ||
|
||
There are three similar test files, suffixed default, high, and highest. | ||
Unfortunately, the precision cannot reliably | ||
be dynamically changed between tensor operations, so | ||
the tests are split into three files to ensure a fresh | ||
environment for each test. | ||
""" | ||
|
||
import os | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
import torch_xla.runtime as xr | ||
|
||
|
||
def _is_on_tpu(): | ||
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' | ||
|
||
|
||
skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') | ||
|
||
|
||
class TestMatMulPrecisionDefault(unittest.TestCase): | ||
|
||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _make_input(self): | ||
eye = torch.eye(1024, device='cpu', dtype=torch.float64) | ||
rand_ = torch.testing.make_tensor((1024, 1024), | ||
dtype=torch.float64, | ||
device="cpu", | ||
low=0.99, | ||
high=1.01) | ||
return eye * rand_ | ||
|
||
# DO NOT add epsilons to this test. These tests must be numerically exact. | ||
@skipIfNotTpu | ||
def test_mat_mul_precision_numerics_default(self): | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('default') | ||
|
||
# Diagonal matrices force mat mul through MXU | ||
# but require only one non-zero accumulation. | ||
x = self._make_input() | ||
y = self._make_input() | ||
reference_float64 = torch.matmul(x, y) | ||
|
||
# This is slightly worse than 1/2**[mantissa bits] because in multiplication, both | ||
# operands may lose 1/256 (=2^8, where 8 = 7 + 1, | ||
# where 7 is the mantissa and 1 is the implicit bit). | ||
worst_atol = torch.tensor(1 - ((2**8 - 1) / 2**8)**2, dtype=torch.float64) | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
x = x.to(torch.float32).to('xla') | ||
y = y.to(torch.float32).to('xla') | ||
|
||
# Act | ||
actual = torch.matmul(x, y).to('cpu').to(torch.float64) | ||
|
||
# Disable rtol, we know exactly the atol for default, high, and highest. | ||
torch.testing.assert_close( | ||
actual, | ||
reference_float64, | ||
rtol=0.0, | ||
atol=worst_atol, | ||
) | ||
|
||
assert not torch.equal(actual, reference_float64), ( | ||
"Actual product and reference product should not be equal, " | ||
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main(verbosity=0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp""" | ||
|
||
import sys | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
|
||
|
||
class TestMatMulPrecisionGetAndSet(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self._original = torch_xla.backends.get_mat_mul_precision() | ||
torch.set_printoptions(precision=20) | ||
torch_xla.sync() | ||
|
||
def tearDown(self): | ||
torch_xla.backends.set_mat_mul_precision(self._original) | ||
torch.set_printoptions(profile="default") | ||
torch_xla.sync() | ||
|
||
def test_set_mat_mul_precision_error(self): | ||
# Assert | ||
with self.assertRaises(ValueError): | ||
# Act | ||
torch_xla.backends.set_mat_mul_precision('BAD VALUE') | ||
|
||
def test_get_and_set_mat_mul_precision_default(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('default') | ||
|
||
# Act | ||
status = torch_xla.backends.get_mat_mul_precision() | ||
|
||
# Assert | ||
self.assertEqual(status, 'default') | ||
|
||
def test_get_and_set_mat_mul_precision_high(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('high') | ||
|
||
# Act | ||
status = torch_xla.backends.get_mat_mul_precision() | ||
|
||
# Assert | ||
self.assertEqual(status, 'high') | ||
|
||
def test_get_and_set_mat_mul_precision_highest(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('highest') | ||
|
||
# Act | ||
status = torch_xla.backends.get_mat_mul_precision() | ||
|
||
# Assert | ||
self.assertEqual(status, 'highest') | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Numeric tests for high precision of mat mul. | ||
|
||
There are three similar test files, suffixed default, high, and highest. | ||
Unfortunately, the precision cannot reliably | ||
be dynamically changed between tensor operations, so | ||
the tests are split into three files to ensure a fresh | ||
environment for each test. | ||
""" | ||
|
||
import os | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
import torch_xla.runtime as xr | ||
|
||
|
||
def _is_on_tpu(): | ||
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' | ||
|
||
|
||
skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') | ||
|
||
|
||
class TestMatMulPrecisionHigh(unittest.TestCase): | ||
|
||
def _make_input(self): | ||
eye = torch.eye(1024, device='cpu', dtype=torch.float64) | ||
rand_ = torch.testing.make_tensor((1024, 1024), | ||
dtype=torch.float64, | ||
device="cpu", | ||
low=0.99, | ||
high=1.01) | ||
return eye * rand_ | ||
|
||
# DO NOT add epsilons to this test. These tests must be numerically exact. | ||
@skipIfNotTpu | ||
def test_mat_mul_precision_numerics_high(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('high') | ||
|
||
# Diagonal matrices force mat mul through MXU | ||
# but require only one non-zero accumulation. | ||
x = self._make_input() | ||
y = self._make_input() | ||
reference_float64 = torch.matmul(x, y) | ||
|
||
# 14 bits is an estimate of precision in the | ||
# three pass technique. | ||
worst_atol = torch.tensor(1 - ((2**14 - 1) / 2**14)**2, dtype=torch.float64) | ||
|
||
x = x.to(torch.float32).to('xla') | ||
y = y.to(torch.float32).to('xla') | ||
|
||
# Act | ||
actual = torch.matmul(x, y).to('cpu').to(torch.float64) | ||
|
||
# Disable rtol, we know exactly the atol for default, high, and highest. | ||
torch.testing.assert_close( | ||
actual, | ||
reference_float64, | ||
rtol=0.0, | ||
atol=worst_atol, | ||
) | ||
|
||
assert not torch.equal(actual, reference_float64), ( | ||
"Actual product and reference product should not be equal, " | ||
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main(verbosity=0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Numeric tests for highest precision of mat mul. | ||
|
||
There are three similar test files, suffixed default, high, and highest. | ||
Unfortunately, the precision cannot reliably | ||
be dynamically changed between tensor operations, so | ||
the tests are split into three files to ensure a fresh | ||
environment for each test. | ||
""" | ||
|
||
import os | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
import torch_xla.runtime as xr | ||
|
||
|
||
def _is_on_tpu(): | ||
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' | ||
|
||
|
||
skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') | ||
|
||
|
||
class TestMatMulPrecisionHighest(unittest.TestCase): | ||
|
||
def _make_input(self): | ||
eye = torch.eye(1024, device='cpu', dtype=torch.float64) | ||
rand_ = torch.testing.make_tensor((1024, 1024), | ||
dtype=torch.float64, | ||
device="cpu", | ||
low=0.99, | ||
high=1.01) | ||
return eye * rand_ | ||
|
||
# DO NOT add epsilons to this test. These tests must be numerically precise. | ||
@skipIfNotTpu | ||
def test_mat_mul_precision_numerics_default(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('highest') | ||
|
||
# Diagonal matrices force mat mul through MXU | ||
# but require only one non-zero accumulation. | ||
x = self._make_input() | ||
y = self._make_input() | ||
reference_float64 = torch.matmul(x, y) | ||
|
||
# 22 bits is an estimate of precision in the | ||
# six pass technique. | ||
worst_atol = torch.tensor(1 - ((2**22 - 1) / 2**22)**2, dtype=torch.float64) | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
x = x.to(torch.float32).to('xla') | ||
y = y.to(torch.float32).to('xla') | ||
|
||
# Act | ||
actual = torch.matmul(x, y).to('cpu').to(torch.float64) | ||
|
||
# Disable rtol, we know exactly the atol for default, high, and highest. | ||
torch.testing.assert_close( | ||
actual, | ||
reference_float64, | ||
rtol=0.0, | ||
atol=worst_atol, | ||
) | ||
|
||
assert not torch.equal(actual, reference_float64), ( | ||
"Actual product and reference product should not be equal, " | ||
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main(verbosity=0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""torch_xla.backends is intended to be moved to torch.backends.xla""" | ||
|
||
# See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py | ||
# for an example of how backends are implemented in PyTorch | ||
# in the __init__.py file. | ||
|
||
# Literal is available from Python 3.8, | ||
# matching the Python versions for PyTorch and PyTorchXLA. | ||
from typing import Final, Literal, TypeAlias | ||
|
||
import torch_xla | ||
|
||
__all__ = ['set_mat_mul_precision', 'get_mat_mul_precision'] | ||
|
||
# Valid values for get_mat_mul_precision/set_mat_mul_precision | ||
# Note: it is idiomatic to PyTorch to use strings rather than enums. | ||
# See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/backends/cpu/__init__.py#L9 | ||
|
||
_DEFAULT: Final = 'default' | ||
_HIGH: Final = 'high' | ||
_HIGHEST: Final = 'highest' | ||
|
||
# Use of variables with Final typehint instead of literals is valid. | ||
_PrecisionType: TypeAlias = Literal[ | ||
_DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm] | ||
|
||
|
||
# Some of this description adapted from Jax documentation. | ||
def set_mat_mul_precision(precision: _PrecisionType) -> None: | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Control the default matmul and conv precision for 32bit inputs. | ||
|
||
Some platforms, like TPU, offer configurable precision levels for | ||
matrix multiplication and convolution computations, | ||
trading off accuracy for speed. | ||
|
||
This option can be used to control the default precision level for | ||
computations involved in matrix multiplication and convolution on | ||
32bit inputs. The levels describe the precision at | ||
which scalar products are computed. | ||
|
||
On a TPU: | ||
* `default` is the fastest and least precise, essentially | ||
downcasting an FP32 to BF16 before multiplying. | ||
|
||
* `high` breaks the FP32 inputs into upper and lower bits and | ||
computes the product of the pair of upper bits, and the two pairs of | ||
upper and lower bits. The pair of lower bits is ignored. Since the BF16 | ||
format has 9 bits of precision (7 bits of mantissa plus one implicit | ||
leading one bit plus one sign bit), this delivers approximately | ||
18 bits of precision, more than the TensorFloat32 (10 bits) | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
but less than a full FP32 (23 bits). | ||
It requires three passes over the data. | ||
|
||
* `highest` is the most precise, and the slowest. | ||
It further breaks up each input into a triple of upper, middle, | ||
and lower bits and effectively calculates | ||
23 bits of precision. This is the most precise option, but also | ||
requires six passes. | ||
|
||
Args: | ||
precision (str): The precision to set for matrix multiplication. Must be | ||
one of 'default', 'high', or 'highest'. | ||
""" | ||
if precision not in [_DEFAULT, _HIGH, _HIGHEST]: | ||
raise ValueError(f"Invalid precision: {precision}. " | ||
"Must be one of 'float32', 'bfloat16', or 'float16'.") | ||
|
||
torch_xla._XLAC._xla_set_mat_mul_precision(precision) | ||
|
||
|
||
def get_mat_mul_precision() -> _PrecisionType: | ||
"""Get the current mat mul precision for 32bit inputs. | ||
|
||
Returns: | ||
str: The current precision setting for matrix multiplication, | ||
one of 'default', 'high', or 'highest'. | ||
""" | ||
precision = torch_xla._XLAC._xla_get_mat_mul_precision() | ||
assert precision in [_DEFAULT, _HIGH, _HIGHEST | ||
], (f"Invalid precision: {precision}. " | ||
"Must be one of 'float32', 'bfloat16', or 'float16'.") | ||
return precision |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.