From b9ebef07a08d6f8543895b8c820749d0f358543f Mon Sep 17 00:00:00 2001 From: "achandel@habana.ai" Date: Fri, 11 Jul 2025 08:22:52 +0300 Subject: [PATCH] Generalize base and compile tests of float8 --- test/float8/test_base.py | 113 ++++++++++++++++++------------------ test/float8/test_compile.py | 90 ++++++++++++++-------------- torchao/utils.py | 10 ++++ 3 files changed, 112 insertions(+), 101 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index df86c6f04e..682e6658ac 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -16,12 +16,12 @@ from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_89, is_sm_at_least_90, ) -if not TORCH_VERSION_AT_LEAST_2_5: +if not TORCH_VERSION_AT_LEAST_2_7: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -237,11 +237,12 @@ def test_axiswise_reshape(self): (ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE), ], ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + @unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_90() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): - a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") - b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + a = torch.randn(*a_shape, dtype=torch.bfloat16, device=torch.accelerator.current_accelerator().type) + b = torch.randn(64, 32, dtype=torch.bfloat16, device=torch.accelerator.current_accelerator().type) linear_mm_config = LinearMMConfig() @@ -270,7 +271,7 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): sqnr = compute_error(c_ref, c_fp8_compute) assert sqnr >= 25.0 - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available") def test_fp8_dtype( self, ): @@ -317,7 +318,8 @@ def _test_linear_impl( torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad) @pytest.mark.parametrize( - "emulate", [True, False] if is_sm_at_least_89() else [True] + "emulate", [True] if (torch.accelerator.current_accelerator().type=="cuda" + and not is_sm_at_least_89()) else [True, False] ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( @@ -335,7 +337,7 @@ def _test_linear_impl( @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @pytest.mark.parametrize("use_ac", [False, True]) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available") def test_linear_from_config_params( self, x_shape, @@ -347,8 +349,8 @@ def test_linear_from_config_params( linear_bias: bool, use_ac: bool, ): - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype) config = get_test_float8_linear_config( scaling_type_input, @@ -380,7 +382,7 @@ def test_linear_from_config_params( @pytest.mark.parametrize( "linear_dtype", [torch.bfloat16, torch.float16, torch.float32] ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "Accelarator not available") @skip_if_rocm("ROCm enablement in progress") def test_linear_from_recipe( self, @@ -389,14 +391,9 @@ def test_linear_from_recipe( linear_dtype: torch.dtype, linear_bias: bool, ): - if torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype) config = Float8LinearConfig.from_recipe_name(recipe_name) self._test_linear_impl( x, @@ -405,20 +402,21 @@ def test_linear_from_recipe( ) @pytest.mark.parametrize( - "emulate", [True, False] if is_sm_at_least_89() else [True] + "emulate", [True] if (torch.accelerator.current_accelerator().type=="cuda" + and not is_sm_at_least_89()) else [True, False] ) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available") def test_autocast_outputs( self, emulate: bool, linear_dtype: torch.dtype, ): m_ref = nn.Sequential( - nn.Linear(32, 32, device="cuda", dtype=linear_dtype), - nn.Linear(32, 32, device="cuda", dtype=linear_dtype), + nn.Linear(32, 32, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype), + nn.Linear(32, 32, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype), ) config = Float8LinearConfig( emulate=emulate, @@ -426,16 +424,16 @@ def test_autocast_outputs( m = convert_to_float8_training(copy.deepcopy(m_ref), config=config) # autocast off - x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) + x = torch.randn(16, 32, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on - with torch.autocast("cuda"): + with torch.autocast(torch.accelerator.current_accelerator().type): y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" - with torch.autocast("cuda", dtype=torch.bfloat16): + with torch.autocast(torch.accelerator.current_accelerator().type, dtype=torch.bfloat16): y = m(x) assert y.dtype == torch.bfloat16, ( f"y.dtype is {y.dtype}, expected {torch.bfloat16}" @@ -445,11 +443,12 @@ def test_autocast_outputs( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @pytest.mark.parametrize( - "emulate", [True, False] if is_sm_at_least_89() else [True] + "emulate", [True] if (torch.accelerator.current_accelerator().type=="cuda" + and not is_sm_at_least_89()) else [True, False] ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) + m = nn.Linear(32, 16, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype) config = Float8LinearConfig(emulate=emulate) m = Float8Linear.from_float(copy.deepcopy(m), config) @@ -457,16 +456,16 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): m = m.to(dtype=linear_dtype) # autocast off - x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) + x = torch.randn(16, 32, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on - with torch.autocast("cuda"): + with torch.autocast(torch.accelerator.current_accelerator().type): y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" - with torch.autocast("cuda", dtype=torch.bfloat16): + with torch.autocast(torch.accelerator.current_accelerator().type, dtype=torch.bfloat16): y = m(x) assert y.dtype == torch.bfloat16, ( f"y.dtype is {y.dtype}, expected {torch.bfloat16}" @@ -484,18 +483,22 @@ def test_repr(self): s = m.__repr__() assert "i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2" in s - @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") + @unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, arch 8.9 not available") def test_inference_mode(self): - x = torch.randn(32, 32, device="cuda") - m = nn.Sequential(nn.Linear(32, 32)).cuda() + x = torch.randn(32, 32, device=torch.accelerator.current_accelerator().type) + m = nn.Sequential(nn.Linear(32, 32)).to(device=torch.accelerator.current_accelerator().type) m = convert_to_float8_training(m) with torch.inference_mode(mode=True): m(x) - @unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available") + @unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, arch 8.9 not available") def test_quantize(self): - x = torch.randn(32, 32, device="cuda") - m = nn.Sequential(nn.Linear(32, 32)).cuda() + x = torch.randn(32, 32, device=torch.accelerator.current_accelerator().type) + m = nn.Sequential(nn.Linear(32, 32)).to(device=torch.accelerator.current_accelerator().type) m = convert_to_float8_training(m) assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" from torchao.quantization.quant_api import float8_weight_only, quantize_ @@ -509,10 +512,9 @@ def test_quantize(self): class TestScaledMM: - @unittest.skipIf( - not is_sm_at_least_89(), - "CUDA not available", - ) + @unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, arch 8.9 not available") @pytest.mark.parametrize( "base_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @@ -523,8 +525,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): output_dtype = base_dtype compare_type = torch.float32 - a = torch.randn(16, 16, device="cuda", dtype=base_dtype) - b = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + a = torch.randn(16, 16, device=torch.accelerator.current_accelerator().type, dtype=base_dtype) + b = torch.randn(32, 16, device=torch.accelerator.current_accelerator().type, dtype=base_dtype).t() a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() @@ -555,10 +557,12 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not is_sm_at_least_89(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, arch 8.9 not available") def test_different_configs_error(self): - x_fp32 = torch.randn(16, 16, device="cuda") - x_scale = torch.tensor(1.0, device="cuda") + x_fp32 = torch.randn(16, 16, device=torch.accelerator.current_accelerator().type) + x_scale = torch.tensor(1.0, device=torch.accelerator.current_accelerator().type) fp8_dtype = e4m3_dtype linear_config_a = LinearMMConfig( ScaledMMConfig(False, True, False, False), @@ -590,10 +594,9 @@ def test_different_configs_error(self): ): a @ b - @unittest.skipIf( - not is_sm_at_least_89(), - "CUDA not available", - ) + @unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, arch 8.9 not available") @pytest.mark.parametrize( "base_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @@ -603,8 +606,8 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): input_dtype = e4m3_dtype compare_type = torch.float32 - a = torch.randn(16, 41, device="cuda", dtype=base_dtype) - b = torch.randn(41, 128, device="cuda", dtype=base_dtype) + a = torch.randn(16, 41, device=torch.accelerator.current_accelerator().type, dtype=base_dtype) + b = torch.randn(41, 128, device=torch.accelerator.current_accelerator().type, dtype=base_dtype) a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() @@ -682,7 +685,7 @@ class TestNumerics: torch.float8_e5m2fnuz, ], ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available") def test_small_amax_float16(self, float8_dtype): # If we calculate scale naively with FP8_MAX_POS / amax, # the result may not be representable in fp16. Verify that @@ -701,7 +704,7 @@ def test_small_amax_float16(self, float8_dtype): FP16_MAX_POS = torch.finfo(torch.float16).max target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12) - x = torch.tensor([target_amax], dtype=torch.float16, device="cuda") + x = torch.tensor([target_amax], dtype=torch.float16, device=torch.accelerator.current_accelerator().type) scale = tensor_to_scale(x, float8_dtype) assert not torch.any(torch.isinf(scale)) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index aaf9d3d3f5..097a473c8a 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -12,12 +12,13 @@ import pytest from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_89, is_sm_at_least_90, + get_backend ) -if not TORCH_VERSION_AT_LEAST_2_5: +if not TORCH_VERSION_AT_LEAST_2_7: pytest.skip("Unsupported PyTorch version", allow_module_level=True) import torch @@ -51,9 +52,9 @@ def _test_compile_base( x_shape = (16, 16) linear_dtype = torch.bfloat16 - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() + x = torch.randn(*x_shape, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype).requires_grad_() x_ref = copy.deepcopy(x) - m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=True, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype) m_fp8 = Float8Linear.from_float( copy.deepcopy(m_ref), @@ -86,9 +87,10 @@ def _test_compile_base( "scaling_type_grad_output", [ScalingType.DYNAMIC], ) -@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) +@pytest.mark.parametrize("emulate", [True] if (torch.accelerator.current_accelerator().type=="cuda" + and not is_sm_at_least_89()) else [True, False]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available") def test_eager_only( fullgraph, emulate: bool, @@ -113,7 +115,8 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) +@pytest.mark.parametrize("emulate", [True] if (torch.accelerator.current_accelerator().type=="cuda" + and not is_sm_at_least_89()) else [True, False]) @pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", @@ -124,7 +127,7 @@ def test_eager_only( [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available") def test_aot_eager( fullgraph, emulate: bool, @@ -159,12 +162,11 @@ def test_aot_eager( "scaling_type_grad_output", [ScalingType.DYNAMIC], ) -@unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), - "CUDA with float8 support not available", -) +@unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, float8 support not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -def test_inductor_from_config_params( +def test_torch_compile_backend_from_config_params( fullgraph, emulate: bool, scaling_type_input: ScalingType, @@ -180,7 +182,7 @@ def test_inductor_from_config_params( emulate, ) _test_compile_base( - "inductor", + get_backend(), fullgraph, config, dtype, @@ -198,16 +200,16 @@ def test_inductor_from_config_params( Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) -@unittest.skipIf( - not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available" -) -def test_inductor_from_recipe(recipe_name): +@unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_90() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, capability 9.0 or greater not available") +def test_torch_compile_backend_from_recipe(recipe_name): torch._dynamo.reset() config = Float8LinearConfig.from_recipe_name(recipe_name) fullgraph = True dtype = torch.bfloat16 _test_compile_base( - "inductor", + get_backend(), fullgraph, config, dtype, @@ -233,35 +235,33 @@ def forward(self, x): return x_fp8 # TODO(future): figure out why the test below fails on CUDA capability 8.9 - @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_90(), - "CUDA with capability 9.0 or greater not available", - ) + @unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_90() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, capability 9.0 or greater not available") def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8Tensor object at the boundary of a subgraph""" - cnts = CompileCounterWithBackend("inductor") - mod = self.MockLinear(graph_break=True).cuda() + cnts = CompileCounterWithBackend(get_backend()) + mod = self.MockLinear(graph_break=True).to(device=torch.accelerator.current_accelerator().type) compiled_mod = copy.deepcopy(mod) compiled_mod = torch.compile(compiled_mod, backend=cnts) - x = torch.randn(16, 16, device="cuda") + x = torch.randn(16, 16, device=torch.accelerator.current_accelerator().type) y_eager = mod(x) y_compiled = compiled_mod(x) self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!") torch.testing.assert_close(y_eager, y_compiled) - @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), - "CUDA with float8 support not available", - ) + @unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, float8 support not available") def test_float8_graph_input(self): """Test that having Float8Tensor object as a graph input""" def to_float(x): return x.to_original_precision() - cnts = CompileCounterWithBackend("inductor") - mod = self.MockLinear(graph_break=False).cuda() - x = torch.randn(2, 2, device="cuda") + cnts = CompileCounterWithBackend(get_backend()) + mod = self.MockLinear(graph_break=False).to(device=torch.accelerator.current_accelerator().type) + x = torch.randn(2, 2, device=torch.accelerator.current_accelerator().type) compiled_to_float = torch.compile(to_float, backend=cnts) y = mod(x) y2_eager = to_float(y) @@ -273,16 +273,15 @@ def to_float(x): ) torch.testing.assert_close(y2_eager, y2_compiled) - @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), - "CUDA with float8 support not available", - ) + @unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, float8 support not available") def test_float8_graph_output(self): """Test that having Float8Tensor object as a graph output works""" - cnts = CompileCounterWithBackend("inductor") - mod = self.MockLinear(graph_break=False).cuda() + cnts = CompileCounterWithBackend(get_backend()) + mod = self.MockLinear(graph_break=False).to(device=torch.accelerator.current_accelerator().type) compiled_mod = torch.compile(mod, backend=cnts) - x = torch.randn(16, 16, device="cuda") + x = torch.randn(16, 16, device=torch.accelerator.current_accelerator().type) y_compiled = compiled_mod(x) self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") @@ -320,10 +319,9 @@ def __exit__(self, *args): sys.stderr = self.sys_stderr -@unittest.skipIf( - not is_sm_at_least_89(), - "CUDA not available", -) +@unittest.skipIf(not torch.accelerator.is_available() or + (not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"), + "Accelerator not available or If CUDA, capability 8.9 or greater not available") @pytest.mark.parametrize( "dtype", [ @@ -344,7 +342,7 @@ def test_dynamic_scale_numeric_parity( ): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(42) - hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) + hp_tensor1 = torch.randn(16, 16, device=torch.accelerator.current_accelerator().type, dtype=dtype) hp_tensor2 = hp_tensor1.detach().clone() float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), @@ -381,7 +379,7 @@ def test_dynamic_scale_numeric_parity( round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, ) torch._dynamo.reset() - float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( + float8_compile = torch.compile(hp_tensor_to_float8_dynamic,backend=get_backend())( hp_tensor2, e4m3_dtype, linear_mm_config, diff --git a/torchao/utils.py b/torchao/utils.py index c56b607b7b..501043d9a0 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -42,6 +42,7 @@ "is_sm_at_least_89", "is_sm_at_least_90", "is_package_at_least", + "get_backend", ] @@ -732,3 +733,12 @@ def _is_fbgemm_genai_gpu_available(): return False return True + +def get_backend(): + """ + Get device specific backend + """ + if torch.accelerator.current_accelerator().type == "hpu": + return "hpu_backend" + else: + return "inductor" \ No newline at end of file