From 86200a58e03ec68282e78c63ec50c042408ab333 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 2 May 2025 14:45:30 +0000 Subject: [PATCH 01/20] initial commit: binding and backends package, no tests. --- test/test_mat_mul_precision.py | 4 ++ torch_xla/backends/__init__.py | 78 +++++++++++++++++++++++++ torch_xla/csrc/init_python_bindings.cpp | 5 ++ 3 files changed, 87 insertions(+) create mode 100644 test/test_mat_mul_precision.py create mode 100644 torch_xla/backends/__init__.py diff --git a/test/test_mat_mul_precision.py b/test/test_mat_mul_precision.py new file mode 100644 index 000000000000..4a7572771aea --- /dev/null +++ b/test/test_mat_mul_precision.py @@ -0,0 +1,4 @@ +"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp""" + +import unittest + diff --git a/torch_xla/backends/__init__.py b/torch_xla/backends/__init__.py new file mode 100644 index 000000000000..85a92a5d30db --- /dev/null +++ b/torch_xla/backends/__init__.py @@ -0,0 +1,78 @@ +"""torch_xla.backends is intended to be moved to torch.backends.xla""" + +# See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py +# for an example of how a backend is implemented in PyTorch. + +# Literal is availabe from Python 3.8, +# matching the Python versions for PyTorch and PyTorchXLA. +from typing import Final, Literal, TypeAlias + +import torch_xla + +__all__ = ['set_mat_mul_precision', 'get_mat_mul_precision'] + +# Valid values for get_mat_mul_precision/set_mat_mul_precision +# Note: it is idiomatic to PyTorch to use strings rather than enums. +# See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/backends/cpu/__init__.py#L9 + +_DEFAULT: Final = 'default' +_HIGH: Final = 'high' +_HIGHEST: Final = 'highest' + +# Use of variables with Final typehint instead of literals is valid. +_PrecisionType: TypeAlias = Literal[_DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm] + +# Some of this description adapted from Jax documentation. +def set_mat_mul_precision(precision: _PrecisionType) -> None: + r"""Control the default matmul and conv precision for 32bit inputs. + + Some platforms, like TPU, offer configurable precision levels for + matrix multiplication and convolution computations, + trading off accuracy for speed. + + This option can be used to control the default precision level for + computations involved in matrix multiplication and convolution on + 32bit inputs. The levels describe the precision at + which scalar products are computed. + + On a TPU: + * `default` is the fastest and least precise, essentially + downcasting an FP32 to BF16 before multiplying. + + * `high` breaks the FP32 inputs into upper and lower bits and + computes the product of the pair of upper bits, and the two pairs of + upper and lower bits. The pair of lower bits is ignored. Since the BF16 + format has 9 bits of precision (7 bits of mantissa plus one implicit + leading one bit plus one sign bit), this delivers approximately + 18 bits of precision, more than the TensorFloat32 (10 bits) + but less than a full FP32 (23 bits). + It requires three passes over the data. + + * `highest` is the most precise, and the slowest. + It further breaks up each input into a triple of upper, middle, + and lower bits and effectively calculates + 23 bits of precision. This is the most precise option, but also + requires six passes. + + Args: + precision (str): The precision to set for matrix multiplication. Must be + one of 'default', 'high', or 'highest'. + """ + if precision not in [_DEFAULT, _HIGH, _HIGHEST]: + raise ValueError(f"Invalid precision: {precision}. " + "Must be one of 'float32', 'bfloat16', or 'float16'.") + + torch_xla._XLAC._set_mat_mul_precision(precision) + +def get_mat_mul_precision() -> _PrecisionType: + r"""Get the current mat mul precision for 32bit inputs. + + Returns: + str: The current precision setting for matrix multiplication, + one of 'default', 'high', or 'highest'. + """ + precision = torch_xla._XLAC._get_mat_mul_precision() + assert precision in [_DEFAULT, _HIGH, _HIGHEST], ( + f"Invalid precision: {precision}. " + "Must be one of 'float32', 'bfloat16', or 'float16'.") + return precision diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index d1b38efe00ea..61e101d312bb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2124,6 +2124,11 @@ void InitXlaModuleBindings(py::module m) { ConsumeValue(xla::StringToPrecision(mat_mul_precision)); XlaHelpers::set_mat_mul_precision(precision); }); + m.def("_xla_get_mat_mul_precision", []() { + xla::PrecisionConfig::Precision precision = + XlaHelpers::mat_mul_precision(); + return xla::PrecisionToString(precision); + }); py::class_(m, "XlaBuilder"); py::class_(m, "XlaOp"); From 2901b77c59021c92538c27c93061fc754e414792 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 2 May 2025 18:42:16 +0000 Subject: [PATCH 02/20] Tests for default, high, and highest precision. --- test/test_mat_mul_precision.py | 4 -- test/test_mat_mul_precision_default.py | 78 ++++++++++++++++++++++ test/test_mat_mul_precision_get_and_set.py | 62 +++++++++++++++++ test/test_mat_mul_precision_high.py | 77 +++++++++++++++++++++ test/test_mat_mul_precision_highest.py | 77 +++++++++++++++++++++ test/tpu/run_tests.sh | 4 ++ torch_xla/backends/__init__.py | 4 +- torch_xla/csrc/xla_op_builder.cpp | 3 + 8 files changed, 303 insertions(+), 6 deletions(-) delete mode 100644 test/test_mat_mul_precision.py create mode 100644 test/test_mat_mul_precision_default.py create mode 100644 test/test_mat_mul_precision_get_and_set.py create mode 100644 test/test_mat_mul_precision_high.py create mode 100644 test/test_mat_mul_precision_highest.py diff --git a/test/test_mat_mul_precision.py b/test/test_mat_mul_precision.py deleted file mode 100644 index 4a7572771aea..000000000000 --- a/test/test_mat_mul_precision.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp""" - -import unittest - diff --git a/test/test_mat_mul_precision_default.py b/test/test_mat_mul_precision_default.py new file mode 100644 index 000000000000..3fc1377e8ee8 --- /dev/null +++ b/test/test_mat_mul_precision_default.py @@ -0,0 +1,78 @@ +"""Numeric tests for default precision of mat mul. + +There are three similar test files, suffixed default, high, and highest. +Unfortunately, the precision cannot reliably +be dynamically changed between tensor operations, so +the tests are split into three files to ensure a fresh +environment for each test. +""" + +import os +import sys +import unittest + +import torch +import torch_xla +import torch_xla.backends +import torch_xla.runtime as xr + + +def _is_on_tpu(): + return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' + + +skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') + + +class TestMatMulPrecisionDefaultDefault(unittest.TestCase): + + def _make_input(self): + eye = torch.eye(1024, device='cpu', dtype=torch.float64) + rand_ = torch.testing.make_tensor((1024, 1024), + dtype=torch.float64, + device="cpu", + low=0.99, + high=1.01) + return eye * rand_ + + # DO NOT add epsilons to this test. These tests must be numerically exact. + @skipIfNotTpu + def test_mat_mul_precision_numerics_default(self): + # Arrange + torch_xla.backends.set_mat_mul_precision('default') + + # Diagonal matrices force mat mul through MXU + # but require only one non-zero accumulation. + x = self._make_input() + y = self._make_input() + actual_float64 = torch.matmul(x, y) + + # This is slightly worse than 1/2**[mantissa bits] because in multiplication, both + # operands may lose 1/256 (=2^8, where 8 = 7 + 1, + # where 7 is the mantissa and 1 is the implicit bit). + worst_error_default = torch.tensor( + 1 - ((2**8 - 1) / 2**8)**2, dtype=torch.float64) + + x = x.to(torch.float32).to('xla') + y = y.to(torch.float32).to('xla') + + # Act + actual_default = torch.matmul(x, y).to('cpu').to(torch.float64) + + # Disable rtol, we know exactly the atol for default, high, and highest. + torch.testing.assert_close( + actual_default, + actual_float64, + rtol=0.0, + atol=worst_error_default, + ) + + assert not torch.equal(actual_default, actual_float64), ( + f"Default and high precision should not be equal, " + f"but they are: {torch.diag(actual_default)} == {torch.diag(actual_float64)}" + ) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_mat_mul_precision_get_and_set.py b/test/test_mat_mul_precision_get_and_set.py new file mode 100644 index 000000000000..ad47a9cc0e60 --- /dev/null +++ b/test/test_mat_mul_precision_get_and_set.py @@ -0,0 +1,62 @@ +"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp""" + +import sys +import unittest + +import torch +import torch_xla +import torch_xla.backends + + +class TestMatMulPrecisionGetAndSet(unittest.TestCase): + + def setUp(self): + self._original = torch_xla.backends.get_mat_mul_precision() + torch.set_printoptions(precision=20) + torch_xla.sync() + + def tearDown(self): + torch_xla.backends.set_mat_mul_precision(self._original) + torch.set_printoptions(profile="default") + torch_xla.sync() + + def test_set_mat_mul_precision_error(self): + # Assert + with self.assertRaises(ValueError): + # Act + torch_xla.backends.set_mat_mul_precision('BAD VALUE') + + def test_get_and_set_mat_mul_precision_default(self): + # Arrange + torch_xla.backends.set_mat_mul_precision('default') + + # Act + status = torch_xla.backends.get_mat_mul_precision() + + # Assert + self.assertEqual(status, 'default') + + def test_get_and_set_mat_mul_precision_high(self): + # Arrange + torch_xla.backends.set_mat_mul_precision('high') + + # Act + status = torch_xla.backends.get_mat_mul_precision() + + # Assert + self.assertEqual(status, 'high') + + def test_get_and_set_mat_mul_precision_highest(self): + # Arrange + torch_xla.backends.set_mat_mul_precision('highest') + + # Act + status = torch_xla.backends.get_mat_mul_precision() + + # Assert + self.assertEqual(status, 'highest') + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_mat_mul_precision_high.py b/test/test_mat_mul_precision_high.py new file mode 100644 index 000000000000..a966a60ab49a --- /dev/null +++ b/test/test_mat_mul_precision_high.py @@ -0,0 +1,77 @@ +"""Numeric tests for high precision of mat mul. + +There are three similar test files, suffixed default, high, and highest. +Unfortunately, the precision cannot reliably +be dynamically changed between tensor operations, so +the tests are split into three files to ensure a fresh +environment for each test. +""" + +import os +import sys +import unittest + +import torch +import torch_xla +import torch_xla.backends +import torch_xla.runtime as xr + + +def _is_on_tpu(): + return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' + + +skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') + + +class TestMatMulPrecisionHigh(unittest.TestCase): + + def _make_input(self): + eye = torch.eye(1024, device='cpu', dtype=torch.float64) + rand_ = torch.testing.make_tensor((1024, 1024), + dtype=torch.float64, + device="cpu", + low=0.99, + high=1.01) + return eye * rand_ + + # DO NOT add epsilons to this test. These tests must be numerically exact. + @skipIfNotTpu + def test_mat_mul_precision_numerics_high(self): + # Arrange + torch_xla.backends.set_mat_mul_precision('high') + + # Diagonal matrices force mat mul through MXU + # but require only one non-zero accumulation. + x = self._make_input() + y = self._make_input() + actual_float64 = torch.matmul(x, y) + + # 14 bits is an estimate of precision in the + # three pass technique. + worst_error_default = torch.tensor( + 1 - ((2**14 - 1) / 2**14)**2, dtype=torch.float64) + + x = x.to(torch.float32).to('xla') + y = y.to(torch.float32).to('xla') + + # Act + actual_default = torch.matmul(x, y).to('cpu').to(torch.float64) + + # Disable rtol, we know exactly the atol for default, high, and highest. + torch.testing.assert_close( + actual_default, + actual_float64, + rtol=0.0, + atol=worst_error_default, + ) + + assert not torch.equal(actual_default, actual_float64), ( + f"Default and high precision should not be equal, " + f"but they are: {torch.diag(actual_default)} == {torch.diag(actual_float64)}" + ) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_mat_mul_precision_highest.py b/test/test_mat_mul_precision_highest.py new file mode 100644 index 000000000000..edd0467e3676 --- /dev/null +++ b/test/test_mat_mul_precision_highest.py @@ -0,0 +1,77 @@ +"""Numeric tests for highest precision of mat mul. + +There are three similar test files, suffixed default, high, and highest. +Unfortunately, the precision cannot reliably +be dynamically changed between tensor operations, so +the tests are split into three files to ensure a fresh +environment for each test. +""" + +import os +import sys +import unittest + +import torch +import torch_xla +import torch_xla.backends +import torch_xla.runtime as xr + + +def _is_on_tpu(): + return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' + + +skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') + + +class TestMatMulPrecisionDefault(unittest.TestCase): + + def _make_input(self): + eye = torch.eye(1024, device='cpu', dtype=torch.float64) + rand_ = torch.testing.make_tensor((1024, 1024), + dtype=torch.float64, + device="cpu", + low=0.99, + high=1.01) + return eye * rand_ + + # DO NOT add epsilons to this test. These tests must be numerically exact. + @skipIfNotTpu + def test_mat_mul_precision_numerics_default(self): + # Arrange + torch_xla.backends.set_mat_mul_precision('highest') + + # Diagonal matrices force mat mul through MXU + # but require only one non-zero accumulation. + x = self._make_input() + y = self._make_input() + actual_float64 = torch.matmul(x, y) + + # 22 bits is an estimate of precision in the + # six pass technique. + worst_error_default = torch.tensor( + 1 - ((2**22 - 1) / 2**22)**2, dtype=torch.float64) + + x = x.to(torch.float32).to('xla') + y = y.to(torch.float32).to('xla') + + # Act + actual_default = torch.matmul(x, y).to('cpu').to(torch.float64) + + # Disable rtol, we know exactly the atol for default, high, and highest. + torch.testing.assert_close( + actual_default, + actual_float64, + rtol=0.0, + atol=worst_error_default, + ) + + assert not torch.equal(actual_default, actual_float64), ( + f"Default and high precision should not be equal, " + f"but they are: {torch.diag(actual_default)} == {torch.diag(actual_float64)}" + ) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 6b596a157460..68669c0d084a 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -6,6 +6,10 @@ TEST_CDIR="$(dirname "$CDIR")" source "${TEST_CDIR}/utils/run_tests_utils.sh" # TODO: merge with other run_tests +python3 "$TEST_CDIR/test_mat_mul_precision_highest.py" +python3 "$TEST_CDIR/test_mat_mul_precision_get_and_set.py" +python3 "$TEST_CDIR/test_mat_mul_precision_default.py" +python3 "$TEST_CDIR/test_mat_mul_precision_high.py" python3 "$TEST_CDIR/test_operations.py" -v python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py" python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py" diff --git a/torch_xla/backends/__init__.py b/torch_xla/backends/__init__.py index 85a92a5d30db..d22e7859c4d3 100644 --- a/torch_xla/backends/__init__.py +++ b/torch_xla/backends/__init__.py @@ -62,7 +62,7 @@ def set_mat_mul_precision(precision: _PrecisionType) -> None: raise ValueError(f"Invalid precision: {precision}. " "Must be one of 'float32', 'bfloat16', or 'float16'.") - torch_xla._XLAC._set_mat_mul_precision(precision) + torch_xla._XLAC._xla_set_mat_mul_precision(precision) def get_mat_mul_precision() -> _PrecisionType: r"""Get the current mat mul precision for 32bit inputs. @@ -71,7 +71,7 @@ def get_mat_mul_precision() -> _PrecisionType: str: The current precision setting for matrix multiplication, one of 'default', 'high', or 'highest'. """ - precision = torch_xla._XLAC._get_mat_mul_precision() + precision = torch_xla._XLAC._xla_get_mat_mul_precision() assert precision in [_DEFAULT, _HIGH, _HIGHEST], ( f"Invalid precision: {precision}. " "Must be one of 'float32', 'bfloat16', or 'float16'.") diff --git a/torch_xla/csrc/xla_op_builder.cpp b/torch_xla/csrc/xla_op_builder.cpp index b84cc2c0a0e1..5a99567e7bde 100644 --- a/torch_xla/csrc/xla_op_builder.cpp +++ b/torch_xla/csrc/xla_op_builder.cpp @@ -208,6 +208,9 @@ xla::PrecisionConfig DotPrecisonConfig(py::dict args) { precision = xla::PrecisionConfig::HIGH; } else if (*arg_precision_config == "highest") { precision = xla::PrecisionConfig::HIGHEST; + } else { + XLA_ERROR() << "Invalid precision config in args: " << *arg_precision_config << + " (valid values: default, high, highest)"; } } return XlaHelpers::BuildPrecisionConfig(precision); From 564685a7a2420b1a21dd10294b575c1c84c105db Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 2 May 2025 18:59:41 +0000 Subject: [PATCH 03/20] clang-format --- torch_xla/csrc/init_python_bindings.cpp | 3 +-- torch_xla/csrc/xla_op_builder.cpp | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 61e101d312bb..9160342383df 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2125,8 +2125,7 @@ void InitXlaModuleBindings(py::module m) { XlaHelpers::set_mat_mul_precision(precision); }); m.def("_xla_get_mat_mul_precision", []() { - xla::PrecisionConfig::Precision precision = - XlaHelpers::mat_mul_precision(); + xla::PrecisionConfig::Precision precision = XlaHelpers::mat_mul_precision(); return xla::PrecisionToString(precision); }); diff --git a/torch_xla/csrc/xla_op_builder.cpp b/torch_xla/csrc/xla_op_builder.cpp index 5a99567e7bde..1fbae78eee95 100644 --- a/torch_xla/csrc/xla_op_builder.cpp +++ b/torch_xla/csrc/xla_op_builder.cpp @@ -209,8 +209,9 @@ xla::PrecisionConfig DotPrecisonConfig(py::dict args) { } else if (*arg_precision_config == "highest") { precision = xla::PrecisionConfig::HIGHEST; } else { - XLA_ERROR() << "Invalid precision config in args: " << *arg_precision_config << - " (valid values: default, high, highest)"; + XLA_ERROR() << "Invalid precision config in args: " + << *arg_precision_config + << " (valid values: default, high, highest)"; } } return XlaHelpers::BuildPrecisionConfig(precision); From 4c0b52ee561b62c5ae0f4c6a755f6d7db5e4fe66 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 2 May 2025 19:15:03 +0000 Subject: [PATCH 04/20] formatter --- torch_xla/backends/__init__.py | 105 +++++++++++++++++---------------- 1 file changed, 54 insertions(+), 51 deletions(-) diff --git a/torch_xla/backends/__init__.py b/torch_xla/backends/__init__.py index d22e7859c4d3..38be9b8f9e88 100644 --- a/torch_xla/backends/__init__.py +++ b/torch_xla/backends/__init__.py @@ -20,59 +20,62 @@ _HIGHEST: Final = 'highest' # Use of variables with Final typehint instead of literals is valid. -_PrecisionType: TypeAlias = Literal[_DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm] +_PrecisionType: TypeAlias = Literal[ + _DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm] + # Some of this description adapted from Jax documentation. def set_mat_mul_precision(precision: _PrecisionType) -> None: - r"""Control the default matmul and conv precision for 32bit inputs. - - Some platforms, like TPU, offer configurable precision levels for - matrix multiplication and convolution computations, - trading off accuracy for speed. - - This option can be used to control the default precision level for - computations involved in matrix multiplication and convolution on - 32bit inputs. The levels describe the precision at - which scalar products are computed. - - On a TPU: - * `default` is the fastest and least precise, essentially - downcasting an FP32 to BF16 before multiplying. - - * `high` breaks the FP32 inputs into upper and lower bits and - computes the product of the pair of upper bits, and the two pairs of - upper and lower bits. The pair of lower bits is ignored. Since the BF16 - format has 9 bits of precision (7 bits of mantissa plus one implicit - leading one bit plus one sign bit), this delivers approximately - 18 bits of precision, more than the TensorFloat32 (10 bits) - but less than a full FP32 (23 bits). - It requires three passes over the data. - - * `highest` is the most precise, and the slowest. - It further breaks up each input into a triple of upper, middle, - and lower bits and effectively calculates - 23 bits of precision. This is the most precise option, but also - requires six passes. - - Args: - precision (str): The precision to set for matrix multiplication. Must be - one of 'default', 'high', or 'highest'. - """ - if precision not in [_DEFAULT, _HIGH, _HIGHEST]: - raise ValueError(f"Invalid precision: {precision}. " - "Must be one of 'float32', 'bfloat16', or 'float16'.") - - torch_xla._XLAC._xla_set_mat_mul_precision(precision) + """Control the default matmul and conv precision for 32bit inputs. + + Some platforms, like TPU, offer configurable precision levels for + matrix multiplication and convolution computations, + trading off accuracy for speed. + + This option can be used to control the default precision level for + computations involved in matrix multiplication and convolution on + 32bit inputs. The levels describe the precision at + which scalar products are computed. + + On a TPU: + * `default` is the fastest and least precise, essentially + downcasting an FP32 to BF16 before multiplying. + + * `high` breaks the FP32 inputs into upper and lower bits and + computes the product of the pair of upper bits, and the two pairs of + upper and lower bits. The pair of lower bits is ignored. Since the BF16 + format has 9 bits of precision (7 bits of mantissa plus one implicit + leading one bit plus one sign bit), this delivers approximately + 18 bits of precision, more than the TensorFloat32 (10 bits) + but less than a full FP32 (23 bits). + It requires three passes over the data. + + * `highest` is the most precise, and the slowest. + It further breaks up each input into a triple of upper, middle, + and lower bits and effectively calculates + 23 bits of precision. This is the most precise option, but also + requires six passes. + + Args: + precision (str): The precision to set for matrix multiplication. Must be + one of 'default', 'high', or 'highest'. + """ + if precision not in [_DEFAULT, _HIGH, _HIGHEST]: + raise ValueError(f"Invalid precision: {precision}. " + "Must be one of 'float32', 'bfloat16', or 'float16'.") + + torch_xla._XLAC._xla_set_mat_mul_precision(precision) + def get_mat_mul_precision() -> _PrecisionType: - r"""Get the current mat mul precision for 32bit inputs. - - Returns: - str: The current precision setting for matrix multiplication, - one of 'default', 'high', or 'highest'. - """ - precision = torch_xla._XLAC._xla_get_mat_mul_precision() - assert precision in [_DEFAULT, _HIGH, _HIGHEST], ( - f"Invalid precision: {precision}. " - "Must be one of 'float32', 'bfloat16', or 'float16'.") - return precision + """Get the current mat mul precision for 32bit inputs. + + Returns: + str: The current precision setting for matrix multiplication, + one of 'default', 'high', or 'highest'. + """ + precision = torch_xla._XLAC._xla_get_mat_mul_precision() + assert precision in [_DEFAULT, _HIGH, _HIGHEST + ], (f"Invalid precision: {precision}. " + "Must be one of 'float32', 'bfloat16', or 'float16'.") + return precision From 0d3f44ffed91d10d1acc88b5ee0b9237de48305e Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 2 May 2025 22:27:22 +0000 Subject: [PATCH 05/20] Updates to error messages --- test/test_mat_mul_precision_default.py | 23 ++++++++++------------- test/test_mat_mul_precision_high.py | 23 ++++++++++------------- test/test_mat_mul_precision_highest.py | 25 +++++++++++-------------- 3 files changed, 31 insertions(+), 40 deletions(-) diff --git a/test/test_mat_mul_precision_default.py b/test/test_mat_mul_precision_default.py index 3fc1377e8ee8..913ab518a7be 100644 --- a/test/test_mat_mul_precision_default.py +++ b/test/test_mat_mul_precision_default.py @@ -8,7 +8,6 @@ """ import os -import sys import unittest import torch @@ -45,34 +44,32 @@ def test_mat_mul_precision_numerics_default(self): # but require only one non-zero accumulation. x = self._make_input() y = self._make_input() - actual_float64 = torch.matmul(x, y) + reference_float64 = torch.matmul(x, y) # This is slightly worse than 1/2**[mantissa bits] because in multiplication, both # operands may lose 1/256 (=2^8, where 8 = 7 + 1, # where 7 is the mantissa and 1 is the implicit bit). - worst_error_default = torch.tensor( - 1 - ((2**8 - 1) / 2**8)**2, dtype=torch.float64) + worst_atol = torch.tensor(1 - ((2**8 - 1) / 2**8)**2, dtype=torch.float64) x = x.to(torch.float32).to('xla') y = y.to(torch.float32).to('xla') # Act - actual_default = torch.matmul(x, y).to('cpu').to(torch.float64) + actual = torch.matmul(x, y).to('cpu').to(torch.float64) # Disable rtol, we know exactly the atol for default, high, and highest. torch.testing.assert_close( - actual_default, - actual_float64, + actual, + reference_float64, rtol=0.0, - atol=worst_error_default, + atol=worst_atol, ) - assert not torch.equal(actual_default, actual_float64), ( - f"Default and high precision should not be equal, " - f"but they are: {torch.diag(actual_default)} == {torch.diag(actual_float64)}" + assert not torch.equal(actual, reference_float64), ( + "Actual product and reference product should not be equal, " + f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" ) if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) + unittest.main(verbosity=0) diff --git a/test/test_mat_mul_precision_high.py b/test/test_mat_mul_precision_high.py index a966a60ab49a..ff94a260acaa 100644 --- a/test/test_mat_mul_precision_high.py +++ b/test/test_mat_mul_precision_high.py @@ -8,7 +8,6 @@ """ import os -import sys import unittest import torch @@ -45,33 +44,31 @@ def test_mat_mul_precision_numerics_high(self): # but require only one non-zero accumulation. x = self._make_input() y = self._make_input() - actual_float64 = torch.matmul(x, y) + reference_float64 = torch.matmul(x, y) # 14 bits is an estimate of precision in the # three pass technique. - worst_error_default = torch.tensor( - 1 - ((2**14 - 1) / 2**14)**2, dtype=torch.float64) + worst_atol = torch.tensor(1 - ((2**14 - 1) / 2**14)**2, dtype=torch.float64) x = x.to(torch.float32).to('xla') y = y.to(torch.float32).to('xla') # Act - actual_default = torch.matmul(x, y).to('cpu').to(torch.float64) + actual = torch.matmul(x, y).to('cpu').to(torch.float64) # Disable rtol, we know exactly the atol for default, high, and highest. torch.testing.assert_close( - actual_default, - actual_float64, + actual, + reference_float64, rtol=0.0, - atol=worst_error_default, + atol=worst_atol, ) - assert not torch.equal(actual_default, actual_float64), ( - f"Default and high precision should not be equal, " - f"but they are: {torch.diag(actual_default)} == {torch.diag(actual_float64)}" + assert not torch.equal(actual, reference_float64), ( + "Actual product and reference product should not be equal, " + f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" ) if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) + unittest.main(verbosity=0) diff --git a/test/test_mat_mul_precision_highest.py b/test/test_mat_mul_precision_highest.py index edd0467e3676..be8c6aa801ea 100644 --- a/test/test_mat_mul_precision_highest.py +++ b/test/test_mat_mul_precision_highest.py @@ -8,7 +8,6 @@ """ import os -import sys import unittest import torch @@ -35,7 +34,7 @@ def _make_input(self): high=1.01) return eye * rand_ - # DO NOT add epsilons to this test. These tests must be numerically exact. + # DO NOT add epsilons to this test. These tests must be numerically precise. @skipIfNotTpu def test_mat_mul_precision_numerics_default(self): # Arrange @@ -45,33 +44,31 @@ def test_mat_mul_precision_numerics_default(self): # but require only one non-zero accumulation. x = self._make_input() y = self._make_input() - actual_float64 = torch.matmul(x, y) + reference_float64 = torch.matmul(x, y) # 22 bits is an estimate of precision in the # six pass technique. - worst_error_default = torch.tensor( - 1 - ((2**22 - 1) / 2**22)**2, dtype=torch.float64) + worst_atol = torch.tensor(1 - ((2**22 - 1) / 2**22)**2, dtype=torch.float64) x = x.to(torch.float32).to('xla') y = y.to(torch.float32).to('xla') # Act - actual_default = torch.matmul(x, y).to('cpu').to(torch.float64) + actual = torch.matmul(x, y).to('cpu').to(torch.float64) # Disable rtol, we know exactly the atol for default, high, and highest. torch.testing.assert_close( - actual_default, - actual_float64, + actual, + reference_float64, rtol=0.0, - atol=worst_error_default, + atol=worst_atol, ) - assert not torch.equal(actual_default, actual_float64), ( - f"Default and high precision should not be equal, " - f"but they are: {torch.diag(actual_default)} == {torch.diag(actual_float64)}" + assert not torch.equal(actual, reference_float64), ( + "Actual product and reference product should not be equal, " + f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" ) if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) + unittest.main(verbosity=0) From 3c333d8024bcbade7b67ef4727bd5ca19463db20 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 2 May 2025 23:53:06 +0000 Subject: [PATCH 06/20] fixed test class names --- test/test_mat_mul_precision_default.py | 2 +- test/test_mat_mul_precision_highest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_mat_mul_precision_default.py b/test/test_mat_mul_precision_default.py index 913ab518a7be..c9169c3d5664 100644 --- a/test/test_mat_mul_precision_default.py +++ b/test/test_mat_mul_precision_default.py @@ -23,7 +23,7 @@ def _is_on_tpu(): skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') -class TestMatMulPrecisionDefaultDefault(unittest.TestCase): +class TestMatMulPrecisionDefault(unittest.TestCase): def _make_input(self): eye = torch.eye(1024, device='cpu', dtype=torch.float64) diff --git a/test/test_mat_mul_precision_highest.py b/test/test_mat_mul_precision_highest.py index be8c6aa801ea..9192e9ff507c 100644 --- a/test/test_mat_mul_precision_highest.py +++ b/test/test_mat_mul_precision_highest.py @@ -23,7 +23,7 @@ def _is_on_tpu(): skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') -class TestMatMulPrecisionDefault(unittest.TestCase): +class TestMatMulPrecisionHighest(unittest.TestCase): def _make_input(self): eye = torch.eye(1024, device='cpu', dtype=torch.float64) From 967ccda1449b9d13fffd49192af5596edb31a55a Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 2 May 2025 23:56:07 +0000 Subject: [PATCH 07/20] typo --- torch_xla/backends/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_xla/backends/__init__.py b/torch_xla/backends/__init__.py index 38be9b8f9e88..9f14fb6720b9 100644 --- a/torch_xla/backends/__init__.py +++ b/torch_xla/backends/__init__.py @@ -1,9 +1,10 @@ """torch_xla.backends is intended to be moved to torch.backends.xla""" # See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py -# for an example of how a backend is implemented in PyTorch. +# for an example of how backends are implemented in PyTorch +# in the __init__.py file. -# Literal is availabe from Python 3.8, +# Literal is available from Python 3.8, # matching the Python versions for PyTorch and PyTorchXLA. from typing import Final, Literal, TypeAlias From 393257e5e289605c9c99d459445a21b7d856c101 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Mon, 5 May 2025 17:54:01 +0000 Subject: [PATCH 08/20] typo on error message. unit tested and yapfed. --- torch_xla/backends/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/backends/__init__.py b/torch_xla/backends/__init__.py index 9f14fb6720b9..a7607dcd0187 100644 --- a/torch_xla/backends/__init__.py +++ b/torch_xla/backends/__init__.py @@ -2,10 +2,10 @@ # See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py # for an example of how backends are implemented in PyTorch -# in the __init__.py file. +# in the __init__.py file, despite general style guidelines against this. # Literal is available from Python 3.8, -# matching the Python versions for PyTorch and PyTorchXLA. +# matching the Python versions for PyTorch and PyTorch/XLA. from typing import Final, Literal, TypeAlias import torch_xla @@ -63,7 +63,7 @@ def set_mat_mul_precision(precision: _PrecisionType) -> None: """ if precision not in [_DEFAULT, _HIGH, _HIGHEST]: raise ValueError(f"Invalid precision: {precision}. " - "Must be one of 'float32', 'bfloat16', or 'float16'.") + f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") torch_xla._XLAC._xla_set_mat_mul_precision(precision) @@ -78,5 +78,5 @@ def get_mat_mul_precision() -> _PrecisionType: precision = torch_xla._XLAC._xla_get_mat_mul_precision() assert precision in [_DEFAULT, _HIGH, _HIGHEST ], (f"Invalid precision: {precision}. " - "Must be one of 'float32', 'bfloat16', or 'float16'.") + f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") return precision From f48d7f6dbf0eaaad5362aad93a00c19a96ca526d Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Mon, 5 May 2025 18:54:16 +0000 Subject: [PATCH 09/20] linter --- torch_xla/backends/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/backends/__init__.py b/torch_xla/backends/__init__.py index a7607dcd0187..8d57d018e775 100644 --- a/torch_xla/backends/__init__.py +++ b/torch_xla/backends/__init__.py @@ -2,7 +2,7 @@ # See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py # for an example of how backends are implemented in PyTorch -# in the __init__.py file, despite general style guidelines against this. +# in the __init__.py file, despite general style guidelines against this. # Literal is available from Python 3.8, # matching the Python versions for PyTorch and PyTorch/XLA. From 114456bf94f070822eb7a41a1eacbb078f817771 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Mon, 5 May 2025 23:20:38 +0000 Subject: [PATCH 10/20] minor edits. --- torch_xla/backends/__init__.py | 78 ++++++++++++++++------------------ 1 file changed, 37 insertions(+), 41 deletions(-) diff --git a/torch_xla/backends/__init__.py b/torch_xla/backends/__init__.py index 8d57d018e775..7256c33bc6bf 100644 --- a/torch_xla/backends/__init__.py +++ b/torch_xla/backends/__init__.py @@ -1,4 +1,8 @@ -"""torch_xla.backends is intended to be moved to torch.backends.xla""" +"""torch_xla.backends controls the behavior of the XLA backend. + +This subpackage parallels the torch.backends.{cuda, cpu, mps, etc} +subpackages in PyTorch. +""" # See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py # for an example of how backends are implemented in PyTorch @@ -10,15 +14,15 @@ import torch_xla -__all__ = ['set_mat_mul_precision', 'get_mat_mul_precision'] +__all__ = ["set_mat_mul_precision", "get_mat_mul_precision"] # Valid values for get_mat_mul_precision/set_mat_mul_precision # Note: it is idiomatic to PyTorch to use strings rather than enums. # See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/backends/cpu/__init__.py#L9 -_DEFAULT: Final = 'default' -_HIGH: Final = 'high' -_HIGHEST: Final = 'highest' +_DEFAULT: Final = "default" +_HIGH: Final = "high" +_HIGHEST: Final = "highest" # Use of variables with Final typehint instead of literals is valid. _PrecisionType: TypeAlias = Literal[ @@ -26,41 +30,33 @@ # Some of this description adapted from Jax documentation. +# TODO: Once the numerics tutorial is released, link from this docstring. def set_mat_mul_precision(precision: _PrecisionType) -> None: """Control the default matmul and conv precision for 32bit inputs. - Some platforms, like TPU, offer configurable precision levels for - matrix multiplication and convolution computations, - trading off accuracy for speed. - - This option can be used to control the default precision level for - computations involved in matrix multiplication and convolution on - 32bit inputs. The levels describe the precision at - which scalar products are computed. - - On a TPU: - * `default` is the fastest and least precise, essentially - downcasting an FP32 to BF16 before multiplying. - - * `high` breaks the FP32 inputs into upper and lower bits and - computes the product of the pair of upper bits, and the two pairs of - upper and lower bits. The pair of lower bits is ignored. Since the BF16 - format has 9 bits of precision (7 bits of mantissa plus one implicit - leading one bit plus one sign bit), this delivers approximately - 18 bits of precision, more than the TensorFloat32 (10 bits) - but less than a full FP32 (23 bits). - It requires three passes over the data. - - * `highest` is the most precise, and the slowest. - It further breaks up each input into a triple of upper, middle, - and lower bits and effectively calculates - 23 bits of precision. This is the most precise option, but also - requires six passes. - - Args: - precision (str): The precision to set for matrix multiplication. Must be - one of 'default', 'high', or 'highest'. - """ + Some platforms, like TPU, offer configurable precision levels for + matrix multiplication and convolution computations, + trading off accuracy for speed. + + This option controls the default precision level for + computations involved in matrix multiplication and convolution on + 32bit inputs. The levels describe the precision at + which scalar products are computed. + + On a TPU: + * `default` is the fastest and least precise, + downcasting an FP32 to BF16 before multiplying. + + * `high` takes three passes and generates approximately 14 bits of + precision. + + * `highest` is the most precise, and the slowest. It takes six + passes and generates approximately 22 bits of precision. + + Args: + precision (str): The precision to set for matrix multiplication. + Must be one of 'default', 'high', or 'highest'. + """ if precision not in [_DEFAULT, _HIGH, _HIGHEST]: raise ValueError(f"Invalid precision: {precision}. " f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") @@ -71,10 +67,10 @@ def set_mat_mul_precision(precision: _PrecisionType) -> None: def get_mat_mul_precision() -> _PrecisionType: """Get the current mat mul precision for 32bit inputs. - Returns: - str: The current precision setting for matrix multiplication, - one of 'default', 'high', or 'highest'. - """ + Returns: + str: The current precision setting for matrix multiplication, + one of 'default', 'high', or 'highest'. + """ precision = torch_xla._XLAC._xla_get_mat_mul_precision() assert precision in [_DEFAULT, _HIGH, _HIGHEST ], (f"Invalid precision: {precision}. " From 87cff3d4d643d98f341d8a60fc6a8e723d7f5a8c Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Tue, 6 May 2025 21:04:06 +0000 Subject: [PATCH 11/20] Updated TODO per review --- test/test_mat_mul_precision_default.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_mat_mul_precision_default.py b/test/test_mat_mul_precision_default.py index c9169c3d5664..ef6b228c6a6e 100644 --- a/test/test_mat_mul_precision_default.py +++ b/test/test_mat_mul_precision_default.py @@ -46,10 +46,8 @@ def test_mat_mul_precision_numerics_default(self): y = self._make_input() reference_float64 = torch.matmul(x, y) - # This is slightly worse than 1/2**[mantissa bits] because in multiplication, both - # operands may lose 1/256 (=2^8, where 8 = 7 + 1, - # where 7 is the mantissa and 1 is the implicit bit). - worst_atol = torch.tensor(1 - ((2**8 - 1) / 2**8)**2, dtype=torch.float64) + # TODO: Justify this logic. Why isn't it Why is it not -1 + ((2**8 - 1) / 2**8)**2. + worst_atol = torch.tensor(-1 + ((2**8 + 1) / 2**8)**2, dtype=torch.float64) x = x.to(torch.float32).to('xla') y = y.to(torch.float32).to('xla') From b96f671c8567f19f8a10def2c508431963022cb6 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Tue, 6 May 2025 21:52:05 +0000 Subject: [PATCH 12/20] Update todo and precision math per comment. --- test/test_mat_mul_precision_high.py | 2 +- test/test_mat_mul_precision_highest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_mat_mul_precision_high.py b/test/test_mat_mul_precision_high.py index ff94a260acaa..19cc7eaa9df3 100644 --- a/test/test_mat_mul_precision_high.py +++ b/test/test_mat_mul_precision_high.py @@ -48,7 +48,7 @@ def test_mat_mul_precision_numerics_high(self): # 14 bits is an estimate of precision in the # three pass technique. - worst_atol = torch.tensor(1 - ((2**14 - 1) / 2**14)**2, dtype=torch.float64) + worst_atol = torch.tensor(-1 + ((2**14 + 1) / 2**14)**2, dtype=torch.float64) x = x.to(torch.float32).to('xla') y = y.to(torch.float32).to('xla') diff --git a/test/test_mat_mul_precision_highest.py b/test/test_mat_mul_precision_highest.py index 9192e9ff507c..3576d86d6795 100644 --- a/test/test_mat_mul_precision_highest.py +++ b/test/test_mat_mul_precision_highest.py @@ -48,7 +48,7 @@ def test_mat_mul_precision_numerics_default(self): # 22 bits is an estimate of precision in the # six pass technique. - worst_atol = torch.tensor(1 - ((2**22 - 1) / 2**22)**2, dtype=torch.float64) + worst_atol = torch.tensor(-1 + ((2**22 + 1) / 2**22)**2, dtype=torch.float64) x = x.to(torch.float32).to('xla') y = y.to(torch.float32).to('xla') From 30d1c19111f6a4b177d40e35cced1fdaef14fe92 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Tue, 6 May 2025 22:18:44 +0000 Subject: [PATCH 13/20] yapf --- test/test_mat_mul_precision_high.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_mat_mul_precision_high.py b/test/test_mat_mul_precision_high.py index 19cc7eaa9df3..d798a4a59658 100644 --- a/test/test_mat_mul_precision_high.py +++ b/test/test_mat_mul_precision_high.py @@ -48,7 +48,8 @@ def test_mat_mul_precision_numerics_high(self): # 14 bits is an estimate of precision in the # three pass technique. - worst_atol = torch.tensor(-1 + ((2**14 + 1) / 2**14)**2, dtype=torch.float64) + worst_atol = torch.tensor( + -1 + ((2**14 + 1) / 2**14)**2, dtype=torch.float64) x = x.to(torch.float32).to('xla') y = y.to(torch.float32).to('xla') From 959b0dc7cc17207f19fb2c8a6dd2011472b197b1 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Wed, 7 May 2025 18:58:09 +0000 Subject: [PATCH 14/20] linter --- test/test_mat_mul_precision_highest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_mat_mul_precision_highest.py b/test/test_mat_mul_precision_highest.py index 3576d86d6795..0f233c035372 100644 --- a/test/test_mat_mul_precision_highest.py +++ b/test/test_mat_mul_precision_highest.py @@ -48,7 +48,8 @@ def test_mat_mul_precision_numerics_default(self): # 22 bits is an estimate of precision in the # six pass technique. - worst_atol = torch.tensor(-1 + ((2**22 + 1) / 2**22)**2, dtype=torch.float64) + worst_atol = torch.tensor( + -1 + ((2**22 + 1) / 2**22)**2, dtype=torch.float64) x = x.to(torch.float32).to('xla') y = y.to(torch.float32).to('xla') From 8aaa9790a2d6316d1aef177e04d1f319b8a18c35 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 9 May 2025 02:28:24 +0000 Subject: [PATCH 15/20] parameterized, but in a process isolated way. --- test/test_mat_mul_precision.py | 109 +++++++++++++++++++++++++ test/test_mat_mul_precision_default.py | 73 ----------------- test/test_mat_mul_precision_high.py | 75 ----------------- test/test_mat_mul_precision_highest.py | 75 ----------------- test/test_utils.py | 10 +++ test/tpu/run_tests.sh | 7 +- 6 files changed, 123 insertions(+), 226 deletions(-) create mode 100644 test/test_mat_mul_precision.py delete mode 100644 test/test_mat_mul_precision_default.py delete mode 100644 test/test_mat_mul_precision_high.py delete mode 100644 test/test_mat_mul_precision_highest.py diff --git a/test/test_mat_mul_precision.py b/test/test_mat_mul_precision.py new file mode 100644 index 000000000000..3240bd93d996 --- /dev/null +++ b/test/test_mat_mul_precision.py @@ -0,0 +1,109 @@ +"""Numeric tests for default precision of mat mul. + +There are three similar test files, suffixed default, high, and highest. +Unfortunately, the precision cannot reliably +be dynamically changed between tensor operations, so +the tests are split into three files to ensure a fresh +environment for each test. +""" + +import unittest +import time + +import torch +import torch_xla +import torch_xla.backends + +import test_utils + + +class TestMatMulPrecision(unittest.TestCase): + + def _make_input(self): + eye = torch.eye(1024, device='cpu', dtype=torch.float64) + rand_ = torch.testing.make_tensor((1024, 1024), + dtype=torch.float64, + device="cpu", + low=0.99, + high=1.01) + return eye * rand_ + + # TODO: Figure out why either PT/XLA or unittest + # is unable to successfully run this test in a parameterized way. + # Is it something about unittest leaking global state? + # Or, switching precision too quickly or without a barrier? + @unittest.skipIf(not test_utils._is_on_tpu(), 'Skipping, not on TPU.') + @unittest.expectedFailure + def test_all(self): + # The number of bit of precise mantissa expected in the result. + parameters = [ + ('highest', 22), + ('high', 14), + ('default', 8), + ] + # Although pytest has a slightly more elegant parameterized testing function, + # all TPU tests user unittest. + for i, (precision, bits) in enumerate(parameters): + with self.subTest(precision=precision, bits=bits): + self._test_parameterized(precision, bits) + + @unittest.skipIf(not test_utils._is_on_tpu(), 'Skipping, not on TPU.') + def test_highest(self): + self._test_parameterized('highest', 22) + + @unittest.skipIf(not test_utils._is_on_tpu(), 'Skipping, not on TPU.') + def test_high(self): + self._test_parameterized('high', 14) + + @unittest.skipIf(not test_utils._is_on_tpu(), 'Skipping, not on TPU.') + def test_default(self): + self._test_parameterized('default', 8) + + # DO NOT add epsilons to this test. These tests must be numerically exact. + def _test_parameterized(self, precision, bits): + # Arrange + torch_xla.backends.set_mat_mul_precision(precision) + + # Diagonal matrices force mat mul through MXU + # but require only one non-zero accumulation. + x = self._make_input() + y = self._make_input() + reference_float64 = torch.matmul(x, y) + + # TODO: Justify this logic. Why isn't it Why is it not + # 1 - ((2**8 - 1) / 2**8)**2 (equation stated by per TPU expert)? + widest_atol = torch.tensor( + -1 + ((2**(bits) + 1) / 2**bits)**2, dtype=torch.float64) + + narrowest_atol = widest_atol / 4.0 + + x = x.to(torch.float32).to('xla') + y = y.to(torch.float32).to('xla') + + # Act + actual = torch.matmul(x, y).to('cpu').to(torch.float64) + + # Disable rtol, we know exactly the atol for default, high, and highest. + torch.testing.assert_close( + actual, + reference_float64, + rtol=0.0, + atol=widest_atol, + ) + + with self.assertRaises(AssertionError): + torch.testing.assert_close( + actual, + reference_float64, + rtol=0.0, + atol=narrowest_atol, + ) + + assert not torch.equal(actual, reference_float64), ( + "Actual product and reference product should not be closer than equal, " + f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" + ) + + +# There is no main function. This is designed to be run from +# python -m unittest ... diff --git a/test/test_mat_mul_precision_default.py b/test/test_mat_mul_precision_default.py deleted file mode 100644 index ef6b228c6a6e..000000000000 --- a/test/test_mat_mul_precision_default.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Numeric tests for default precision of mat mul. - -There are three similar test files, suffixed default, high, and highest. -Unfortunately, the precision cannot reliably -be dynamically changed between tensor operations, so -the tests are split into three files to ensure a fresh -environment for each test. -""" - -import os -import unittest - -import torch -import torch_xla -import torch_xla.backends -import torch_xla.runtime as xr - - -def _is_on_tpu(): - return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' - - -skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') - - -class TestMatMulPrecisionDefault(unittest.TestCase): - - def _make_input(self): - eye = torch.eye(1024, device='cpu', dtype=torch.float64) - rand_ = torch.testing.make_tensor((1024, 1024), - dtype=torch.float64, - device="cpu", - low=0.99, - high=1.01) - return eye * rand_ - - # DO NOT add epsilons to this test. These tests must be numerically exact. - @skipIfNotTpu - def test_mat_mul_precision_numerics_default(self): - # Arrange - torch_xla.backends.set_mat_mul_precision('default') - - # Diagonal matrices force mat mul through MXU - # but require only one non-zero accumulation. - x = self._make_input() - y = self._make_input() - reference_float64 = torch.matmul(x, y) - - # TODO: Justify this logic. Why isn't it Why is it not -1 + ((2**8 - 1) / 2**8)**2. - worst_atol = torch.tensor(-1 + ((2**8 + 1) / 2**8)**2, dtype=torch.float64) - - x = x.to(torch.float32).to('xla') - y = y.to(torch.float32).to('xla') - - # Act - actual = torch.matmul(x, y).to('cpu').to(torch.float64) - - # Disable rtol, we know exactly the atol for default, high, and highest. - torch.testing.assert_close( - actual, - reference_float64, - rtol=0.0, - atol=worst_atol, - ) - - assert not torch.equal(actual, reference_float64), ( - "Actual product and reference product should not be equal, " - f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" - ) - - -if __name__ == '__main__': - unittest.main(verbosity=0) diff --git a/test/test_mat_mul_precision_high.py b/test/test_mat_mul_precision_high.py deleted file mode 100644 index d798a4a59658..000000000000 --- a/test/test_mat_mul_precision_high.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Numeric tests for high precision of mat mul. - -There are three similar test files, suffixed default, high, and highest. -Unfortunately, the precision cannot reliably -be dynamically changed between tensor operations, so -the tests are split into three files to ensure a fresh -environment for each test. -""" - -import os -import unittest - -import torch -import torch_xla -import torch_xla.backends -import torch_xla.runtime as xr - - -def _is_on_tpu(): - return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' - - -skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') - - -class TestMatMulPrecisionHigh(unittest.TestCase): - - def _make_input(self): - eye = torch.eye(1024, device='cpu', dtype=torch.float64) - rand_ = torch.testing.make_tensor((1024, 1024), - dtype=torch.float64, - device="cpu", - low=0.99, - high=1.01) - return eye * rand_ - - # DO NOT add epsilons to this test. These tests must be numerically exact. - @skipIfNotTpu - def test_mat_mul_precision_numerics_high(self): - # Arrange - torch_xla.backends.set_mat_mul_precision('high') - - # Diagonal matrices force mat mul through MXU - # but require only one non-zero accumulation. - x = self._make_input() - y = self._make_input() - reference_float64 = torch.matmul(x, y) - - # 14 bits is an estimate of precision in the - # three pass technique. - worst_atol = torch.tensor( - -1 + ((2**14 + 1) / 2**14)**2, dtype=torch.float64) - - x = x.to(torch.float32).to('xla') - y = y.to(torch.float32).to('xla') - - # Act - actual = torch.matmul(x, y).to('cpu').to(torch.float64) - - # Disable rtol, we know exactly the atol for default, high, and highest. - torch.testing.assert_close( - actual, - reference_float64, - rtol=0.0, - atol=worst_atol, - ) - - assert not torch.equal(actual, reference_float64), ( - "Actual product and reference product should not be equal, " - f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" - ) - - -if __name__ == '__main__': - unittest.main(verbosity=0) diff --git a/test/test_mat_mul_precision_highest.py b/test/test_mat_mul_precision_highest.py deleted file mode 100644 index 0f233c035372..000000000000 --- a/test/test_mat_mul_precision_highest.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Numeric tests for highest precision of mat mul. - -There are three similar test files, suffixed default, high, and highest. -Unfortunately, the precision cannot reliably -be dynamically changed between tensor operations, so -the tests are split into three files to ensure a fresh -environment for each test. -""" - -import os -import unittest - -import torch -import torch_xla -import torch_xla.backends -import torch_xla.runtime as xr - - -def _is_on_tpu(): - return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' - - -skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') - - -class TestMatMulPrecisionHighest(unittest.TestCase): - - def _make_input(self): - eye = torch.eye(1024, device='cpu', dtype=torch.float64) - rand_ = torch.testing.make_tensor((1024, 1024), - dtype=torch.float64, - device="cpu", - low=0.99, - high=1.01) - return eye * rand_ - - # DO NOT add epsilons to this test. These tests must be numerically precise. - @skipIfNotTpu - def test_mat_mul_precision_numerics_default(self): - # Arrange - torch_xla.backends.set_mat_mul_precision('highest') - - # Diagonal matrices force mat mul through MXU - # but require only one non-zero accumulation. - x = self._make_input() - y = self._make_input() - reference_float64 = torch.matmul(x, y) - - # 22 bits is an estimate of precision in the - # six pass technique. - worst_atol = torch.tensor( - -1 + ((2**22 + 1) / 2**22)**2, dtype=torch.float64) - - x = x.to(torch.float32).to('xla') - y = y.to(torch.float32).to('xla') - - # Act - actual = torch.matmul(x, y).to('cpu').to(torch.float64) - - # Disable rtol, we know exactly the atol for default, high, and highest. - torch.testing.assert_close( - actual, - reference_float64, - rtol=0.0, - atol=worst_atol, - ) - - assert not torch.equal(actual, reference_float64), ( - "Actual product and reference product should not be equal, " - f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" - ) - - -if __name__ == '__main__': - unittest.main(verbosity=0) diff --git a/test/test_utils.py b/test/test_utils.py index ad00a1def62b..d63738e26c20 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -13,6 +13,7 @@ import torch_xla import torch_xla.core.xla_model as xm import torch_xla.utils.utils as xu +import torch_xla.runtime as xr def _set_rng_seed(seed): @@ -420,3 +421,12 @@ def temporary_env(**kwargs): else: # Restore the original value os.environ[key] = old_value + + +# Taken from test_operations.py +def _is_on_tpu(): + return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' + + +# Adapted from test_operations.py +unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 68669c0d084a..76d0cb441f88 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -6,10 +6,11 @@ TEST_CDIR="$(dirname "$CDIR")" source "${TEST_CDIR}/utils/run_tests_utils.sh" # TODO: merge with other run_tests -python3 "$TEST_CDIR/test_mat_mul_precision_highest.py" +PYTHONPATH="$TEST_CDIR${PYTHONPATH:+:$PYTHONPATH}" python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_high +PYTHONPATH="$TEST_CDIR${PYTHONPATH:+:$PYTHONPATH}" python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_default +PYTHONPATH="$TEST_CDIR${PYTHONPATH:+:$PYTHONPATH}" python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_highest +PYTHONPATH="$TEST_CDIR${PYTHONPATH:+:$PYTHONPATH}" python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_all python3 "$TEST_CDIR/test_mat_mul_precision_get_and_set.py" -python3 "$TEST_CDIR/test_mat_mul_precision_default.py" -python3 "$TEST_CDIR/test_mat_mul_precision_high.py" python3 "$TEST_CDIR/test_operations.py" -v python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py" python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py" From 04516c778cdd3d262270781d8585a968182fbc7f Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 9 May 2025 02:53:22 +0000 Subject: [PATCH 16/20] removed dead code --- test/test_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index d63738e26c20..102321f5955b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -426,7 +426,3 @@ def temporary_env(**kwargs): # Taken from test_operations.py def _is_on_tpu(): return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' - - -# Adapted from test_operations.py -unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU') From c3856c88098750a3fbfb490ec0724a947ecbf1aa Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 9 May 2025 03:22:58 +0000 Subject: [PATCH 17/20] added issue for repeatable, unexpected behavior. --- test/test_mat_mul_precision.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_mat_mul_precision.py b/test/test_mat_mul_precision.py index 3240bd93d996..9303fd83248c 100644 --- a/test/test_mat_mul_precision.py +++ b/test/test_mat_mul_precision.py @@ -30,8 +30,7 @@ def _make_input(self): # TODO: Figure out why either PT/XLA or unittest # is unable to successfully run this test in a parameterized way. - # Is it something about unittest leaking global state? - # Or, switching precision too quickly or without a barrier? + # https://github.com/pytorch/xla/issues/9129 @unittest.skipIf(not test_utils._is_on_tpu(), 'Skipping, not on TPU.') @unittest.expectedFailure def test_all(self): From 79af4169b7d4ceed5b64b1f99d4d7fb221de602f Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 9 May 2025 16:14:27 +0000 Subject: [PATCH 18/20] Updated docstring. --- test/test_mat_mul_precision.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/test_mat_mul_precision.py b/test/test_mat_mul_precision.py index 9303fd83248c..caad3f6165ce 100644 --- a/test/test_mat_mul_precision.py +++ b/test/test_mat_mul_precision.py @@ -1,14 +1,6 @@ -"""Numeric tests for default precision of mat mul. - -There are three similar test files, suffixed default, high, and highest. -Unfortunately, the precision cannot reliably -be dynamically changed between tensor operations, so -the tests are split into three files to ensure a fresh -environment for each test. -""" +"""Numeric tests for default precision of mat mul.""" import unittest -import time import torch import torch_xla From 6ff7f5949b36552591bf85ac1de4a2cf0b9a2d61 Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 9 May 2025 19:07:06 +0000 Subject: [PATCH 19/20] changed naming of is_on_tpu --- test/test_mat_mul_precision.py | 8 ++++---- test/test_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_mat_mul_precision.py b/test/test_mat_mul_precision.py index caad3f6165ce..6da3aec3f259 100644 --- a/test/test_mat_mul_precision.py +++ b/test/test_mat_mul_precision.py @@ -23,7 +23,7 @@ def _make_input(self): # TODO: Figure out why either PT/XLA or unittest # is unable to successfully run this test in a parameterized way. # https://github.com/pytorch/xla/issues/9129 - @unittest.skipIf(not test_utils._is_on_tpu(), 'Skipping, not on TPU.') + @unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') @unittest.expectedFailure def test_all(self): # The number of bit of precise mantissa expected in the result. @@ -38,15 +38,15 @@ def test_all(self): with self.subTest(precision=precision, bits=bits): self._test_parameterized(precision, bits) - @unittest.skipIf(not test_utils._is_on_tpu(), 'Skipping, not on TPU.') + @unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') def test_highest(self): self._test_parameterized('highest', 22) - @unittest.skipIf(not test_utils._is_on_tpu(), 'Skipping, not on TPU.') + @unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') def test_high(self): self._test_parameterized('high', 14) - @unittest.skipIf(not test_utils._is_on_tpu(), 'Skipping, not on TPU.') + @unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') def test_default(self): self._test_parameterized('default', 8) diff --git a/test/test_utils.py b/test/test_utils.py index 102321f5955b..e425c51902e3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -424,5 +424,5 @@ def temporary_env(**kwargs): # Taken from test_operations.py -def _is_on_tpu(): +def is_on_tpu(): return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' From 9db6e3e7c885aa9c6a1fbe9ffc4ce79415d50c8b Mon Sep 17 00:00:00 2001 From: Yaoshiang Ho Date: Fri, 9 May 2025 21:04:38 +0000 Subject: [PATCH 20/20] CICD friendly version, hopefully. --- test/tpu/run_tests.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 76d0cb441f88..e265494dbaba 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -6,10 +6,10 @@ TEST_CDIR="$(dirname "$CDIR")" source "${TEST_CDIR}/utils/run_tests_utils.sh" # TODO: merge with other run_tests -PYTHONPATH="$TEST_CDIR${PYTHONPATH:+:$PYTHONPATH}" python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_high -PYTHONPATH="$TEST_CDIR${PYTHONPATH:+:$PYTHONPATH}" python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_default -PYTHONPATH="$TEST_CDIR${PYTHONPATH:+:$PYTHONPATH}" python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_highest -PYTHONPATH="$TEST_CDIR${PYTHONPATH:+:$PYTHONPATH}" python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_all +(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_high) +(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_default) +(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_highest) +(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_all) python3 "$TEST_CDIR/test_mat_mul_precision_get_and_set.py" python3 "$TEST_CDIR/test_operations.py" -v python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py"