diff --git a/test/quantization/test_gptq.py b/test/quantization/test_gptq.py index 98760f8cf6..6b275e8a8a 100644 --- a/test/quantization/test_gptq.py +++ b/test/quantization/test_gptq.py @@ -2,13 +2,16 @@ from pathlib import Path import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import ( + TestCase, +) from torchao._models.llama.model import ( ModelArgs, Transformer, prepare_inputs_for_model, ) +from torchao.utils import auto_detect_device from torchao._models.llama.tokenizer import get_tokenizer from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.quantization.utils import compute_error @@ -18,10 +21,10 @@ torch.manual_seed(0) +_DEVICE = auto_detect_device() class TestGPTQ(TestCase): @unittest.skip("skipping until we get checkpoints for gpt-fast") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_gptq_quantizer_int4_weight_only(self): from torchao._models._eval import ( LMEvalInputRecorder, @@ -30,7 +33,6 @@ def test_gptq_quantizer_int4_weight_only(self): from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer precision = torch.bfloat16 - device = "cuda" checkpoint_path = Path( "../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" ) @@ -80,7 +82,7 @@ def test_gptq_quantizer_int4_weight_only(self): model = quantizer.quantize(model, *inputs).cuda() model.reset_caches() - with torch.device("cuda"): + with torch.device(_DEVICE): model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size) limit = 1 @@ -89,7 +91,7 @@ def test_gptq_quantizer_int4_weight_only(self): tokenizer, model.config.block_size, prepare_inputs_for_model, - device, + _DEVICE, ).run_eval( ["wikitext"], limit, @@ -102,7 +104,6 @@ def test_gptq_quantizer_int4_weight_only(self): class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ import MultiTensor @@ -115,7 +116,6 @@ def test_multitensor_add_tensors(self): self.assertTrue(torch.equal(mt.values[1], tensor2)) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ import MultiTensor @@ -127,7 +127,6 @@ def test_multitensor_pad_unpad(self): self.assertEqual(mt.count, 1) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ import MultiTensor @@ -138,7 +137,6 @@ def test_multitensor_inplace_operation(self): class TestMultiTensorInputRecorder(TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_input_recorder(self): from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder @@ -159,7 +157,6 @@ def test_multitensor_input_recorder(self): self.assertTrue(isinstance(MT_input[2][2], MultiTensor)) self.assertEqual(MT_input[3], torch.float) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_gptq_with_input_recorder(self): from torchao.quantization.GPTQ import ( Int4WeightOnlyGPTQQuantizer, @@ -170,7 +167,7 @@ def test_gptq_with_input_recorder(self): config = ModelArgs(n_layer=2) - with torch.device("cuda"): + with torch.device(_DEVICE): model = Transformer(config) model.setup_caches(max_batch_size=2, max_seq_length=100) idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32) @@ -191,7 +188,11 @@ def test_gptq_with_input_recorder(self): args = input_recorder.get_recorded_inputs() - quantizer = Int4WeightOnlyGPTQQuantizer() + if _DEVICE == "xpu": + from torchao.dtypes import Int4XPULayout + quantizer = Int4WeightOnlyGPTQQuantizer(device=torch.device("xpu"), layout=Int4XPULayout()) + else: + quantizer = Int4WeightOnlyGPTQQuantizer() quantizer.quantize(model, *args) diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index 425b881dba..0385bb0925 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -31,8 +31,11 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_90, + auto_detect_device, ) +_DEVICE = auto_detect_device() + if torch.version.hip is not None: pytest.skip( "ROCm support for MoE quantization is under development", @@ -52,7 +55,7 @@ def _test_impl_moe_quant( base_class=AffineQuantizedTensor, tensor_impl_class=None, dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, fullgraph=False, ): """ @@ -114,8 +117,6 @@ def _test_impl_moe_quant( ] ) def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") @@ -138,10 +139,6 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): ] ) def test_int4wo_base(self, name, num_tokens, fullgraph): - if not torch.cuda.is_available(): - self.skipTest("Need CUDA available") - if not is_sm_at_least_90(): - self.skipTest("Requires CUDA capability >= 9.0") if not TORCH_VERSION_AT_LEAST_2_5: self.skipTest("Test only enabled for 2.5+") diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 323802757d..0ba87df886 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -83,11 +83,13 @@ TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6, + auto_detect_device, ) # TODO: put this in a common test utils file -_CUDA_IS_AVAILABLE = torch.cuda.is_available() +_GPU_IS_AVAILABLE = True if torch.cuda.is_available() or torch.xpu.is_available() else False +_DEVICE = auto_detect_device() class Sub(torch.nn.Module): def __init__(self): @@ -329,7 +331,7 @@ def _set_ptq_weight( group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), + q_weight.to(_DEVICE), qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight @@ -600,13 +602,13 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 inner_k_tiles = 8 scales_precision = torch.bfloat16 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) x = torch.randn(100, 256, dtype=dtype, device=device) @@ -654,13 +656,13 @@ def test_qat_4w_primitives(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available") def test_qat_4w_linear(self): from torchao.quantization.GPTQ import WeightOnlyInt4Linear from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( @@ -701,14 +703,14 @@ def test_qat_4w_quantizer_gradients(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available") def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao.quantization.qat import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) m = M().to(device).to(dtype) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 0435a6c59b..c414b9b9b8 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -74,8 +74,11 @@ is_sm_at_least_89, is_sm_at_least_90, unwrap_tensor_subclass, + auto_detect_device, ) +_DEVICE = auto_detect_device() + try: import gemlite # noqa: F401 @@ -301,7 +304,7 @@ def api(model): m2.load_state_dict(state_dict) m2 = m2.to(device="cuda") - example_inputs = map(lambda x: x.cuda(), example_inputs) + example_inputs = map(lambda x: x.to(_DEVICE), example_inputs) res = m2(*example_inputs) torch.testing.assert_close(ref, res.cpu()) @@ -337,12 +340,13 @@ def test_8da4w_quantizer_linear_bias(self): m(*example_inputs) @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantizer_int4_weight_only(self): from torchao._models._eval import TransformerEvalWrapper from torchao.quantization.linear_quant_modules import Int4WeightOnlyQuantizer precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) @@ -359,7 +363,7 @@ def test_quantizer_int4_weight_only(self): quantizer = Int4WeightOnlyQuantizer( groupsize, ) - model = quantizer.quantize(model).cuda() + model = quantizer.quantize(model).to(_DEVICE) result = TransformerEvalWrapper( model, tokenizer, @@ -375,11 +379,12 @@ def test_quantizer_int4_weight_only(self): ) @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_eval_wrapper(self): from torchao._models._eval import TransformerEvalWrapper precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) @@ -408,11 +413,12 @@ def test_eval_wrapper(self): # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_eval_wrapper_llama3(self): from torchao._models._eval import TransformerEvalWrapper precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path( ".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth" ) @@ -607,11 +613,15 @@ def test_int8wo_quantized_model_to_device(self): self.assertEqual(cuda_res.cpu(), ref) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_int4wo_quantized_model_to_device(self): # TODO: change initial model to "cpu" - devices = ["cuda", "cuda:0"] + if _DEVICE == "cuda": + devices = ["cuda", "cuda:0"] + elif _DEVICE =="xpu": + devices = ["xpu", "xpu:0"] + for device in devices: m = ToyLinearModel().eval().to(torch.bfloat16).to(device) example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) @@ -625,10 +635,10 @@ def test_int4wo_quantized_model_to_device(self): self.assertEqual(cuda_res.cpu(), ref) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantized_tensor_subclass_save_load_map_location(self): - m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda") - example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") + m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device=_DEVICE) + example_inputs = m.example_inputs(dtype=torch.bfloat16, device=_DEVICE) quantize_(m, int8_weight_only()) ref = m(*example_inputs) @@ -641,32 +651,50 @@ def test_quantized_tensor_subclass_save_load_map_location(self): m_copy = ToyLinearModel().eval() m_copy.load_state_dict(state_dict, assign=True) - m_copy.to(dtype=torch.bfloat16, device="cuda") + m_copy.to(dtype=torch.bfloat16, device=_DEVICE) res = m_copy(*example_inputs) self.assertEqual(res, ref) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantized_model_streaming(self): - def reset_memory(): - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + + def get_max_memory_allocated(device): + if device == "cuda": + return torch.cuda.max_memory_allocated(device) + elif device == "xpu": + return torch.xpu.max_memory_allocated(device) + elif device == "cpu": + return 0 + else: + raise ValueError(f"Unsupported device type: {device}") - reset_memory() + def reset_memory(device): + gc.collect() + if device == "cuda": + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device.index if device.index is not None else None) + elif device == "xpu": + torch.xpu.empty_cache() + elif device == "cpu": + pass + else: + raise ValueError(f"Unsupported device type: {device}") + + reset_memory(_DEVICE) m = ToyLinearModel() - quantize_(m.to(device="cuda"), int8_weight_only()) - memory_baseline = torch.cuda.max_memory_allocated() + quantize_(m.to(device=_DEVICE), int8_weight_only()) + memory_baseline = get_max_memory_allocated(_DEVICE) del m - reset_memory() + reset_memory(_DEVICE) m = ToyLinearModel() - quantize_(m, int8_weight_only(), device="cuda") - memory_streaming = torch.cuda.max_memory_allocated() + quantize_(m, int8_weight_only(), device=_DEVICE) + memory_streaming = get_max_memory_allocated(_DEVICE) for param in m.parameters(): - assert param.is_cuda + assert getattr(param, f'is_{_DEVICE}') self.assertLess(memory_streaming, memory_baseline) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @@ -697,7 +725,7 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") @common_utils.parametrize( "config", [ @@ -742,17 +770,17 @@ def test_workflow_e2e_numerics(self, config): # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability if isinstance(config, float8_static_activation_float8_weight): - config.scale = config.scale.to("cuda") + config.scale = config.scale.to(_DEVICE) dtype = torch.bfloat16 if isinstance(config, gemlite_uintx_weight_only): dtype = torch.float16 # set up inputs - x = torch.randn(128, 128, device="cuda", dtype=dtype) + x = torch.randn(128, 128, device=_DEVICE, dtype=dtype) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? - m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype) + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).to(_DEVICE).to(dtype) m_q = copy.deepcopy(m_ref) # quantize @@ -765,13 +793,13 @@ def test_workflow_e2e_numerics(self, config): sqnr = compute_error(y_ref, y_q) assert sqnr >= 16.5, f"SQNR {sqnr} is too low" - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_module_fqn_to_config_default(self): config1 = Int4WeightOnlyConfig(group_size=32) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"_default": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) @@ -779,13 +807,13 @@ def test_module_fqn_to_config_default(self): assert isinstance(model.linear2.weight, AffineQuantizedTensor) assert isinstance(model.linear2.weight._layout, PlainLayout) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_module_fqn_to_config_module_name(self): config1 = Int4WeightOnlyConfig(group_size=32) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"linear1": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) @@ -825,25 +853,25 @@ def test_module_fqn_to_config_embedding_linear(self): assert isinstance(model.emb.weight._layout, QDQLayout) assert isinstance(model.linear.weight, LinearActivationQuantizedTensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_module_fqn_to_config_skip(self): config1 = Int4WeightOnlyConfig(group_size=32) config = ModuleFqnToConfig({"_default": config1, "linear2": None}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) assert not isinstance(model.linear2.weight, AffineQuantizedTensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_int4wo_cuda_serialization(self): config = Int4WeightOnlyConfig(group_size=32) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) # quantize in cuda quantize_(model, config) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) model(*example_inputs) with tempfile.NamedTemporaryFile() as ckpt: # save checkpoint in cuda @@ -852,7 +880,7 @@ def test_int4wo_cuda_serialization(self): # This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253 sd = torch.load(ckpt.name, weights_only=False, map_location="cpu") for k, v in sd.items(): - sd[k] = v.to("cuda") + sd[k] = v.to(_DEVICE) # load state_dict in cuda model.load_state_dict(sd, assign=True) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index e69d68b27f..7d7f5ea290 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -36,11 +36,14 @@ check_cpu_version, check_xpu_version, is_fbcode, + auto_detect_device, ) _SEED = 1234 torch.manual_seed(_SEED) +_GPU_IS_AVAILABLE = True if torch.cuda.is_available() or torch.xpu.is_available() else False +_DEVICE = auto_detect_device() # Helper function to run a function twice # and verify that the result is the same. @@ -614,12 +617,10 @@ def test_choose_qparams_tensor_asym_eps(self): eps = torch.finfo(torch.float32).eps self.assertEqual(scale, eps) - @unittest.skipIf( - not torch.cuda.is_available(), "skipping when cuda is not available" - ) + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available") def test_get_group_qparams_symmetric_memory(self): """Check the memory usage of the op""" - weight = torch.randn(1024, 1024).to(device="cuda") + weight = torch.randn(1024, 1024).to(device=_DEVICE) original_mem_use = torch.cuda.memory_allocated() n_bit = 4 groupsize = 128 diff --git a/test/test_ao_models.py b/test/test_ao_models.py index 79e4cc3ef5..9b71019f63 100644 --- a/test/test_ao_models.py +++ b/test/test_ao_models.py @@ -7,8 +7,10 @@ import torch from torchao._models.llama.model import Transformer +from torchao.utils import get_available_devices -_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + +_DEVICES = get_available_devices() def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): @@ -17,7 +19,7 @@ def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): return model.eval() -@pytest.mark.parametrize("device", _AVAILABLE_DEVICES) +@pytest.mark.parametrize("device", _DEVICES) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("is_training", [True, False]) def test_ao_llama_model_inference_mode(device, batch_size, is_training): diff --git a/torchao/utils.py b/torchao/utils.py index 416d23d785..36d20e7e55 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -143,6 +143,16 @@ def get_available_devices(): devices.append("mps") return devices +def auto_detect_device(): + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch.version, "hip") and torch.version.hip: + return "rocm" + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return "xpu" + else: + return "cpu" + def get_compute_capability(): if torch.cuda.is_available():