Skip to content

9080 expose mat mul precision #9081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions test/test_mat_mul_precision_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Numeric tests for default precision of mat mul.

There are three similar test files, suffixed default, high, and highest.
Unfortunately, the precision cannot reliably
be dynamically changed between tensor operations, so
the tests are split into three files to ensure a fresh
environment for each test.
"""

import os
import unittest

import torch
import torch_xla
import torch_xla.backends
import torch_xla.runtime as xr


def _is_on_tpu():
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'


skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU')


class TestMatMulPrecisionDefaultDefault(unittest.TestCase):

def _make_input(self):
eye = torch.eye(1024, device='cpu', dtype=torch.float64)
rand_ = torch.testing.make_tensor((1024, 1024),
dtype=torch.float64,
device="cpu",
low=0.99,
high=1.01)
return eye * rand_

# DO NOT add epsilons to this test. These tests must be numerically exact.
@skipIfNotTpu
def test_mat_mul_precision_numerics_default(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('default')

# Diagonal matrices force mat mul through MXU
# but require only one non-zero accumulation.
x = self._make_input()
y = self._make_input()
reference_float64 = torch.matmul(x, y)

# This is slightly worse than 1/2**[mantissa bits] because in multiplication, both
# operands may lose 1/256 (=2^8, where 8 = 7 + 1,
# where 7 is the mantissa and 1 is the implicit bit).
worst_atol = torch.tensor(1 - ((2**8 - 1) / 2**8)**2, dtype=torch.float64)

x = x.to(torch.float32).to('xla')
y = y.to(torch.float32).to('xla')

# Act
actual = torch.matmul(x, y).to('cpu').to(torch.float64)

# Disable rtol, we know exactly the atol for default, high, and highest.
torch.testing.assert_close(
actual,
reference_float64,
rtol=0.0,
atol=worst_atol,
)

assert not torch.equal(actual, reference_float64), (
"Actual product and reference product should not be equal, "
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}"
)


if __name__ == '__main__':
unittest.main(verbosity=0)
62 changes: 62 additions & 0 deletions test/test_mat_mul_precision_get_and_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp"""

import sys
import unittest

import torch
import torch_xla
import torch_xla.backends


class TestMatMulPrecisionGetAndSet(unittest.TestCase):

def setUp(self):
self._original = torch_xla.backends.get_mat_mul_precision()
torch.set_printoptions(precision=20)
torch_xla.sync()

def tearDown(self):
torch_xla.backends.set_mat_mul_precision(self._original)
torch.set_printoptions(profile="default")
torch_xla.sync()

def test_set_mat_mul_precision_error(self):
# Assert
with self.assertRaises(ValueError):
# Act
torch_xla.backends.set_mat_mul_precision('BAD VALUE')

