diff --git a/test/test_mat_mul_precision.py b/test/test_mat_mul_precision.py new file mode 100644 index 000000000000..6da3aec3f259 --- /dev/null +++ b/test/test_mat_mul_precision.py @@ -0,0 +1,100 @@ +"""Numeric tests for default precision of mat mul.""" + +import unittest + +import torch +import torch_xla +import torch_xla.backends + +import test_utils + + +class TestMatMulPrecision(unittest.TestCase): + + def _make_input(self): + eye = torch.eye(1024, device='cpu', dtype=torch.float64) + rand_ = torch.testing.make_tensor((1024, 1024), + dtype=torch.float64, + device="cpu", + low=0.99, + high=1.01) + return eye * rand_ + + # TODO: Figure out why either PT/XLA or unittest + # is unable to successfully run this test in a parameterized way. + # https://github.com/pytorch/xla/issues/9129 + @unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') + @unittest.expectedFailure + def test_all(self): + # The number of bit of precise mantissa expected in the result. + parameters = [ + ('highest', 22), + ('high', 14), + ('default', 8), + ] + # Although pytest has a slightly more elegant parameterized testing function, + # all TPU tests user unittest. + for i, (precision, bits) in enumerate(parameters): + with self.subTest(precision=precision, bits=bits): + self._test_parameterized(precision, bits) + + @unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') + def test_highest(self): + self._test_parameterized('highest', 22) + + @unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') + def test_high(self): + self._test_parameterized('high', 14) + + @unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') + def test_default(self): + self._test_parameterized('default', 8) + + # DO NOT add epsilons to this test. These tests must be numerically exact. + def _test_parameterized(self, precision, bits): + # Arrange + torch_xla.backends.set_mat_mul_precision(precision) + + # Diagonal matrices force mat mul through MXU + # but require only one non-zero accumulation. + x = self._make_input() + y = self._make_input() + reference_float64 = torch.matmul(x, y) + + # TODO: Justify this logic. Why isn't it Why is it not + # 1 - ((2**8 - 1) / 2**8)**2 (equation stated by per TPU expert)? + widest_atol = torch.tensor( + -1 + ((2**(bits) + 1) / 2**bits)**2, dtype=torch.float64) + + narrowest_atol = widest_atol / 4.0 + + x = x.to(torch.float32).to('xla') + y = y.to(torch.float32).to('xla') + + # Act + actual = torch.matmul(x, y).to('cpu').to(torch.float64) + + # Disable rtol, we know exactly the atol for default, high, and highest. + torch.testing.assert_close( + actual, + reference_float64, + rtol=0.0, + atol=widest_atol, + ) + + with self.assertRaises(AssertionError): + torch.testing.assert_close( + actual, + reference_float64, + rtol=0.0, + atol=narrowest_atol, + ) + + assert not torch.equal(actual, reference_float64), ( + "Actual product and reference product should not be closer than equal, " + f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" + ) + + +# There is no main function. This is designed to be run from +# python -m unittest ... diff --git a/test/test_mat_mul_precision_get_and_set.py b/test/test_mat_mul_precision_get_and_set.py new file mode 100644 index 000000000000..ad47a9cc0e60 --- /dev/null +++ b/test/test_mat_mul_precision_get_and_set.py @@ -0,0 +1,62 @@ +"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp""" + +import sys +import unittest + +import torch +import torch_xla +import torch_xla.backends + + +class TestMatMulPrecisionGetAndSet(unittest.TestCase): + + def setUp(self): + self._original = torch_xla.backends.get_mat_mul_precision() + torch.set_printoptions(precision=20) + torch_xla.sync() + + def tearDown(self): + torch_xla.backends.set_mat_mul_precision(self._original) + torch.set_printoptions(profile="default") + torch_xla.sync() + + def test_set_mat_mul_precision_error(self): + # Assert + with self.assertRaises(ValueError): + # Act + torch_xla.backends.set_mat_mul_precision('BAD VALUE') + + def test_get_and_set_mat_mul_precision_default(self): + # Arrange + torch_xla.backends.set_mat_mul_precision('default') + + # Act + status = torch_xla.backends.get_mat_mul_precision() + + # Assert + self.assertEqual(status, 'default') + + def test_get_and_set_mat_mul_precision_high(self): + # Arrange + torch_xla.backends.set_mat_mul_precision('high') + + # Act + status = torch_xla.backends.get_mat_mul_precision() + + # Assert + self.assertEqual(status, 'high') + + def test_get_and_set_mat_mul_precision_highest(self): + # Arrange + torch_xla.backends.set_mat_mul_precision('highest') + + # Act + status = torch_xla.backends.get_mat_mul_precision() + + # Assert + self.assertEqual(status, 'highest') + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_utils.py b/test/test_utils.py index 6d1a2ff9a278..f238f4c82540 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -13,6 +13,7 @@ import torch_xla import torch_xla.core.xla_model as xm import torch_xla.utils.utils as xu +import torch_xla.runtime as xr def _set_rng_seed(seed): @@ -420,3 +421,8 @@ def temporary_env(**kwargs): else: # Restore the original value os.environ[key] = old_value + + +# Taken from test_operations.py +def is_on_tpu(): + return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 6b596a157460..e265494dbaba 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -6,6 +6,11 @@ TEST_CDIR="$(dirname "$CDIR")" source "${TEST_CDIR}/utils/run_tests_utils.sh" # TODO: merge with other run_tests +(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_high) +(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_default) +(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_highest) +(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_all) +python3 "$TEST_CDIR/test_mat_mul_precision_get_and_set.py" python3 "$TEST_CDIR/test_operations.py" -v python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py" python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py" diff --git a/torch_xla/backends/__init__.py b/torch_xla/backends/__init__.py new file mode 100644 index 000000000000..7256c33bc6bf --- /dev/null +++ b/torch_xla/backends/__init__.py @@ -0,0 +1,78 @@ +"""torch_xla.backends controls the behavior of the XLA backend. + +This subpackage parallels the torch.backends.{cuda, cpu, mps, etc} +subpackages in PyTorch. +""" + +# See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py +# for an example of how backends are implemented in PyTorch +# in the __init__.py file, despite general style guidelines against this. + +# Literal is available from Python 3.8, +# matching the Python versions for PyTorch and PyTorch/XLA. +from typing import Final, Literal, TypeAlias + +import torch_xla + +__all__ = ["set_mat_mul_precision", "get_mat_mul_precision"] + +# Valid values for get_mat_mul_precision/set_mat_mul_precision +# Note: it is idiomatic to PyTorch to use strings rather than enums. +# See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/backends/cpu/__init__.py#L9 + +_DEFAULT: Final = "default" +_HIGH: Final = "high" +_HIGHEST: Final = "highest" + +# Use of variables with Final typehint instead of literals is valid. +_PrecisionType: TypeAlias = Literal[ + _DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm] + + +# Some of this description adapted from Jax documentation. +# TODO: Once the numerics tutorial is released, link from this docstring. +def set_mat_mul_precision(precision: _PrecisionType) -> None: + """Control the default matmul and conv precision for 32bit inputs. + + Some platforms, like TPU, offer configurable precision levels for + matrix multiplication and convolution computations, + trading off accuracy for speed. + + This option controls the default precision level for + computations involved in matrix multiplication and convolution on + 32bit inputs. The levels describe the precision at + which scalar products are computed. + + On a TPU: + * `default` is the fastest and least precise, + downcasting an FP32 to BF16 before multiplying. + + * `high` takes three passes and generates approximately 14 bits of + precision. + + * `highest` is the most precise, and the slowest. It takes six + passes and generates approximately 22 bits of precision. + + Args: + precision (str): The precision to set for matrix multiplication. + Must be one of 'default', 'high', or 'highest'. + """ + if precision not in [_DEFAULT, _HIGH, _HIGHEST]: + raise ValueError(f"Invalid precision: {precision}. " + f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") + + torch_xla._XLAC._xla_set_mat_mul_precision(precision) + + +def get_mat_mul_precision() -> _PrecisionType: + """Get the current mat mul precision for 32bit inputs. + + Returns: + str: The current precision setting for matrix multiplication, + one of 'default', 'high', or 'highest'. + """ + precision = torch_xla._XLAC._xla_get_mat_mul_precision() + assert precision in [_DEFAULT, _HIGH, _HIGHEST + ], (f"Invalid precision: {precision}. " + f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") + return precision diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index d1b38efe00ea..9160342383df 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2124,6 +2124,10 @@ void InitXlaModuleBindings(py::module m) { ConsumeValue(xla::StringToPrecision(mat_mul_precision)); XlaHelpers::set_mat_mul_precision(precision); }); + m.def("_xla_get_mat_mul_precision", []() { + xla::PrecisionConfig::Precision precision = XlaHelpers::mat_mul_precision(); + return xla::PrecisionToString(precision); + }); py::class_(m, "XlaBuilder"); py::class_(m, "XlaOp"); diff --git a/torch_xla/csrc/xla_op_builder.cpp b/torch_xla/csrc/xla_op_builder.cpp index b84cc2c0a0e1..1fbae78eee95 100644 --- a/torch_xla/csrc/xla_op_builder.cpp +++ b/torch_xla/csrc/xla_op_builder.cpp @@ -208,6 +208,10 @@ xla::PrecisionConfig DotPrecisonConfig(py::dict args) { precision = xla::PrecisionConfig::HIGH; } else if (*arg_precision_config == "highest") { precision = xla::PrecisionConfig::HIGHEST; + } else { + XLA_ERROR() << "Invalid precision config in args: " + << *arg_precision_config + << " (valid values: default, high, highest)"; } } return XlaHelpers::BuildPrecisionConfig(precision);