-
Notifications
You must be signed in to change notification settings - Fork 546
9080 expose mat mul precision #9081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
86200a5
initial commit: binding and backends package, no tests.
yaoshiang 2901b77
Tests for default, high, and highest precision.
yaoshiang 564685a
clang-format
yaoshiang 4c0b52e
formatter
yaoshiang 0d3f44f
Updates to error messages
yaoshiang 3c333d8
fixed test class names
yaoshiang 967ccda
typo
yaoshiang 393257e
typo on error message. unit tested and yapfed.
yaoshiang f48d7f6
linter
yaoshiang 114456b
minor edits.
yaoshiang 87cff3d
Updated TODO per review
yaoshiang b96f671
Update todo and precision math per comment.
yaoshiang c26ad8f
Merge branch 'master' into 9080-expose-mat_mul_precision
yaoshiang 30d1c19
yapf
yaoshiang 959b0dc
linter
yaoshiang 8aaa979
parameterized, but in a process isolated way.
yaoshiang 04516c7
removed dead code
yaoshiang c3856c8
added issue for repeatable, unexpected behavior.
yaoshiang 79af416
Updated docstring.
yaoshiang 6ff7f59
changed naming of is_on_tpu
yaoshiang 89c72a8
Merge branch 'master' into 9080-expose-mat_mul_precision
yaoshiang 9db6e3e
CICD friendly version, hopefully.
yaoshiang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Numeric tests for default precision of mat mul. | ||
|
||
There are three similar test files, suffixed default, high, and highest. | ||
Unfortunately, the precision cannot reliably | ||
be dynamically changed between tensor operations, so | ||
the tests are split into three files to ensure a fresh | ||
environment for each test. | ||
""" | ||
|
||
import os | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
import torch_xla.runtime as xr | ||
|
||
|
||
def _is_on_tpu(): | ||
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' | ||
|
||
|
||
skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') | ||
|
||
|
||
class TestMatMulPrecisionDefault(unittest.TestCase): | ||
|
||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _make_input(self): | ||
eye = torch.eye(1024, device='cpu', dtype=torch.float64) | ||
rand_ = torch.testing.make_tensor((1024, 1024), | ||
dtype=torch.float64, | ||
device="cpu", | ||
low=0.99, | ||
high=1.01) | ||
return eye * rand_ | ||
|
||
# DO NOT add epsilons to this test. These tests must be numerically exact. | ||
@skipIfNotTpu | ||
def test_mat_mul_precision_numerics_default(self): | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('default') | ||
|
||
# Diagonal matrices force mat mul through MXU | ||
# but require only one non-zero accumulation. | ||
x = self._make_input() | ||
y = self._make_input() | ||
reference_float64 = torch.matmul(x, y) | ||
|
||
# This is slightly worse than 1/2**[mantissa bits] because in multiplication, both | ||
# operands may lose 1/256 (=2^8, where 8 = 7 + 1, | ||
# where 7 is the mantissa and 1 is the implicit bit). | ||
worst_atol = torch.tensor(1 - ((2**8 - 1) / 2**8)**2, dtype=torch.float64) | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
x = x.to(torch.float32).to('xla') | ||
y = y.to(torch.float32).to('xla') | ||
|
||
# Act | ||
actual = torch.matmul(x, y).to('cpu').to(torch.float64) | ||
|
||
# Disable rtol, we know exactly the atol for default, high, and highest. | ||
torch.testing.assert_close( | ||
actual, | ||
reference_float64, | ||
rtol=0.0, | ||
atol=worst_atol, | ||
) | ||
|
||
assert not torch.equal(actual, reference_float64), ( | ||
"Actual product and reference product should not be equal, " | ||
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main(verbosity=0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp""" | ||
|
||
import sys | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
|
||
|
||
class TestMatMulPrecisionGetAndSet(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self._original = torch_xla.backends.get_mat_mul_precision() | ||
torch.set_printoptions(precision=20) | ||
torch_xla.sync() | ||
|
||
def tearDown(self): | ||
torch_xla.backends.set_mat_mul_precision(self._original) | ||
torch.set_printoptions(profile="default") | ||
torch_xla.sync() | ||
|
||
def test_set_mat_mul_precision_error(self): | ||
# Assert | ||
with self.assertRaises(ValueError): | ||
# Act | ||
torch_xla.backends.set_mat_mul_precision('BAD VALUE') | ||
|
||
def test_get_and_set_mat_mul_precision_default(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('default') | ||
|
||
# Act | ||
status = torch_xla.backends.get_mat_mul_precision() | ||
|
||
# Assert | ||
self.assertEqual(status, 'default') | ||
|
||
def test_get_and_set_mat_mul_precision_high(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('high') | ||
|
||
# Act | ||
status = torch_xla.backends.get_mat_mul_precision() | ||
|
||
# Assert | ||
self.assertEqual(status, 'high') | ||
|
||
def test_get_and_set_mat_mul_precision_highest(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('highest') | ||
|
||
# Act | ||
status = torch_xla.backends.get_mat_mul_precision() | ||
|
||
# Assert | ||
self.assertEqual(status, 'highest') | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Numeric tests for high precision of mat mul. | ||
|
||
There are three similar test files, suffixed default, high, and highest. | ||
Unfortunately, the precision cannot reliably | ||
be dynamically changed between tensor operations, so | ||
the tests are split into three files to ensure a fresh | ||
environment for each test. | ||
""" | ||
|
||
import os | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
import torch_xla.runtime as xr | ||
|
||
|
||
def _is_on_tpu(): | ||
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' | ||
|
||
|
||
skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') | ||
|
||
|
||
class TestMatMulPrecisionHigh(unittest.TestCase): | ||
|
||
def _make_input(self): | ||
eye = torch.eye(1024, device='cpu', dtype=torch.float64) | ||
rand_ = torch.testing.make_tensor((1024, 1024), | ||
dtype=torch.float64, | ||
device="cpu", | ||
low=0.99, | ||
high=1.01) | ||
return eye * rand_ | ||
|
||
# DO NOT add epsilons to this test. These tests must be numerically exact. | ||
@skipIfNotTpu | ||
def test_mat_mul_precision_numerics_high(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('high') | ||
|
||
# Diagonal matrices force mat mul through MXU | ||
# but require only one non-zero accumulation. | ||
x = self._make_input() | ||
y = self._make_input() | ||
reference_float64 = torch.matmul(x, y) | ||
|
||
# 14 bits is an estimate of precision in the | ||
# three pass technique. | ||
worst_atol = torch.tensor(1 - ((2**14 - 1) / 2**14)**2, dtype=torch.float64) | ||
|
||
x = x.to(torch.float32).to('xla') | ||
y = y.to(torch.float32).to('xla') | ||
|
||
# Act | ||
actual = torch.matmul(x, y).to('cpu').to(torch.float64) | ||
|
||
# Disable rtol, we know exactly the atol for default, high, and highest. | ||
torch.testing.assert_close( | ||
actual, | ||
reference_float64, | ||
rtol=0.0, | ||
atol=worst_atol, | ||
) | ||
|
||
assert not torch.equal(actual, reference_float64), ( | ||
"Actual product and reference product should not be equal, " | ||
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main(verbosity=0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Numeric tests for highest precision of mat mul. | ||
|
||
There are three similar test files, suffixed default, high, and highest. | ||
Unfortunately, the precision cannot reliably | ||
be dynamically changed between tensor operations, so | ||
the tests are split into three files to ensure a fresh | ||
environment for each test. | ||
""" | ||
|
||
import os | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
import torch_xla.runtime as xr | ||
|
||
|
||
def _is_on_tpu(): | ||
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' | ||
|
||
|
||
skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') | ||
|
||
|
||
class TestMatMulPrecisionHighest(unittest.TestCase): | ||
|
||
def _make_input(self): | ||
eye = torch.eye(1024, device='cpu', dtype=torch.float64) | ||
rand_ = torch.testing.make_tensor((1024, 1024), | ||
dtype=torch.float64, | ||
device="cpu", | ||
low=0.99, | ||
high=1.01) | ||
return eye * rand_ | ||
|
||
# DO NOT add epsilons to this test. These tests must be numerically precise. | ||
@skipIfNotTpu | ||
def test_mat_mul_precision_numerics_default(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('highest') | ||
|
||
# Diagonal matrices force mat mul through MXU | ||
# but require only one non-zero accumulation. | ||
x = self._make_input() | ||
y = self._make_input() | ||
reference_float64 = torch.matmul(x, y) | ||
|
||
# 22 bits is an estimate of precision in the | ||
# six pass technique. | ||
worst_atol = torch.tensor(1 - ((2**22 - 1) / 2**22)**2, dtype=torch.float64) | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
x = x.to(torch.float32).to('xla') | ||
y = y.to(torch.float32).to('xla') | ||
|
||
# Act | ||
actual = torch.matmul(x, y).to('cpu').to(torch.float64) | ||
|
||
# Disable rtol, we know exactly the atol for default, high, and highest. | ||
torch.testing.assert_close( | ||
actual, | ||
reference_float64, | ||
rtol=0.0, | ||
atol=worst_atol, | ||
) | ||
|
||
assert not torch.equal(actual, reference_float64), ( | ||
"Actual product and reference product should not be equal, " | ||
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main(verbosity=0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""torch_xla.backends controls the behavior of the XLA backend. | ||
|
||
This subpackage parallels the torch.backends.{cuda, cpu, mps, etc} | ||
subpackages in PyTorch. | ||
""" | ||
|
||
# See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py | ||
# for an example of how backends are implemented in PyTorch | ||
# in the __init__.py file, despite general style guidelines against this. | ||
|
||
# Literal is available from Python 3.8, | ||
# matching the Python versions for PyTorch and PyTorch/XLA. | ||
from typing import Final, Literal, TypeAlias | ||
|
||
import torch_xla | ||
|
||
__all__ = ["set_mat_mul_precision", "get_mat_mul_precision"] | ||
|
||
# Valid values for get_mat_mul_precision/set_mat_mul_precision | ||
# Note: it is idiomatic to PyTorch to use strings rather than enums. | ||
# See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/backends/cpu/__init__.py#L9 | ||
|
||
_DEFAULT: Final = "default" | ||
_HIGH: Final = "high" | ||
_HIGHEST: Final = "highest" | ||
|
||
# Use of variables with Final typehint instead of literals is valid. | ||
_PrecisionType: TypeAlias = Literal[ | ||
_DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm] | ||
|
||
|
||
# Some of this description adapted from Jax documentation. | ||
# TODO: Once the numerics tutorial is released, link from this docstring. | ||
def set_mat_mul_precision(precision: _PrecisionType) -> None: | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Control the default matmul and conv precision for 32bit inputs. | ||
|
||
Some platforms, like TPU, offer configurable precision levels for | ||
matrix multiplication and convolution computations, | ||
trading off accuracy for speed. | ||
|
||
This option controls the default precision level for | ||
computations involved in matrix multiplication and convolution on | ||
32bit inputs. The levels describe the precision at | ||
which scalar products are computed. | ||
|
||
On a TPU: | ||
* `default` is the fastest and least precise, | ||
downcasting an FP32 to BF16 before multiplying. | ||
|
||
* `high` takes three passes and generates approximately 14 bits of | ||
precision. | ||
|
||
* `highest` is the most precise, and the slowest. It takes six | ||
passes and generates approximately 22 bits of precision. | ||
|
||
Args: | ||
precision (str): The precision to set for matrix multiplication. | ||
Must be one of 'default', 'high', or 'highest'. | ||
""" | ||
if precision not in [_DEFAULT, _HIGH, _HIGHEST]: | ||
raise ValueError(f"Invalid precision: {precision}. " | ||
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") | ||
|
||
torch_xla._XLAC._xla_set_mat_mul_precision(precision) | ||
|
||
|
||
def get_mat_mul_precision() -> _PrecisionType: | ||
"""Get the current mat mul precision for 32bit inputs. | ||
|
||
Returns: | ||
str: The current precision setting for matrix multiplication, | ||
one of 'default', 'high', or 'highest'. | ||
""" | ||
precision = torch_xla._XLAC._xla_get_mat_mul_precision() | ||
assert precision in [_DEFAULT, _HIGH, _HIGHEST | ||
], (f"Invalid precision: {precision}. " | ||
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") | ||
return precision |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.