def test_get_and_set_mat_mul_precision_default(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('default')

# Act
status = torch_xla.backends.get_mat_mul_precision()

# Assert
self.assertEqual(status, 'default')

def test_get_and_set_mat_mul_precision_high(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('high')

# Act
status = torch_xla.backends.get_mat_mul_precision()

# Assert
self.assertEqual(status, 'high')

def test_get_and_set_mat_mul_precision_highest(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('highest')

# Act
status = torch_xla.backends.get_mat_mul_precision()

# Assert
self.assertEqual(status, 'highest')


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
74 changes: 74 additions & 0 deletions test/test_mat_mul_precision_high.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Numeric tests for high precision of mat mul.

There are three similar test files, suffixed default, high, and highest.
Unfortunately, the precision cannot reliably
be dynamically changed between tensor operations, so
the tests are split into three files to ensure a fresh
environment for each test.
"""

import os
import unittest

import torch
import torch_xla
import torch_xla.backends
import torch_xla.runtime as xr


def _is_on_tpu():
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'


skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU')


class TestMatMulPrecisionHigh(unittest.TestCase):

def _make_input(self):
eye = torch.eye(1024, device='cpu', dtype=torch.float64)
rand_ = torch.testing.make_tensor((1024, 1024),
dtype=torch.float64,
device="cpu",
low=0.99,
high=1.01)
return eye * rand_

# DO NOT add epsilons to this test. These tests must be numerically exact.
@skipIfNotTpu
def test_mat_mul_precision_numerics_high(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('high')

# Diagonal matrices force mat mul through MXU
# but require only one non-zero accumulation.
x = self._make_input()
y = self._make_input()
reference_float64 = torch.matmul(x, y)

# 14 bits is an estimate of precision in the
# three pass technique.
worst_atol = torch.tensor(1 - ((2**14 - 1) / 2**14)**2, dtype=torch.float64)

x = x.to(torch.float32).to('xla')
y = y.to(torch.float32).to('xla')

# Act
actual = torch.matmul(x, y).to('cpu').to(torch.float64)

# Disable rtol, we know exactly the atol for default, high, and highest.
torch.testing.assert_close(
actual,
reference_float64,
rtol=0.0,
atol=worst_atol,
)

assert not torch.equal(actual, reference_float64), (
"Actual product and reference product should not be equal, "
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}"
)


if __name__ == '__main__':
unittest.main(verbosity=0)
74 changes: 74 additions & 0 deletions test/test_mat_mul_precision_highest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Numeric tests for highest precision of mat mul.

There are three similar test files, suffixed default, high, and highest.
Unfortunately, the precision cannot reliably
be dynamically changed between tensor operations, so
the tests are split into three files to ensure a fresh
environment for each test.
"""

import os
import unittest

import torch
import torch_xla
import torch_xla.backends
import torch_xla.runtime as xr


def _is_on_tpu():
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'


skipIfNotTpu = unittest.skipIf(not _is_on_tpu(), 'Only supported on TPU')


class TestMatMulPrecisionDefault(unittest.TestCase):

def _make_input(self):
eye = torch.eye(1024, device='cpu', dtype=torch.float64)
rand_ = torch.testing.make_tensor((1024, 1024),
dtype=torch.float64,
device="cpu",
low=0.99,
high=1.01)
return eye * rand_

# DO NOT add epsilons to this test. These tests must be numerically precise.
@skipIfNotTpu
def test_mat_mul_precision_numerics_default(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('highest')

# Diagonal matrices force mat mul through MXU
# but require only one non-zero accumulation.
x = self._make_input()
y = self._make_input()
reference_float64 = torch.matmul(x, y)

# 22 bits is an estimate of precision in the
# six pass technique.
worst_atol = torch.tensor(1 - ((2**22 - 1) / 2**22)**2, dtype=torch.float64)

x = x.to(torch.float32).to('xla')
y = y.to(torch.float32).to('xla')

# Act
actual = torch.matmul(x, y).to('cpu').to(torch.float64)

# Disable rtol, we know exactly the atol for default, high, and highest.
torch.testing.assert_close(
actual,
reference_float64,
rtol=0.0,
atol=worst_atol,
)

assert not torch.equal(actual, reference_float64), (
"Actual product and reference product should not be equal, "
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}"
)


if __name__ == '__main__':
unittest.main(verbosity=0)
4 changes: 4 additions & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ TEST_CDIR="$(dirname "$CDIR")"
source "${TEST_CDIR}/utils/run_tests_utils.sh"

# TODO: merge with other run_tests
python3 "$TEST_CDIR/test_mat_mul_precision_highest.py"
python3 "$TEST_CDIR/test_mat_mul_precision_get_and_set.py"
python3 "$TEST_CDIR/test_mat_mul_precision_default.py"
python3 "$TEST_CDIR/test_mat_mul_precision_high.py"
python3 "$TEST_CDIR/test_operations.py" -v
python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py"
python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py"
Expand Down
81 changes: 81 additions & 0 deletions torch_xla/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""torch_xla.backends is intended to be moved to torch.backends.xla"""

# See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py
# for an example of how a backend is implemented in PyTorch.

# Literal is availabe from Python 3.8,
# matching the Python versions for PyTorch and PyTorchXLA.
from typing import Final, Literal, TypeAlias

import torch_xla

__all__ = ['set_mat_mul_precision', 'get_mat_mul_precision']

# Valid values for get_mat_mul_precision/set_mat_mul_precision
# Note: it is idiomatic to PyTorch to use strings rather than enums.
# See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/backends/cpu/__init__.py#L9

_DEFAULT: Final = 'default'
_HIGH: Final = 'high'
_HIGHEST: Final = 'highest'

# Use of variables with Final typehint instead of literals is valid.
_PrecisionType: TypeAlias = Literal[
_DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm]


# Some of this description adapted from Jax documentation.
def set_mat_mul_precision(precision: _PrecisionType) -> None:
"""Control the default matmul and conv precision for 32bit inputs.

Some platforms, like TPU, offer configurable precision levels for
matrix multiplication and convolution computations,
trading off accuracy for speed.

This option can be used to control the default precision level for
computations involved in matrix multiplication and convolution on
32bit inputs. The levels describe the precision at
which scalar products are computed.

On a TPU:
* `default` is the fastest and least precise, essentially
downcasting an FP32 to BF16 before multiplying.

* `high` breaks the FP32 inputs into upper and lower bits and
computes the product of the pair of upper bits, and the two pairs of
upper and lower bits. The pair of lower bits is ignored. Since the BF16
format has 9 bits of precision (7 bits of mantissa plus one implicit
leading one bit plus one sign bit), this delivers approximately
18 bits of precision, more than the TensorFloat32 (10 bits)
but less than a full FP32 (23 bits).
It requires three passes over the data.

* `highest` is the most precise, and the slowest.
It further breaks up each input into a triple of upper, middle,
and lower bits and effectively calculates
23 bits of precision. This is the most precise option, but also
requires six passes.

Args:
precision (str): The precision to set for matrix multiplication. Must be
one of 'default', 'high', or 'highest'.
"""
if precision not in [_DEFAULT, _HIGH, _HIGHEST]:
raise ValueError(f"Invalid precision: {precision}. "
"Must be one of 'float32', 'bfloat16', or 'float16'.")

torch_xla._XLAC._xla_set_mat_mul_precision(precision)


def get_mat_mul_precision() -> _PrecisionType:
"""Get the current mat mul precision for 32bit inputs.

Returns:
str: The current precision setting for matrix multiplication,
one of 'default', 'high', or 'highest'.
"""
precision = torch_xla._XLAC._xla_get_mat_mul_precision()
assert precision in [_DEFAULT, _HIGH, _HIGHEST
], (f"Invalid precision: {precision}. "
"Must be one of 'float32', 'bfloat16', or 'float16'.")
return precision
4 changes: 4 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,10 @@ void InitXlaModuleBindings(py::module m) {
ConsumeValue(xla::StringToPrecision(mat_mul_precision));
XlaHelpers::set_mat_mul_precision(precision);
});
m.def("_xla_get_mat_mul_precision", []() {
xla::PrecisionConfig::Precision precision = XlaHelpers::mat_mul_precision();
return xla::PrecisionToString(precision);
});

py::class_<xla::XlaBuilder, op_builder::BuilderPtr>(m, "XlaBuilder");
py::class_<op_builder::Op, op_builder::OpPtr>(m, "XlaOp");
Expand Down
Loading
Loading